From da4f0a11b04c43bdbc5647f711825316327e3d6f Mon Sep 17 00:00:00 2001 From: Guillaume Ausset Date: Mon, 25 May 2026 12:27:40 +0200 Subject: [PATCH 1/4] Generalize GPU execution provider selection --- colgrep/README.md | 2 +- colgrep/src/acceleration.rs | 67 ++++++++ colgrep/src/config.rs | 105 +++++++++--- colgrep/src/index/mod.rs | 313 +++++++++++++++--------------------- next-plaid-api/Cargo.toml | 1 + next-plaid-onnx/src/lib.rs | 298 ++++++++++++++++++++++++++++------ 6 files changed, 527 insertions(+), 259 deletions(-) diff --git a/colgrep/README.md b/colgrep/README.md index 548e09e1..e1497438 100644 --- a/colgrep/README.md +++ b/colgrep/README.md @@ -250,7 +250,7 @@ colgrep settings --pool-factor 2 # Set parallel encoding sessions (default: CPU count, max 16) colgrep settings --parallel 8 -# Set batch size per session (default: 1 for CPU, 64 for CUDA) +# Set batch size per session (default: 1 for CPU, 64 for GPU inference providers) colgrep settings --batch-size 2 # Set parser recursion depth guard (default: 1024) diff --git a/colgrep/src/acceleration.rs b/colgrep/src/acceleration.rs index 43bbfde2..2d04223f 100644 --- a/colgrep/src/acceleration.rs +++ b/colgrep/src/acceleration.rs @@ -1,4 +1,5 @@ use anyhow::{bail, Result}; +use next_plaid_onnx::ExecutionProvider; const FORCE_CPU_ENV_VARS: &[&str] = &["FORCE_CPU", "COLGREP_FORCE_CPU", "NEXT_PLAID_FORCE_CPU"]; const FORCE_GPU_ENV_VARS: &[&str] = &["FORCE_GPU", "COLGREP_FORCE_GPU", "NEXT_PLAID_FORCE_GPU"]; @@ -79,6 +80,72 @@ pub fn apply_acceleration_mode(mode: AccelerationMode) { } } +/// Returns whether an ONNX execution provider is usable for colgrep. +/// +/// `next-plaid-onnx` can check whether a provider is compiled into the loaded +/// ONNX Runtime library. colgrep adds the CUDA/cuDNN readiness check because it +/// manages CUDA library discovery itself on Linux. +pub fn is_gpu_provider_available(provider: ExecutionProvider) -> bool { + if !next_plaid_onnx::is_execution_provider_available(provider) { + return false; + } + + match provider { + ExecutionProvider::Cuda => { + #[cfg(feature = "cuda")] + { + crate::onnx_runtime::is_cudnn_available() + } + #[cfg(not(feature = "cuda"))] + { + false + } + } + provider => provider.is_gpu(), + } +} + +/// Available GPU execution providers for colgrep in selection order. +pub fn available_gpu_providers() -> Vec { + next_plaid_onnx::compiled_gpu_execution_providers() + .into_iter() + .filter(|provider| is_gpu_provider_available(*provider)) + .collect() +} + +/// Preferred available GPU execution provider for colgrep. +pub fn preferred_gpu_provider() -> Option { + available_gpu_providers().into_iter().next() +} + +/// Whether colgrep currently has any usable GPU inference provider. +pub fn has_gpu_provider() -> bool { + preferred_gpu_provider().is_some() +} + +/// Require a GPU provider, returning a user-facing diagnostic if none is usable. +pub fn require_gpu_provider() -> Result { + if let Some(provider) = preferred_gpu_provider() { + return Ok(provider); + } + + let compiled = next_plaid_onnx::compiled_gpu_execution_providers(); + if compiled.is_empty() { + bail!( + "FORCE_GPU is set, but this colgrep binary was compiled without a GPU execution provider. Enable a feature such as 'cuda', 'migraphx', 'coreml', or 'directml'." + ); + } + + let names = compiled + .iter() + .map(|provider| provider.display_name()) + .collect::>() + .join(", "); + bail!( + "FORCE_GPU is set, but no compiled GPU execution provider is available. Compiled provider(s): {names}. For ROCm, set ORT_DYLIB_PATH to an ONNX Runtime build with MIGraphX support; for CUDA, ensure the GPU ONNX Runtime and cuDNN are loadable." + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/colgrep/src/config.rs b/colgrep/src/config.rs index bd6b23d9..25a5fa84 100644 --- a/colgrep/src/config.rs +++ b/colgrep/src/config.rs @@ -8,7 +8,13 @@ use std::path::PathBuf; use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -#[cfg(feature = "cuda")] +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] use crate::acceleration::{env_acceleration_mode_lossy, AccelerationMode}; use crate::index::paths::get_colgrep_data_dir; @@ -23,26 +29,44 @@ pub const DEFAULT_MAX_RECURSION_DEPTH: usize = 1024; /// Testing shows batch_size=1 gives best performance with parallel sessions on CPU pub const DEFAULT_BATCH_SIZE_CPU: usize = 1; -/// Default batch size per encoding session for GPU (CUDA) +/// Default batch size per encoding session for GPU inference providers. /// With 1 session, larger batch size (64) is optimal for GPU throughput pub const DEFAULT_BATCH_SIZE_GPU: usize = 64; -/// Default batch size - use GPU default when CUDA is enabled AND available, CPU otherwise -/// Note: At compile time we set the GPU default, but at runtime we check cuDNN availability -#[cfg(feature = "cuda")] +/// Default batch size - use GPU default when a GPU inference provider is enabled, CPU otherwise. +/// Note: At compile time we set the GPU default, but at runtime we check provider availability. +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] pub const DEFAULT_BATCH_SIZE: usize = DEFAULT_BATCH_SIZE_GPU; -#[cfg(not(feature = "cuda"))] +#[cfg(not(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +)))] pub const DEFAULT_BATCH_SIZE: usize = DEFAULT_BATCH_SIZE_CPU; /// Get the effective default batch size at runtime. -/// When CUDA feature is enabled but cuDNN is not available, returns CPU default. -#[cfg(feature = "cuda")] +/// When GPU features are enabled but no provider is available, returns CPU default. +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] pub fn get_default_batch_size() -> usize { match env_acceleration_mode_lossy() { AccelerationMode::ForceCpu => DEFAULT_BATCH_SIZE_CPU, AccelerationMode::ForceGpu => DEFAULT_BATCH_SIZE_GPU, AccelerationMode::Auto => { - if crate::onnx_runtime::is_cudnn_available() { + if crate::acceleration::has_gpu_provider() { DEFAULT_BATCH_SIZE_GPU } else { DEFAULT_BATCH_SIZE_CPU @@ -51,7 +75,13 @@ pub fn get_default_batch_size() -> usize { } } -#[cfg(not(feature = "cuda"))] +#[cfg(not(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +)))] pub fn get_default_batch_size() -> usize { DEFAULT_BATCH_SIZE_CPU } @@ -64,14 +94,20 @@ pub fn get_default_cpu_parallel_sessions() -> usize { } /// Get the effective default parallel sessions at runtime. -/// When CUDA feature is enabled but cuDNN is not available, returns CPU default. -#[cfg(feature = "cuda")] +/// When GPU features are enabled but no provider is available, returns CPU default. +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] pub fn get_default_parallel_sessions() -> usize { match env_acceleration_mode_lossy() { AccelerationMode::ForceCpu => get_default_cpu_parallel_sessions(), AccelerationMode::ForceGpu => DEFAULT_PARALLEL_SESSIONS_GPU, AccelerationMode::Auto => { - if crate::onnx_runtime::is_cudnn_available() { + if crate::acceleration::has_gpu_provider() { DEFAULT_PARALLEL_SESSIONS_GPU } else { get_default_cpu_parallel_sessions() @@ -80,13 +116,19 @@ pub fn get_default_parallel_sessions() -> usize { } } -#[cfg(not(feature = "cuda"))] +#[cfg(not(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +)))] pub fn get_default_parallel_sessions() -> usize { get_default_cpu_parallel_sessions() } -/// Default number of parallel sessions for GPU (CUDA) -/// Using 1 session with larger batch is optimal for CUDA to minimize session creation overhead +/// Default number of parallel sessions for GPU inference providers. +/// Using 1 session with larger batch minimizes session creation overhead. /// The GPU handles batched inference more efficiently than multiple parallel sessions pub const DEFAULT_PARALLEL_SESSIONS_GPU: usize = 1; @@ -247,14 +289,15 @@ impl Config { self.default_n = None; } - /// Check if FP32 (non-quantized) model should be used - /// Defaults to true when cuda feature is enabled (better CUDA performance with FP32) + /// Check if FP32 (non-quantized) model should be used. + /// Defaults to true for CUDA/ROCm-family GPU builds, where FP32 models are + /// generally better supported and faster than quantized ONNX graphs. pub fn use_fp32(&self) -> bool { - #[cfg(feature = "cuda")] + #[cfg(any(feature = "cuda", feature = "tensorrt", feature = "migraphx"))] { self.fp32.unwrap_or(true) } - #[cfg(not(feature = "cuda"))] + #[cfg(not(any(feature = "cuda", feature = "tensorrt", feature = "migraphx")))] { self.fp32.unwrap_or(false) } @@ -295,7 +338,7 @@ impl Config { /// Get the number of parallel sessions for encoding /// Returns the configured value or: - /// - 1 session when CUDA is enabled AND cuDNN is available (GPUs work best with single session + large batches) + /// - 1 session when a GPU inference provider is available /// - min(CPU count, 16) otherwise (CPUs benefit from parallel sessions) pub fn get_parallel_sessions(&self) -> usize { self.configured_parallel_sessions() @@ -320,7 +363,7 @@ impl Config { /// Get the batch size for encoding /// Returns the configured value or the runtime default: - /// - 64 when CUDA is enabled AND cuDNN is available + /// - 64 when a GPU inference provider is available /// - 1 otherwise (CPU mode) pub fn get_batch_size(&self) -> usize { self.configured_batch_size() @@ -653,7 +696,13 @@ mod tests { assert_eq!(MAX_PARALLEL_SESSIONS_CPU, 16); let sessions = get_default_parallel_sessions(); - #[cfg(feature = "cuda")] + #[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" + ))] let expected = match env_acceleration_mode_lossy() { AccelerationMode::ForceCpu => std::thread::available_parallelism() .map(|p| p.get()) @@ -661,7 +710,7 @@ mod tests { .min(MAX_PARALLEL_SESSIONS_CPU), AccelerationMode::ForceGpu => DEFAULT_PARALLEL_SESSIONS_GPU, AccelerationMode::Auto => { - if crate::onnx_runtime::is_cudnn_available() { + if crate::acceleration::has_gpu_provider() { DEFAULT_PARALLEL_SESSIONS_GPU } else { std::thread::available_parallelism() @@ -671,7 +720,13 @@ mod tests { } } }; - #[cfg(not(feature = "cuda"))] + #[cfg(not(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" + )))] let expected = std::thread::available_parallelism() .map(|p| p.get()) .unwrap_or(16) diff --git a/colgrep/src/index/mod.rs b/colgrep/src/index/mod.rs index 3342df30..9aa162c7 100644 --- a/colgrep/src/index/mod.rs +++ b/colgrep/src/index/mod.rs @@ -21,9 +21,10 @@ use next_plaid_onnx::{pool_document_embeddings, Colbert, ExecutionProvider}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; -#[cfg(feature = "cuda")] -use crate::acceleration::apply_acceleration_mode; -use crate::acceleration::{env_acceleration_mode_lossy, AccelerationMode}; +use crate::acceleration::{ + apply_acceleration_mode, env_acceleration_mode_lossy, preferred_gpu_provider, + require_gpu_provider, AccelerationMode, +}; use crate::embed::build_embedding_text; use crate::parser::{build_call_graph, detect_language, extract_units, CodeUnit, Language}; use crate::signal::{is_interrupted, is_interrupted_outside_critical, CriticalSectionGuard}; @@ -154,9 +155,8 @@ const LARGE_BATCH_POOL_FACTOR: usize = 2; const DEFAULT_ENCODE_BATCH_SIZE: usize = 64; -/// Threshold for forcing CPU encoding even when CUDA is available. +/// Threshold for forcing CPU encoding even when a GPU provider is available. /// For small batches (< this many units), CPU is faster due to GPU initialization overhead. -#[cfg(feature = "cuda")] const SMALL_BATCH_CPU_THRESHOLD: usize = 300; /// Bounded channel capacity between the pool and index stages. /// Kept small (4 chunks) to limit memory: each chunk holds full embeddings @@ -924,69 +924,53 @@ impl IndexBuilder { /// /// # Arguments /// * `num_units` - Number of code units to encode. Used to decide whether to use GPU or CPU. - /// For small batches (< SMALL_BATCH_CPU_THRESHOLD), CPU is preferred even when CUDA is - /// available, as GPU initialization overhead outweighs the benefits for small workloads. + /// For small batches (< SMALL_BATCH_CPU_THRESHOLD), CPU is preferred even when a GPU + /// provider is available, as GPU initialization overhead outweighs the benefits for small workloads. fn ensure_model_created(&mut self, num_units: usize) -> Result<()> { if self.model.is_none() { - #[cfg(feature = "cuda")] let acceleration_mode = env_acceleration_mode_lossy(); - #[cfg(feature = "cuda")] - let (num_sessions, execution_provider) = { - match acceleration_mode { - AccelerationMode::ForceCpu => { + let (num_sessions, execution_provider) = match acceleration_mode { + AccelerationMode::ForceCpu => { + apply_acceleration_mode(AccelerationMode::ForceCpu); + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + + ( + self.parallel_sessions + .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), + ExecutionProvider::Cpu, + ) + } + AccelerationMode::ForceGpu => { + apply_acceleration_mode(AccelerationMode::ForceGpu); + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + + let provider = require_gpu_provider()?; + ( + self.parallel_sessions + .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), + provider, + ) + } + AccelerationMode::Auto => { + let force_cpu_for_small_batch = num_units < SMALL_BATCH_CPU_THRESHOLD; + if force_cpu_for_small_batch { apply_acceleration_mode(AccelerationMode::ForceCpu); - crate::onnx_runtime::ensure_onnx_runtime() - .context("Failed to initialize ONNX Runtime")?; - - ( - self.parallel_sessions - .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), - ExecutionProvider::Cpu, - ) - } - AccelerationMode::ForceGpu => { - apply_acceleration_mode(AccelerationMode::ForceGpu); - crate::onnx_runtime::ensure_onnx_runtime() - .context("Failed to initialize ONNX Runtime")?; - - if !crate::onnx_runtime::is_cudnn_available() { - anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized"); - } - - if !next_plaid_onnx::is_cuda_available() { - anyhow::bail!( - "FORCE_GPU is set, but the CUDA execution provider was not initialized" - ); - } - - ( - self.parallel_sessions - .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), - ExecutionProvider::Cuda, - ) + } else { + apply_acceleration_mode(AccelerationMode::Auto); } - AccelerationMode::Auto => { - let force_cpu = num_units < SMALL_BATCH_CPU_THRESHOLD; - if force_cpu { - apply_acceleration_mode(AccelerationMode::ForceCpu); - } else { - apply_acceleration_mode(AccelerationMode::Auto); - } - crate::onnx_runtime::ensure_onnx_runtime() - .context("Failed to initialize ONNX Runtime")?; + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; - let use_cuda = !force_cpu && { - crate::onnx_runtime::is_cudnn_available() - && next_plaid_onnx::is_cuda_available() - }; - - if use_cuda { + if !force_cpu_for_small_batch { + if let Some(provider) = preferred_gpu_provider() { ( self.parallel_sessions .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), - ExecutionProvider::Cuda, + provider, ) } else { ( @@ -996,29 +980,21 @@ impl IndexBuilder { ExecutionProvider::Cpu, ) } + } else { + ( + self.parallel_sessions + .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), + ExecutionProvider::Cpu, + ) } } }; - #[cfg(not(feature = "cuda"))] - let (num_sessions, execution_provider) = { - let _ = num_units; // Silence unused warning when cuda feature is disabled - - // Initialize ONNX Runtime (CPU-only build) - crate::onnx_runtime::ensure_onnx_runtime() - .context("Failed to initialize ONNX Runtime")?; - - ( - self.parallel_sessions - .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), - ExecutionProvider::Cpu, - ) - }; // Print model info after ONNX runtime is initialized (and any potential re-exec) eprintln!("šŸ¤– Model: {}", self.model_id); eprintln!("šŸ“‚ Building index..."); - // Use runtime default for batch size (respects cuDNN availability) + // Use runtime default for batch size (respects provider availability) let batch = self .batch_size .unwrap_or_else(crate::config::get_default_batch_size); @@ -1053,7 +1029,6 @@ impl IndexBuilder { } /// Check if the current model is using GPU execution. - #[cfg(feature = "cuda")] fn is_using_gpu(&self) -> bool { self.model .as_ref() @@ -1065,7 +1040,6 @@ impl IndexBuilder { /// Uses `dynamic_batch(false)` because CPU encoding processes fixed-size batches /// sequentially — the token-budget bucketing of dynamic batch only helps GPU /// where plan reuse across similar shapes reduces kernel launch overhead. - #[cfg(feature = "cuda")] fn rebuild_model_for_cpu(&mut self) -> Result<()> { self.model = None; apply_acceleration_mode(AccelerationMode::ForceCpu); @@ -1125,57 +1099,57 @@ impl IndexBuilder { }, ); - #[cfg(feature = "cuda")] - if let Err(gpu_err) = result { - if self.is_using_gpu() { - let accel = env_acceleration_mode_lossy(); - if accel == AccelerationMode::ForceGpu { - anyhow::bail!( - "GPU encoding failed with --force-gpu. \ - Not enough GPU memory for batch size {batch} and document length. \ - Try reducing the batch size or use auto mode to allow CPU fallback.\n\ - \nCaused by: {gpu_err}", - batch = self - .batch_size - .unwrap_or(crate::config::DEFAULT_BATCH_SIZE_GPU), - ); - } + match result { + Ok(was_interrupted) => Ok(was_interrupted), + Err(gpu_err) => { + if self.is_using_gpu() { + let accel = env_acceleration_mode_lossy(); + if accel == AccelerationMode::ForceGpu { + anyhow::bail!( + "GPU encoding failed with --force-gpu. \ + Not enough GPU memory for batch size {batch} and document length. \ + Try reducing the batch size or use auto mode to allow CPU fallback.\n\ + \nCaused by: {gpu_err}", + batch = self + .batch_size + .unwrap_or(crate::config::DEFAULT_BATCH_SIZE_GPU), + ); + } - eprintln!( - "\nāš ļø GPU encoding failed, falling back to CPU. \ - This is usually caused by insufficient GPU memory for the batch size.\n" - ); + eprintln!( + "\nāš ļø GPU encoding failed, falling back to CPU. \ + This is usually caused by insufficient GPU memory for the batch size.\n" + ); - self.rebuild_model_for_cpu()?; + self.rebuild_model_for_cpu()?; - let force_cpu = next_plaid::is_force_cpu(); - let config = IndexConfig { - force_cpu, - ..Default::default() - }; - let update_config = UpdateConfig { - force_cpu, - ..Default::default() - }; + let force_cpu = next_plaid::is_force_cpu(); + let config = IndexConfig { + force_cpu, + ..Default::default() + }; + let update_config = UpdateConfig { + force_cpu, + ..Default::default() + }; - return run_chunk_pipeline( - self.model().clone(), - sorted_units, - ChunkPipelineConfig { - index_chunk_size, - pool_factor, - index_path, - config, - update_config, - pb, - }, - ); + run_chunk_pipeline( + self.model().clone(), + sorted_units, + ChunkPipelineConfig { + index_chunk_size, + pool_factor, + index_path, + config, + update_config, + pb, + }, + ) + } else { + Err(gpu_err) + } } - - return Err(gpu_err); } - - result } /// Get the path to the index directory @@ -2098,11 +2072,12 @@ impl IndexBuilder { self.ensure_model_created(all_units.len())?; #[cfg(feature = "cuda")] - if !crate::onnx_runtime::is_cudnn_available() + if !self.is_using_gpu() + && !crate::onnx_runtime::is_cudnn_available() && std::env::var("_COLGREP_CUDNN_NOTICE").is_err() { std::env::set_var("_COLGREP_CUDNN_NOTICE", "1"); - eprintln!("šŸ“‚ cuDNN not found, encoding will use CPU."); + eprintln!("šŸ“‚ cuDNN not found, CUDA encoding will use CPU."); } // Build new index in temp directory to avoid destroying the old one @@ -3264,6 +3239,36 @@ pub struct Searcher { index_path: String, } +fn apply_search_acceleration_mode(acceleration_mode: AccelerationMode) { + match acceleration_mode { + AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu), + AccelerationMode::ForceCpu => apply_acceleration_mode(AccelerationMode::ForceCpu), + // Keep the existing search behavior: single-query searches default to + // CPU to avoid GPU initialization overhead, except on CoreML builds + // where automatic acceleration has historically been the default. + AccelerationMode::Auto if cfg!(feature = "coreml") => { + apply_acceleration_mode(AccelerationMode::Auto) + } + AccelerationMode::Auto => apply_acceleration_mode(AccelerationMode::ForceCpu), + } +} + +fn resolve_search_execution_provider( + acceleration_mode: AccelerationMode, +) -> Result { + match acceleration_mode { + AccelerationMode::ForceGpu => require_gpu_provider(), + AccelerationMode::ForceCpu => Ok(ExecutionProvider::Cpu), + AccelerationMode::Auto => { + if next_plaid_onnx::is_coreml_available() { + Ok(ExecutionProvider::CoreML) + } else { + Ok(ExecutionProvider::Cpu) + } + } + } +} + impl Searcher { pub fn load(project_root: &Path, model_id: &str, model_path: &Path) -> Result { Self::load_with_quantized(project_root, model_id, model_path, false) @@ -3280,39 +3285,10 @@ impl Searcher { let index_path = vector_dir.to_str().unwrap().to_string(); let acceleration_mode = env_acceleration_mode_lossy(); - let execution_provider = match acceleration_mode { - AccelerationMode::ForceGpu => ExecutionProvider::Cuda, - AccelerationMode::ForceCpu => ExecutionProvider::Cpu, - AccelerationMode::Auto => { - if cfg!(feature = "coreml") { - ExecutionProvider::CoreML - } else { - ExecutionProvider::Cpu - } - } - }; - - #[cfg(feature = "cuda")] - match acceleration_mode { - AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu), - AccelerationMode::ForceCpu | AccelerationMode::Auto => { - apply_acceleration_mode(AccelerationMode::ForceCpu) - } - } + apply_search_acceleration_mode(acceleration_mode); crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?; - - #[cfg(feature = "cuda")] - if matches!(acceleration_mode, AccelerationMode::ForceGpu) { - if !crate::onnx_runtime::is_cudnn_available() { - anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized"); - } - if !next_plaid_onnx::is_cuda_available() { - anyhow::bail!( - "FORCE_GPU is set, but the CUDA execution provider was not initialized" - ); - } - } + let execution_provider = resolve_search_execution_provider(acceleration_mode)?; // Cap intra-op threads to avoid overhead on high-core-count systems let num_threads = std::thread::available_parallelism() @@ -3356,39 +3332,10 @@ impl Searcher { let index_path = vector_dir.to_str().unwrap().to_string(); let acceleration_mode = env_acceleration_mode_lossy(); - let execution_provider = match acceleration_mode { - AccelerationMode::ForceGpu => ExecutionProvider::Cuda, - AccelerationMode::ForceCpu => ExecutionProvider::Cpu, - AccelerationMode::Auto => { - if cfg!(feature = "coreml") { - ExecutionProvider::CoreML - } else { - ExecutionProvider::Cpu - } - } - }; - - #[cfg(feature = "cuda")] - match acceleration_mode { - AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu), - AccelerationMode::ForceCpu | AccelerationMode::Auto => { - apply_acceleration_mode(AccelerationMode::ForceCpu) - } - } + apply_search_acceleration_mode(acceleration_mode); crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?; - - #[cfg(feature = "cuda")] - if matches!(acceleration_mode, AccelerationMode::ForceGpu) { - if !crate::onnx_runtime::is_cudnn_available() { - anyhow::bail!("FORCE_GPU is set, but cuDNN was not initialized"); - } - if !next_plaid_onnx::is_cuda_available() { - anyhow::bail!( - "FORCE_GPU is set, but the CUDA execution provider was not initialized" - ); - } - } + let execution_provider = resolve_search_execution_provider(acceleration_mode)?; // Cap intra-op threads to avoid overhead on high-core-count systems let num_threads = std::thread::available_parallelism() diff --git a/next-plaid-api/Cargo.toml b/next-plaid-api/Cargo.toml index b917d6f4..524d1fd6 100644 --- a/next-plaid-api/Cargo.toml +++ b/next-plaid-api/Cargo.toml @@ -61,6 +61,7 @@ mkl = ["next-plaid/mkl"] metal_gpu = ["next-plaid/metal_gpu"] model = ["dep:next-plaid-onnx"] cuda = ["model", "next-plaid-onnx/cuda", "next-plaid/cuda"] +migraphx = ["model", "next-plaid-onnx/migraphx"] [dev-dependencies] reqwest = { version = "0.13", features = ["json"] } diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 359c5545..609c4ad7 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -68,15 +68,31 @@ use tokenizers::Encoding; use tokenizers::Tokenizer; // Conditional imports for execution providers -#[cfg(feature = "cuda")] +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] use ort::ep::ExecutionProvider as OrtExecutionProviderTrait; #[cfg(feature = "cuda")] use ort::execution_providers::CUDAExecutionProvider; -/// Run a closure, catching panics without printing the default panic message. -/// See `next_plaid::cuda::catch_cuda_panic` for the rationale. -#[cfg(feature = "cuda")] -fn catch_cuda_panic(f: F) -> std::result::Result> +/// Run a closure, catching execution-provider panics without printing the +/// default panic message. Provider availability checks can panic when the ORT +/// dylib has not been initialized yet or when a provider's driver libraries are +/// stubs/incompatible; callers convert that into "provider unavailable". +#[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" +))] +fn catch_execution_provider_panic( + f: F, +) -> std::result::Result> where F: FnOnce() -> R + std::panic::UnwindSafe, { @@ -181,7 +197,7 @@ fn find_onnxruntime_library() -> Option { #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ExecutionProvider { /// Automatically detect and use the best available hardware. - /// Tries in order: CUDA > TensorRT > CoreML > DirectML > CPU + /// Tries in order: CUDA > TensorRT > CoreML > DirectML > MIGraphX > CPU #[default] Auto, /// CPU execution only @@ -198,6 +214,98 @@ pub enum ExecutionProvider { MIGraphX, } +impl ExecutionProvider { + /// Human-readable provider name for diagnostics and CLI messages. + pub fn display_name(self) -> &'static str { + match self { + ExecutionProvider::Auto => "auto", + ExecutionProvider::Cpu => "CPU", + ExecutionProvider::Cuda => "CUDA", + ExecutionProvider::TensorRT => "TensorRT", + ExecutionProvider::CoreML => "CoreML", + ExecutionProvider::DirectML => "DirectML", + ExecutionProvider::MIGraphX => "MIGraphX", + } + } + + /// Whether this provider represents a hardware accelerator rather than CPU. + pub fn is_gpu(self) -> bool { + matches!( + self, + ExecutionProvider::Cuda + | ExecutionProvider::TensorRT + | ExecutionProvider::CoreML + | ExecutionProvider::DirectML + | ExecutionProvider::MIGraphX + ) + } +} + +const GPU_PROVIDER_ORDER: [ExecutionProvider; 5] = [ + ExecutionProvider::Cuda, + ExecutionProvider::TensorRT, + ExecutionProvider::CoreML, + ExecutionProvider::DirectML, + ExecutionProvider::MIGraphX, +]; + +/// GPU execution providers compiled into this crate, in auto-selection order. +pub fn compiled_gpu_execution_providers() -> Vec { + #[allow(unused_mut)] + let mut providers = Vec::new(); + + #[cfg(feature = "cuda")] + providers.push(ExecutionProvider::Cuda); + #[cfg(feature = "tensorrt")] + providers.push(ExecutionProvider::TensorRT); + #[cfg(feature = "coreml")] + providers.push(ExecutionProvider::CoreML); + #[cfg(feature = "directml")] + providers.push(ExecutionProvider::DirectML); + #[cfg(feature = "migraphx")] + providers.push(ExecutionProvider::MIGraphX); + + providers +} + +/// First compiled GPU execution provider in auto-selection order. +pub fn compiled_gpu_execution_provider() -> Option { + compiled_gpu_execution_providers().into_iter().next() +} + +/// Return whether a specific execution provider is available in the currently +/// loaded ONNX Runtime library. +pub fn is_execution_provider_available(provider: ExecutionProvider) -> bool { + match provider { + ExecutionProvider::Auto => preferred_gpu_execution_provider().is_some(), + ExecutionProvider::Cpu => true, + ExecutionProvider::Cuda => is_cuda_available(), + ExecutionProvider::TensorRT => is_tensorrt_available(), + ExecutionProvider::CoreML => is_coreml_available(), + ExecutionProvider::DirectML => is_directml_available(), + ExecutionProvider::MIGraphX => is_migraphx_available(), + } +} + +/// Available GPU execution providers in auto-selection order. +pub fn available_gpu_execution_providers() -> Vec { + GPU_PROVIDER_ORDER + .iter() + .copied() + .filter(|provider| is_execution_provider_available(*provider)) + .collect() +} + +/// Preferred available GPU execution provider, if any. +pub fn preferred_gpu_execution_provider() -> Option { + available_gpu_execution_providers().into_iter().next() +} + +/// Whether any compiled GPU execution provider is available. +pub fn is_gpu_available() -> bool { + preferred_gpu_execution_provider().is_some() +} + fn configure_execution_provider( builder: SessionBuilder, provider: ExecutionProvider, @@ -285,7 +393,7 @@ pub fn is_cuda_available() -> bool { // Try to check if CUDA EP is available, catching any panics from CUDA driver loading // This can panic if CUDA libraries are present but corrupted/incomplete (stub libraries) - catch_cuda_panic(|| { + catch_execution_provider_panic(|| { CUDAExecutionProvider::default() .is_available() .unwrap_or(false) @@ -303,9 +411,102 @@ pub fn is_cuda_available() -> bool { false } +/// Check if TensorRT execution provider is available. +#[cfg(feature = "tensorrt")] +pub fn is_tensorrt_available() -> bool { + !is_force_cpu() + && catch_execution_provider_panic(|| { + TensorRTExecutionProvider::default() + .is_available() + .unwrap_or(false) + }) + .unwrap_or(false) +} + +/// Check if TensorRT execution provider is available. +/// Always returns false when TensorRT feature is not enabled. +#[cfg(not(feature = "tensorrt"))] +pub fn is_tensorrt_available() -> bool { + false +} + +/// Check if CoreML execution provider is available. +#[cfg(feature = "coreml")] +pub fn is_coreml_available() -> bool { + !is_force_cpu() + && catch_execution_provider_panic(|| { + CoreMLExecutionProvider::default() + .is_available() + .unwrap_or(false) + }) + .unwrap_or(false) +} + +/// Check if CoreML execution provider is available. +/// Always returns false when CoreML feature is not enabled. +#[cfg(not(feature = "coreml"))] +pub fn is_coreml_available() -> bool { + false +} + +/// Check if DirectML execution provider is available. +#[cfg(feature = "directml")] +pub fn is_directml_available() -> bool { + !is_force_cpu() + && catch_execution_provider_panic(|| { + DirectMLExecutionProvider::default() + .is_available() + .unwrap_or(false) + }) + .unwrap_or(false) +} + +/// Check if DirectML execution provider is available. +/// Always returns false when DirectML feature is not enabled. +#[cfg(not(feature = "directml"))] +pub fn is_directml_available() -> bool { + false +} + +/// Check if MIGraphX execution provider is available. +#[cfg(feature = "migraphx")] +pub fn is_migraphx_available() -> bool { + !is_force_cpu() + && catch_execution_provider_panic(|| { + MIGraphXExecutionProvider::default() + .is_available() + .unwrap_or(false) + }) + .unwrap_or(false) +} + +/// Check if MIGraphX execution provider is available. +/// Always returns false when MIGraphX feature is not enabled. +#[cfg(not(feature = "migraphx"))] +pub fn is_migraphx_available() -> bool { + false +} + fn configure_auto_provider(builder: SessionBuilder) -> Result { if is_force_gpu() { - return configure_cuda(builder); + let provider = preferred_gpu_execution_provider().ok_or_else(|| { + let compiled = compiled_gpu_execution_providers(); + if compiled.is_empty() { + anyhow::anyhow!( + "NEXT_PLAID_FORCE_GPU is set, but no GPU execution provider was compiled. Enable a feature such as 'cuda', 'migraphx', 'coreml', or 'directml'." + ) + } else { + let names = compiled + .iter() + .map(|provider| provider.display_name()) + .collect::>() + .join(", "); + anyhow::anyhow!( + "NEXT_PLAID_FORCE_GPU is set, but no compiled GPU execution provider is available in the loaded ONNX Runtime library. Compiled provider(s): {names}." + ) + } + })?; + return configure_execution_provider(builder, provider); } // Skip GPU providers entirely if CPU-only mode is forced @@ -322,10 +523,8 @@ fn configure_auto_provider(builder: SessionBuilder) -> Result { if !force_cpu { // Wrap CUDA initialization in catch_cuda_panic to handle panics from stub libraries // without printing the default panic message to stderr - let cuda_result = catch_cuda_panic(std::panic::AssertUnwindSafe(|| { - builder - .clone() - .with_execution_providers([configured_cuda_execution_provider().build()]) + let cuda_result = catch_execution_provider_panic(std::panic::AssertUnwindSafe(|| { + configure_cuda(builder.clone()) })); match cuda_result { Ok(Ok(b)) => return Ok(b), @@ -338,40 +537,28 @@ fn configure_auto_provider(builder: SessionBuilder) -> Result { #[cfg(feature = "tensorrt")] if !force_cpu { - if let Ok(b) = builder - .clone() - .with_execution_providers([TensorRTExecutionProvider::default().build()]) - { + if let Ok(b) = configure_tensorrt(builder.clone()) { return Ok(b); } } #[cfg(feature = "coreml")] - { - if let Ok(b) = builder - .clone() - .with_execution_providers([CoreMLExecutionProvider::default().build()]) - { + if !force_cpu { + if let Ok(b) = configure_coreml(builder.clone()) { return Ok(b); } } #[cfg(feature = "directml")] if !force_cpu { - if let Ok(b) = builder - .clone() - .with_execution_providers([DirectMLExecutionProvider::default().build()]) - { + if let Ok(b) = configure_directml(builder.clone()) { return Ok(b); } } #[cfg(feature = "migraphx")] if !force_cpu { - if let Ok(b) = builder - .clone() - .with_execution_providers([MIGraphXExecutionProvider::default().build()]) - { + if let Ok(b) = configure_migraphx(builder.clone()) { return Ok(b); } } @@ -388,10 +575,12 @@ fn configure_cuda(builder: SessionBuilder) -> Result { // Wrap CUDA initialization in catch_cuda_panic to handle panics from stub/invalid libraries // without printing the default panic message to stderr - let cuda_result = catch_cuda_panic(std::panic::AssertUnwindSafe(|| { + let cuda_result = catch_execution_provider_panic(std::panic::AssertUnwindSafe(|| { builder .clone() - .with_execution_providers([configured_cuda_execution_provider().build()]) + .with_execution_providers([configured_cuda_execution_provider() + .build() + .error_on_failure()]) })); match cuda_result { @@ -400,10 +589,9 @@ fn configure_cuda(builder: SessionBuilder) -> Result { "Failed to configure CUDA execution provider: {e:?}. Ensure CUDA toolkit and cuDNN are installed." ) }), - Err(_) => { - eprintln!("[next-plaid-onnx] CUDA init panicked (invalid/stub library?), falling back to CPU"); - Ok(builder) - } + Err(_) => Err(anyhow::anyhow!( + "Failed to configure CUDA execution provider: CUDA initialization panicked (invalid/stub library?)" + )), } } @@ -415,7 +603,9 @@ fn configure_cuda(_builder: SessionBuilder) -> Result { #[cfg(feature = "tensorrt")] fn configure_tensorrt(builder: SessionBuilder) -> Result { builder - .with_execution_providers([TensorRTExecutionProvider::default().build()]) + .with_execution_providers([TensorRTExecutionProvider::default() + .build() + .error_on_failure()]) .map_err(|e| anyhow::anyhow!("Failed to configure TensorRT execution provider: {e:?}")) } @@ -427,7 +617,9 @@ fn configure_tensorrt(_builder: SessionBuilder) -> Result { #[cfg(feature = "coreml")] fn configure_coreml(builder: SessionBuilder) -> Result { builder - .with_execution_providers([CoreMLExecutionProvider::default().build()]) + .with_execution_providers([CoreMLExecutionProvider::default() + .build() + .error_on_failure()]) .map_err(|e| anyhow::anyhow!("Failed to configure CoreML execution provider: {e:?}")) } @@ -439,7 +631,9 @@ fn configure_coreml(_builder: SessionBuilder) -> Result { #[cfg(feature = "directml")] fn configure_directml(builder: SessionBuilder) -> Result { builder - .with_execution_providers([DirectMLExecutionProvider::default().build()]) + .with_execution_providers([DirectMLExecutionProvider::default() + .build() + .error_on_failure()]) .map_err(|e| anyhow::anyhow!("Failed to configure DirectML execution provider: {e:?}")) } @@ -454,7 +648,9 @@ fn configure_migraphx(builder: SessionBuilder) -> Result { return Ok(builder); } builder - .with_execution_providers([MIGraphXExecutionProvider::default().build()]) + .with_execution_providers([MIGraphXExecutionProvider::default() + .build() + .error_on_failure()]) .context("Failed to configure MIGraphX execution provider. Ensure ROCm and MIGraphX are installed.") } @@ -918,14 +1114,15 @@ impl ColbertBuilder { update_token_ids(&mut config, &tokenizer); let skiplist_ids = build_skiplist(&config, &tokenizer); + let gpu_execution_requested = match self.execution_provider { + ExecutionProvider::Auto => preferred_gpu_execution_provider().is_some(), + provider => provider.is_gpu(), + }; + // For GPU execution, cap intra-op threads to 1 — the GPU handles parallelism // and extra threads only cause ORT to allocate per-thread CUDA workspace buffers, // wasting GPU memory. The high thread count only benefits CPU sessions. - let threads_per_session = if matches!( - self.execution_provider, - ExecutionProvider::Cuda | ExecutionProvider::Auto - ) && self.num_sessions == 1 - { + let threads_per_session = if gpu_execution_requested && self.num_sessions == 1 { 1 } else { self.threads_per_session @@ -961,11 +1158,10 @@ impl ColbertBuilder { // Determine batch size let batch_size = self.batch_size.unwrap_or(if self.num_sessions > 1 { 2 // Small batches optimal for parallel sessions + } else if gpu_execution_requested { + DEFAULT_GPU_BATCH_SIZE } else { - match self.execution_provider { - ExecutionProvider::Cpu => DEFAULT_CPU_BATCH_SIZE, - _ => DEFAULT_GPU_BATCH_SIZE, - } + DEFAULT_CPU_BATCH_SIZE }); Ok(Colbert { @@ -1067,8 +1263,10 @@ impl Colbert { let processed_texts = preprocess_texts(&self.config, documents); let tokenized = tokenize_processed_texts_individually(&self.tokenizer, &processed_texts)?; let truncate_limit = self.config.document_length.saturating_sub(1); - let use_gpu_batch_modes = - !matches!(self.requested_execution_provider, ExecutionProvider::Cpu); + let use_gpu_batch_modes = match self.requested_execution_provider { + ExecutionProvider::Auto => is_gpu_available(), + provider => provider.is_gpu(), + }; let use_dynamic_batch = self.dynamic_batch && use_gpu_batch_modes; // CPU path: simple fixed-size batches. Documents are batched in input From b18fb2a5cf4dd4f6a2313ab9533219920a3d3cc0 Mon Sep 17 00:00:00 2001 From: Guillaume Ausset Date: Mon, 25 May 2026 12:29:55 +0200 Subject: [PATCH 2/4] Handle MIGraphX ONNX Runtime discovery --- colgrep/README.md | 23 +- colgrep/src/onnx_runtime.rs | 691 +++++++++++++++++++++++++++++++----- next-plaid-onnx/README.md | 6 + next-plaid-onnx/src/lib.rs | 35 ++ 4 files changed, 656 insertions(+), 99 deletions(-) diff --git a/colgrep/README.md b/colgrep/README.md index e1497438..eabd672e 100644 --- a/colgrep/README.md +++ b/colgrep/README.md @@ -642,14 +642,29 @@ Then: `cargo install colgrep --features openblas` ### ONNX Runtime -ONNX Runtime is downloaded automatically on first use. No manual installation required. +ONNX Runtime CPU and CUDA builds are downloaded automatically on first use. +ROCm/MIGraphX builds are ROCm-versioned and are not downloaded automatically; +install AMD's wheel and point ColGREP at its runtime library if auto-discovery +does not find it: + +```bash +pip install onnxruntime-migraphx \ + -f https://repo.radeon.com/rocm/manylinux/rocm-rel-/ + +export ORT_DYLIB_PATH=/path/to/site-packages/onnxruntime/capi/libonnxruntime.so +colgrep --force-gpu search "your query" +``` Lookup order: 1. `ORT_DYLIB_PATH` environment variable -2. Python environments (pip/conda/venv) -3. System paths -4. Auto-download to `~/.cache/onnxruntime/` +2. MIGraphX-capable Python/system installs (`migraphx` builds) +3. Python environments (pip/conda/venv) +4. System paths +5. Auto-download to `~/.cache/colgrep/onnxruntime/` + +On Linux, ColGREP may re-exec itself once to add the ONNX Runtime, cuDNN, or +ROCm library directories to `LD_LIBRARY_PATH` before ONNX Runtime is loaded. --- diff --git a/colgrep/src/onnx_runtime.rs b/colgrep/src/onnx_runtime.rs index 6bd9a4f3..c14e30c2 100644 --- a/colgrep/src/onnx_runtime.rs +++ b/colgrep/src/onnx_runtime.rs @@ -2,6 +2,8 @@ //! //! Automatically finds or downloads ONNX Runtime library. //! When the `cuda` feature is enabled, downloads the GPU version with CUDA support. +//! When the `migraphx` feature is enabled, discovers an externally installed +//! ONNX Runtime build with MIGraphX support (AMD publishes these as Python wheels). use anyhow::{Context, Result}; use std::env; @@ -41,7 +43,11 @@ const ORT_LIB_NAME: &str = "libonnxruntime.so"; #[cfg(target_os = "windows")] const ORT_LIB_NAME: &str = "onnxruntime.dll"; -/// Whether to use GPU (CUDA) version of ONNX Runtime +/// Whether the managed auto-download should use the CUDA GPU ONNX Runtime. +/// +/// CUDA is available from the official GitHub release artifacts. MIGraphX/ROCm +/// is not: AMD publishes ROCm-versioned Python wheels, so those are discovered +/// from the environment instead of downloaded here. #[cfg(feature = "cuda")] const USE_GPU: bool = true; #[cfg(not(feature = "cuda"))] @@ -55,7 +61,10 @@ const ORT_CACHE_SUBDIR: &str = "cpu"; /// Ensure ONNX Runtime is available. /// Sets ORT_DYLIB_PATH if found or downloaded. -/// When `cuda` feature is enabled, ensures GPU version is used and checks for cuDNN. +/// When `cuda` feature is enabled, ensures the managed CUDA GPU version is used +/// unless an explicit compatible runtime is provided, and checks for cuDNN. +/// When `migraphx` is enabled, discovers an installed MIGraphX-capable +/// runtime and avoids downloading the CPU-only runtime when GPU is forced. /// /// NOTE: To force CPU-only mode and avoid CUDA initialization overhead, set /// COLGREP_FORCE_CPU="1" before calling this function. This makes the GPU @@ -65,13 +74,22 @@ const ORT_CACHE_SUBDIR: &str = "cpu"; /// this function will re-exec the current process with the updated LD_LIBRARY_PATH. /// This is necessary because Linux caches LD_LIBRARY_PATH at process startup. pub fn ensure_onnx_runtime() -> Result { + #[cfg(any(feature = "migraphx", all(target_os = "linux", feature = "cuda")))] + let acceleration_mode = crate::acceleration::env_acceleration_mode_lossy(); + + // ROCm/MIGraphX wheels depend on ROCm shared libraries that are often under + // /opt/rocm*/lib or a conda environment. Linux caches LD_LIBRARY_PATH at + // process startup, so prepare those directories before we dlopen ORT. + #[cfg(all(target_os = "linux", feature = "migraphx"))] + if acceleration_mode != crate::acceleration::AccelerationMode::ForceCpu { + ensure_rocm_loader_path().context("Failed to prepare ROCm library path")?; + } + // For CUDA builds on Linux, check if we need to re-exec with cuDNN in LD_LIBRARY_PATH // This is only needed on Linux because it caches LD_LIBRARY_PATH at process startup // Skip CUDA setup if COLGREP_FORCE_CPU is set (CPU-only mode) #[cfg(all(target_os = "linux", feature = "cuda"))] - if crate::acceleration::env_acceleration_mode_lossy() - != crate::acceleration::AccelerationMode::ForceCpu - { + if acceleration_mode != crate::acceleration::AccelerationMode::ForceCpu { // Check if we already have the marker indicating we've set up LD_LIBRARY_PATH if env::var("_COLGREP_CUDA_SETUP").is_err() { // First pass: find cuDNN and set up LD_LIBRARY_PATH, then re-exec @@ -133,8 +151,7 @@ pub fn ensure_onnx_runtime() -> Result { if let Ok(path) = env::var("ORT_DYLIB_PATH") { let path = PathBuf::from(&path); if path.exists() && is_valid_ort_dylib(&path) { - pin_runtime_library(&path); - return Ok(path); + return activate_runtime_library(&path); } // Path from env is missing or can't be loaded (wrong arch, broken // symlink, stale Homebrew formula, ...). Clear it so the search and @@ -147,17 +164,43 @@ pub fn ensure_onnx_runtime() -> Result { env::remove_var("ORT_DYLIB_PATH"); } + // Prefer an already-installed MIGraphX-enabled ORT over managed CPU/CUDA + // downloads whenever ROCm GPU usage is allowed. This is what makes + // `--features migraphx --force-gpu` usable without accidentally downloading the + // official CPU-only GitHub artifact first. + #[cfg(feature = "migraphx")] + if acceleration_mode != crate::acceleration::AccelerationMode::ForceCpu { + if let Some(path) = find_migraphx_onnx_runtime() { + return activate_runtime_library(&path); + } + } + + // A ROCm-only binary has no managed GPU runtime to download. If the user + // explicitly requested GPU and no MIGraphX runtime was found, fail now with + // installation guidance instead of downloading a CPU runtime and failing + // later with a less actionable provider-availability error. + #[cfg(all(feature = "migraphx", not(feature = "cuda")))] + if acceleration_mode == crate::acceleration::AccelerationMode::ForceGpu { + return Err(migraphx_runtime_not_found_error()); + } + // 2. Search common locations (skip for CUDA - we want our managed GPU version) #[cfg(not(feature = "cuda"))] if let Some(path) = find_onnx_runtime() { - pin_runtime_library(&path); - return Ok(path); + return activate_runtime_library(&path); } // 3. Download and cache let path = download_onnx_runtime()?; - pin_runtime_library(&path); - Ok(path) + activate_runtime_library(&path) +} + +fn activate_runtime_library(path: &Path) -> Result { + #[cfg(target_os = "linux")] + ensure_runtime_loader_path(path).context("Failed to prepare ONNX Runtime library path")?; + + pin_runtime_library(path); + Ok(path.to_path_buf()) } fn pin_runtime_library(path: &Path) { @@ -175,6 +218,31 @@ fn pin_runtime_library(path: &Path) { } } +/// Set `ORT_DYLIB_PATH` to the MIGraphX-capable ONNX Runtime ColGREP would +/// use, without initializing ORT sessions. +/// +/// MIGraphX static-cache keys include the ORT dylib fingerprint. Auto indexing +/// checks cache warmth before full ORT initialization to avoid startup overhead +/// on cold-cache CPU fallback paths, so it still needs this env var to match +/// `warm-cache` and the eventual GPU session path. +#[cfg(feature = "migraphx")] +pub fn ensure_migraphx_onnx_runtime_path_for_cache_key() -> Option { + if let Ok(path) = env::var("ORT_DYLIB_PATH") { + if !path.trim().is_empty() { + return Some(PathBuf::from(path)); + } + } + + find_migraphx_onnx_runtime().inspect(|path| { + env::set_var("ORT_DYLIB_PATH", path); + }) +} + +#[cfg(not(feature = "migraphx"))] +pub fn ensure_migraphx_onnx_runtime_path_for_cache_key() -> Option { + None +} + /// Find the cuDNN library directory (without setting any global state) #[cfg(all(target_os = "linux", feature = "cuda"))] fn find_cudnn_directory() -> Option { @@ -222,6 +290,209 @@ fn prepend_ld_library_path(dir: &Path) { } } +#[cfg(target_os = "linux")] +fn ld_library_path_contains(dir: &Path) -> bool { + env::var("LD_LIBRARY_PATH") + .ok() + .map(|current| { + current + .split(':') + .filter(|entry| !entry.is_empty()) + .any(|entry| Path::new(entry) == dir) + }) + .unwrap_or(false) +} + +#[cfg(target_os = "linux")] +fn reexec_with_library_dirs(dirs: &[PathBuf], marker: &str, context: &str) -> Result<()> { + if dirs.is_empty() { + return Ok(()); + } + + if env::var(marker).is_ok() { + return Ok(()); + } + + let mut missing = Vec::new(); + for dir in dirs { + if dir.exists() + && dir.is_dir() + && !ld_library_path_contains(dir) + && !missing.iter().any(|existing: &PathBuf| existing == dir) + { + missing.push(dir.clone()); + } + } + + if missing.is_empty() { + env::set_var(marker, "1"); + return Ok(()); + } + + let current_ld = env::var("LD_LIBRARY_PATH").unwrap_or_default(); + let prefix = missing + .iter() + .map(|dir| dir.to_string_lossy()) + .collect::>() + .join(":"); + let new_ld = if current_ld.is_empty() { + prefix + } else { + format!("{}:{}", prefix, current_ld) + }; + + env::set_var("LD_LIBRARY_PATH", &new_ld); + env::set_var(marker, "1"); + + let exe = env::current_exe().context("Failed to get current executable")?; + let args: Vec = env::args().collect(); + + let err = exec::execvp(&exe, &args); + Err(anyhow::anyhow!( + "Failed to re-exec with updated {context} library path: {err}" + )) +} + +#[cfg(target_os = "linux")] +fn ensure_runtime_loader_path(path: &Path) -> Result<()> { + let mut dirs = Vec::new(); + + // Provider shared libraries are loaded by ONNX Runtime after startup. Make + // their directory visible to the dynamic loader before ORT initializes. + if runtime_has_any_provider_companion(path) { + if let Some(parent) = path.parent() { + dirs.push(parent.to_path_buf()); + } + } + + #[cfg(feature = "migraphx")] + if runtime_has_migraphx_provider_companion(path) { + dirs.extend(get_rocm_library_dirs()); + } + + #[cfg(feature = "cuda")] + if runtime_has_cuda_provider_companion(path) { + if let Some(cudnn_dir) = find_cudnn_directory() { + dirs.push(cudnn_dir); + } + } + + reexec_with_library_dirs(&dirs, "_COLGREP_ORT_PROVIDER_SETUP", "ONNX Runtime") +} + +/// Prepare the Linux dynamic loader for ROCm/MIGraphX before validating ORT. +/// +/// Some ROCm ONNX Runtime wheels link directly or indirectly against ROCm +/// libraries. If those directories are not in LD_LIBRARY_PATH at process +/// startup, `dlopen(libonnxruntime.so)` can fail even when ORT_DYLIB_PATH is +/// correct. +#[cfg(all(target_os = "linux", feature = "migraphx"))] +fn ensure_rocm_loader_path() -> Result<()> { + let mut dirs = get_rocm_library_dirs(); + + if let Ok(path) = env::var("ORT_DYLIB_PATH") { + let path = PathBuf::from(path); + if let Some(parent) = path.parent() { + dirs.push(parent.to_path_buf()); + } + } + + reexec_with_library_dirs(&dirs, "_COLGREP_ROCM_SETUP", "ROCm") +} + +#[cfg(all(target_os = "linux", feature = "migraphx"))] +fn get_rocm_library_dirs() -> Vec { + let mut dirs = Vec::new(); + + fn push_existing_unique(dirs: &mut Vec, dir: PathBuf) { + if dir.exists() && dir.is_dir() && !dirs.iter().any(|existing| existing == &dir) { + dirs.push(dir); + } + } + + fn dir_contains_rocm_library(dir: &Path) -> bool { + fs::read_dir(dir) + .ok() + .into_iter() + .flat_map(|entries| entries.flatten()) + .any(|entry| { + let name = entry.file_name(); + let name = name.to_string_lossy(); + name.starts_with("libmigraphx") + || name.starts_with("libamdhip64") + || name.starts_with("libhsa-runtime64") + || name.starts_with("librocblas") + || name.starts_with("librocm") + || name.starts_with("librocrand") + || name.starts_with("librocsolver") + || name.starts_with("librocsparse") + || name.starts_with("libroctracer") + || name.starts_with("libroctx") + }) + } + + fn push_if_contains_rocm_library(dirs: &mut Vec, dir: PathBuf) { + if dir.exists() + && dir.is_dir() + && dir_contains_rocm_library(&dir) + && !dirs.iter().any(|existing| existing == &dir) + { + dirs.push(dir); + } + } + + if let Ok(conda_prefix) = env::var("CONDA_PREFIX") { + let conda = PathBuf::from(conda_prefix); + push_existing_unique(&mut dirs, conda.join("lib")); + push_existing_unique(&mut dirs, conda.join("lib64")); + } + + for var in [ + "ROCM_PATH", + "ROCM_HOME", + "HIP_PATH", + "HIP_HOME", + "MIGRAPHX_PATH", + "MIGRAPHX_HOME", + ] { + if let Ok(value) = env::var(var) { + let base = PathBuf::from(value); + push_existing_unique(&mut dirs, base.join("lib")); + push_existing_unique(&mut dirs, base.join("lib64")); + push_if_contains_rocm_library(&mut dirs, base); + } + } + + for dir in [ + PathBuf::from("/opt/rocm/lib"), + PathBuf::from("/opt/rocm/lib64"), + ] { + push_existing_unique(&mut dirs, dir); + } + + for dir in [ + PathBuf::from("/usr/lib/x86_64-linux-gnu"), + PathBuf::from("/usr/lib64"), + PathBuf::from("/usr/lib"), + ] { + push_if_contains_rocm_library(&mut dirs, dir); + } + + if let Ok(entries) = fs::read_dir("/opt") { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if name == "rocm" || name.starts_with("rocm-") { + let base = entry.path(); + push_existing_unique(&mut dirs, base.join("lib")); + push_existing_unique(&mut dirs, base.join("lib64")); + } + } + } + + dirs +} + /// Get all directories to search for cuDNN library (Linux only) #[cfg(all(target_os = "linux", feature = "cuda"))] fn get_cudnn_search_dirs() -> Vec { @@ -374,41 +645,110 @@ fn is_valid_ort_dylib(path: &Path) -> bool { } } -/// Search for ONNX Runtime in common locations +fn provider_companion_path(path: &Path, provider: &str) -> Option { + let parent = path.parent()?; + + #[cfg(target_os = "linux")] + let file_name = format!("libonnxruntime_providers_{provider}.so"); + + #[cfg(target_os = "macos")] + let file_name = format!("libonnxruntime_providers_{provider}.dylib"); + + #[cfg(target_os = "windows")] + let file_name = format!("onnxruntime_providers_{provider}.dll"); + + Some(parent.join(file_name)) +} + +fn runtime_has_provider_companion(path: &Path, provider: &str) -> bool { + provider_companion_path(path, provider).is_some_and(|path| path.exists()) +} + +fn runtime_has_migraphx_provider_companion(path: &Path) -> bool { + runtime_has_provider_companion(path, "migraphx") +} + +fn runtime_has_cuda_provider_companion(path: &Path) -> bool { + runtime_has_provider_companion(path, "cuda") +} + +fn runtime_has_any_provider_companion(path: &Path) -> bool { + runtime_has_provider_companion(path, "shared") + || runtime_has_cuda_provider_companion(path) + || runtime_has_migraphx_provider_companion(path) +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OrtRequirement { + Any, + MIGraphX, +} + +/// Search for any loadable ONNX Runtime in common locations. #[cfg(not(feature = "cuda"))] fn find_onnx_runtime() -> Option { - let search_paths = get_search_paths(); + find_onnx_runtime_with_requirement(get_search_paths(), OrtRequirement::Any) +} + +/// Search for an ONNX Runtime installation that includes the MIGraphX provider. +#[cfg(feature = "migraphx")] +fn find_migraphx_onnx_runtime() -> Option { + find_onnx_runtime_with_requirement(get_migraphx_search_paths(), OrtRequirement::MIGraphX) +} + +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn find_onnx_runtime_with_requirement( + search_paths: Vec, + requirement: OrtRequirement, +) -> Option { let mut rejected: Vec = Vec::new(); + let mut seen_candidates = std::collections::HashSet::new(); - let try_candidate = |candidate: PathBuf, rejected: &mut Vec| -> Option { - if !candidate.exists() { + let mut try_candidate = |candidate: PathBuf, rejected: &mut Vec| -> Option { + if !candidate.exists() || !candidate.is_file() { return None; } - if is_valid_ort_dylib(&candidate) { - Some(candidate) - } else { + + let canonical = candidate + .canonicalize() + .unwrap_or_else(|_| candidate.clone()); + if !seen_candidates.insert(canonical) { + return None; + } + + if !is_valid_ort_dylib(&candidate) { rejected.push(candidate); - None + return None; } + + if !candidate_satisfies_ort_requirement(&candidate, requirement) { + return None; + } + + Some(candidate) }; for base_path in search_paths { + // The search list may include either directories or a direct dylib path. + if let Some(p) = try_candidate(base_path.clone(), &mut rejected) { + return Some(p); + } + // Direct library file if let Some(p) = try_candidate(base_path.join(ORT_LIB_NAME), &mut rejected) { return Some(p); } - // Versioned library (e.g., libonnxruntime.so.1.23.0 on Linux, libonnxruntime.1.20.1.dylib on macOS) - // Match "libonnxruntime.so*" or "libonnxruntime.*dylib" only — NOT companion libraries - // like libonnxruntime_providers_shared.so which lack OrtGetApiBase. + // Versioned library (e.g., libonnxruntime.so.1.23.0 on Linux, + // libonnxruntime.1.20.1.dylib on macOS). Match only the core ORT + // library — NOT companion libraries like + // libonnxruntime_providers_shared.so, which lack OrtGetApiBase. if let Ok(entries) = fs::read_dir(&base_path) { for entry in entries.flatten() { let name = entry.file_name(); let name_str = name.to_string_lossy(); - if name_str.starts_with("libonnxruntime.so") - || name_str.starts_with("libonnxruntime.dylib") - || (name_str.starts_with("libonnxruntime.") && name_str.ends_with(".dylib")) - { + if is_onnxruntime_library_name(&name_str) { if let Some(p) = try_candidate(entry.path(), &mut rejected) { return Some(p); } @@ -422,41 +762,114 @@ fn find_onnx_runtime() -> Option { } } - if !rejected.is_empty() { - // Guard against repeat logging: `ensure_onnx_runtime` can be re-entered - // within a single process (tests, re-execs that restore the env, code - // paths that clear ORT_DYLIB_PATH), and once we've explained the - // rejection the user doesn't need to see it again. - use std::sync::atomic::{AtomicBool, Ordering}; - static WARNED: AtomicBool = AtomicBool::new(false); - if !WARNED.swap(true, Ordering::Relaxed) { - let mut seen: std::collections::HashSet = std::collections::HashSet::new(); - let unique: Vec<&PathBuf> = rejected - .iter() - .filter(|p| { - let canon = p.canonicalize().unwrap_or_else(|_| (*p).clone()); - seen.insert(canon) - }) - .collect(); - eprintln!( - "āš ļø Found {} ONNX Runtime candidate(s) that failed to load (wrong arch, broken \ - signature, or companion library); downloading a managed copy instead:", - unique.len() - ); - for p in unique { - eprintln!(" - {}", p.display()); - } - } - } + warn_rejected_ort_candidates(&rejected, requirement); None } +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn candidate_satisfies_ort_requirement(path: &Path, requirement: OrtRequirement) -> bool { + match requirement { + OrtRequirement::Any => true, + OrtRequirement::MIGraphX => runtime_has_migraphx_provider_companion(path), + } +} + +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn is_onnxruntime_library_name(name: &str) -> bool { + #[cfg(target_os = "linux")] + return name.starts_with("libonnxruntime.so"); + + #[cfg(target_os = "macos")] + return name.starts_with("libonnxruntime.dylib") + || (name.starts_with("libonnxruntime.") && name.ends_with(".dylib")); + + #[cfg(target_os = "windows")] + return name.eq_ignore_ascii_case("onnxruntime.dll"); +} + +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn warn_rejected_ort_candidates(rejected: &[PathBuf], requirement: OrtRequirement) { + if rejected.is_empty() { + return; + } + + // Guard against repeat logging: `ensure_onnx_runtime` can be re-entered + // within a single process (tests, re-execs that restore the env, code paths + // that clear ORT_DYLIB_PATH), and once we've explained the rejection the + // user doesn't need to see it again. + use std::sync::atomic::{AtomicBool, Ordering}; + static WARNED: AtomicBool = AtomicBool::new(false); + if WARNED.swap(true, Ordering::Relaxed) { + return; + } + + let mut seen: std::collections::HashSet = std::collections::HashSet::new(); + let unique: Vec<&PathBuf> = rejected + .iter() + .filter(|p| { + let canon = p.canonicalize().unwrap_or_else(|_| (*p).clone()); + seen.insert(canon) + }) + .collect(); + + let fallback = match requirement { + OrtRequirement::Any => "downloading a managed copy instead", + OrtRequirement::MIGraphX => "continuing to other runtime options", + }; + eprintln!( + "āš ļø Found {} ONNX Runtime candidate(s) that failed to load (wrong arch, broken \ + signature, missing dependencies, or companion library); {fallback}:", + unique.len() + ); + for p in unique { + eprintln!(" - {}", p.display()); + } +} + +#[cfg(all(feature = "migraphx", not(feature = "cuda")))] +fn migraphx_runtime_not_found_error() -> anyhow::Error { + anyhow::anyhow!( + "GPU execution was requested for a ROCm/MIGraphX build, but no ONNX Runtime library with MIGraphX support was found. Install AMD's ONNX Runtime package, for example `pip install onnxruntime-migraphx -f https://repo.radeon.com/rocm/manylinux/rocm-rel-/`, then set ORT_DYLIB_PATH to the wheel's `onnxruntime/capi/{}`. The official GitHub CPU ONNX Runtime package does not include MIGraphX.", + ORT_LIB_NAME + ) +} + /// Get list of paths to search for ONNX Runtime -#[cfg(not(feature = "cuda"))] +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] fn get_search_paths() -> Vec { let mut paths = Vec::new(); + // Explicit library/home env vars used by source builds and package managers. + for var in [ + "ORT_LIB_DIR", + "ORT_HOME", + "ONNXRUNTIME_LIB_DIR", + "ONNXRUNTIME_HOME", + ] { + if let Ok(value) = env::var(var) { + let path = PathBuf::from(value); + paths.push(path.clone()); + paths.push(path.join("lib")); + } + } + + if let Ok(conda_prefix) = env::var("CONDA_PREFIX") { + let conda_path = PathBuf::from(conda_prefix); + paths.push(conda_path.join("lib")); + push_python_prefix_runtime_dirs(&mut paths, &conda_path); + } + + if let Ok(virtual_env) = env::var("VIRTUAL_ENV") { + push_python_prefix_runtime_dirs(&mut paths, Path::new(&virtual_env)); + } + + if let Ok(current_dir) = env::current_dir() { + for venv_name in [".venv", "venv", ".env", "env", "python/.venv"] { + push_python_prefix_runtime_dirs(&mut paths, ¤t_dir.join(venv_name)); + } + } + // Home directory for cache if let Some(home) = dirs::home_dir() { // Our cache location (new path with cpu/gpu subdirs) @@ -470,48 +883,8 @@ fn get_search_paths() -> Vec { // Legacy cache location (for backwards compatibility) paths.push(home.join(".cache").join("onnxruntime").join(ORT_VERSION)); - // Conda environments - if let Ok(conda_prefix) = env::var("CONDA_PREFIX") { - let conda_path = PathBuf::from(&conda_prefix); - paths.push(conda_path.join("lib")); - - // Python site-packages in conda - for entry in [ - "lib/python3.12", - "lib/python3.11", - "lib/python3.10", - "lib/python3.9", - ] { - paths.push( - conda_path - .join(entry) - .join("site-packages/onnxruntime/capi"), - ); - } - } - - // Virtual environments - for venv_name in [".venv", "venv", ".env", "env"] { - let venv_path = std::env::current_dir() - .map(|cwd| cwd.join(venv_name)) - .unwrap_or_default(); - - #[cfg(target_os = "windows")] - paths.push(venv_path.join("Lib/site-packages/onnxruntime/capi")); - - #[cfg(not(target_os = "windows"))] - for py in ["python3.12", "python3.11", "python3.10", "python3.9"] { - paths.push( - venv_path - .join("lib") - .join(py) - .join("site-packages/onnxruntime/capi"), - ); - } - } - - // UV cache - paths.push(home.join(".cache/uv")); + push_python_prefix_runtime_dirs(&mut paths, &home.join(".local")); + push_uv_runtime_dirs(&mut paths, &home); // Homebrew (macOS) #[cfg(target_os = "macos")] @@ -532,6 +905,134 @@ fn get_search_paths() -> Vec { paths } +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn push_site_packages_runtime_dirs(paths: &mut Vec, site_packages: &Path) { + for package_dir in ["onnxruntime", "onnxruntime_migraphx"] { + paths.push(site_packages.join(package_dir).join("capi")); + } +} + +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn push_python_prefix_runtime_dirs(paths: &mut Vec, prefix: &Path) { + #[cfg(target_os = "windows")] + push_site_packages_runtime_dirs(paths, &prefix.join("Lib").join("site-packages")); + + #[cfg(not(target_os = "windows"))] + { + let lib_dir = prefix.join("lib"); + + // Dynamic discovery handles Python 3.13+ without hardcoding every minor + // version, while the fallback list still covers prefixes that do not + // exist yet at search-list construction time. + if let Ok(entries) = fs::read_dir(&lib_dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + if name.to_string_lossy().starts_with("python") { + push_site_packages_runtime_dirs(paths, &entry.path().join("site-packages")); + } + } + } + + for py in [ + "python3.14", + "python3.13", + "python3.12", + "python3.11", + "python3.10", + "python3.9", + ] { + push_site_packages_runtime_dirs(paths, &lib_dir.join(py).join("site-packages")); + } + } +} + +#[cfg(any(not(feature = "cuda"), feature = "migraphx"))] +fn push_uv_runtime_dirs(paths: &mut Vec, home: &Path) { + let uv_cache = home.join(".cache").join("uv"); + paths.push(uv_cache.clone()); + + if let Ok(entries) = fs::read_dir(&uv_cache) { + for entry in entries.flatten() { + let name = entry.file_name(); + if !name.to_string_lossy().starts_with("archive-v") { + continue; + } + + if let Ok(archives) = fs::read_dir(entry.path()) { + for archive in archives.flatten() { + for package_dir in ["onnxruntime", "onnxruntime_migraphx"] { + paths.push(archive.path().join(package_dir).join("capi")); + } + } + } + } + } +} + +#[cfg(feature = "migraphx")] +fn get_migraphx_search_paths() -> Vec { + let mut paths = get_search_paths(); + + // Custom/source-built ORT installs. These are only searched for MIGraphX + // because `candidate_satisfies_ort_requirement` filters out CPU-only ORT + // libraries by requiring the MIGraphX provider companion library. + #[cfg(target_os = "linux")] + { + for base in [ + PathBuf::from("/opt"), + PathBuf::from("/usr/local"), + PathBuf::from("/usr"), + ] { + paths.push(base.join("lib")); + if let Ok(entries) = fs::read_dir(&base) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy().to_ascii_lowercase(); + if name.contains("onnxruntime") || name.contains("ort") { + let path = entry.path(); + paths.push(path.clone()); + paths.push(path.join("lib")); + } + } + } + } + + for rocm_dir in get_rocm_library_dirs() { + paths.push(rocm_dir); + } + } + + paths +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn detects_migraphx_provider_companion_next_to_runtime() { + let dir = tempfile::tempdir().unwrap(); + let runtime = dir.path().join(ORT_LIB_NAME); + fs::write(&runtime, b"").unwrap(); + + let migraphx = provider_companion_path(&runtime, "migraphx").unwrap(); + fs::write(migraphx, b"").unwrap(); + + assert!(runtime_has_migraphx_provider_companion(&runtime)); + assert!(runtime_has_any_provider_companion(&runtime)); + } + + #[test] + fn missing_migraphx_provider_companion_is_not_migraphx_runtime() { + let dir = tempfile::tempdir().unwrap(); + let runtime = dir.path().join(ORT_LIB_NAME); + fs::write(&runtime, b"").unwrap(); + + assert!(!runtime_has_migraphx_provider_companion(&runtime)); + assert!(!runtime_has_any_provider_companion(&runtime)); + } +} + /// Download ONNX Runtime from GitHub releases fn download_onnx_runtime() -> Result { let cache_dir = dirs::home_dir() diff --git a/next-plaid-onnx/README.md b/next-plaid-onnx/README.md index 573e12bd..2e9f0f21 100644 --- a/next-plaid-onnx/README.md +++ b/next-plaid-onnx/README.md @@ -79,6 +79,12 @@ next-plaid-onnx = { version = "0.2", features = ["directml"] } `ExecutionProvider::Auto` tries providers in order: CUDA → TensorRT → CoreML → DirectML → CPU. Set `NEXT_PLAID_FORCE_CPU=1` to bypass all GPU providers. +For ROCm, install AMD's ONNX Runtime wheel for your ROCm release (for example +`pip install onnxruntime-migraphx -f https://repo.radeon.com/rocm/manylinux/rocm-rel-/`) +or provide a custom ONNX Runtime build, then set `ORT_DYLIB_PATH` to +`.../site-packages/onnxruntime/capi/libonnxruntime.so` before starting the +process. The official GitHub CPU ONNX Runtime package does not include MIGraphX. + ### Token Pooling Reduce token count with hierarchical clustering (Ward's method): diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 609c4ad7..0ec947cd 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -306,6 +306,41 @@ pub fn is_gpu_available() -> bool { preferred_gpu_execution_provider().is_some() } +fn execution_provider_list_display(providers: &[ExecutionProvider]) -> String { + providers + .iter() + .map(|provider| provider.display_name()) + .collect::>() + .join(", ") +} + +fn unavailable_gpu_execution_provider_reason() -> String { + let compiled = compiled_gpu_execution_providers(); + if compiled.is_empty() { + "no GPU execution provider was compiled. Enable a feature such as 'cuda', 'migraphx', 'coreml', or 'directml'.".to_string() + } else { + let names = execution_provider_list_display(&compiled); + let rocm_hint = if compiled.contains(&ExecutionProvider::MIGraphX) { + " For ROCm/MIGraphX, install AMD's `onnxruntime-migraphx` wheel or use a custom ORT build, then set ORT_DYLIB_PATH to its `onnxruntime/capi/libonnxruntime.so`." + } else { + "" + }; + format!( + "no compiled GPU execution provider is available in the loaded ONNX Runtime library. Compiled provider(s): {names}.{rocm_hint}" + ) + } +} + +/// Return the preferred available GPU execution provider or a user-facing error. +pub fn require_gpu_execution_provider() -> Result { + preferred_gpu_execution_provider().ok_or_else(|| { + anyhow::anyhow!( + "GPU execution requested, but {}", + unavailable_gpu_execution_provider_reason() + ) + }) +} + fn configure_execution_provider( builder: SessionBuilder, provider: ExecutionProvider, From d218d60471274c78d9cf284055ca9d5e34b0919b Mon Sep 17 00:00:00 2001 From: Guillaume Ausset Date: Tue, 26 May 2026 21:55:16 +0200 Subject: [PATCH 3/4] Add cache-hit-only MIGraphX static-shape indexing Teach the ONNX layer to specialize MIGraphX sessions to validated static document shapes, key the cache by the selected model and provider options, and preserve strict --force-gpu semantics. ColGREP auto mode keeps CPU as the default path: it only enables MIGraphX for warm eligible document shapes when the run has enough work to amortize session/GPU overhead, with CPU fallback for cold shapes. Add opt-in COLGREP_PROFILE diagnostics so backend routing, model loading, and indexing/search phases can be measured without changing normal output. --- Cargo.lock | 2 + colgrep/src/commands/init.rs | 174 +- colgrep/src/commands/search.rs | 131 +- colgrep/src/config.rs | 24 + colgrep/src/hardware.rs | 638 +++++++ colgrep/src/index/mod.rs | 673 +++++-- colgrep/src/lib.rs | 2 + colgrep/src/model.rs | 32 +- colgrep/src/profile.rs | 171 ++ next-plaid-onnx/Cargo.toml | 2 + next-plaid-onnx/src/lib.rs | 3063 ++++++++++++++++++++++++++++++-- 11 files changed, 4478 insertions(+), 434 deletions(-) create mode 100644 colgrep/src/hardware.rs create mode 100644 colgrep/src/profile.rs diff --git a/Cargo.lock b/Cargo.lock index 28f84bb4..553ed054 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2552,6 +2552,7 @@ name = "next-plaid-onnx" version = "1.5.2" dependencies = [ "anyhow", + "fs2", "glob", "ndarray 0.16.1", "numpy", @@ -2560,6 +2561,7 @@ dependencies = [ "rayon", "serde", "serde_json", + "sha2", "tokenizers", ] diff --git a/colgrep/src/commands/init.rs b/colgrep/src/commands/init.rs index 2f7a154c..5f832ec0 100644 --- a/colgrep/src/commands/init.rs +++ b/colgrep/src/commands/init.rs @@ -4,6 +4,7 @@ use anyhow::Result; use crate::commands::search::{resolve_model, resolve_pool_factor}; use colgrep::{ensure_model, find_parent_index, index_exists, Config, IndexBuilder}; +use next_plaid_onnx::ExecutionProvider; pub struct InitOptions<'a> { pub cli_model: Option<&'a str>, @@ -29,84 +30,117 @@ fn resolve_index_runtime_overrides( } pub fn cmd_init(path: &PathBuf, options: InitOptions<'_>) -> Result<()> { - let path = std::fs::canonicalize(path) - .map_err(|_| anyhow::anyhow!("Path does not exist: {}", path.display()))?; - - if !path.is_dir() { - anyhow::bail!("Path is not a directory: {}", path.display()); - } - - let model = resolve_model(options.cli_model); - let pool_factor = resolve_pool_factor(options.pool_factor, options.no_pool); - - let config = Config::load().unwrap_or_default(); - let quantized = !config.use_fp32(); - let (parallel_sessions, batch_size) = - resolve_index_runtime_overrides(&config, options.batch_size); - - // Check if path is inside an already-indexed parent project - let parent_info = find_parent_index(&path, &model)?; - let effective_root = match &parent_info { - Some(info) => info.project_path.clone(), - None => path.clone(), - }; - - // Check if index already exists for the effective root - let has_existing_index = index_exists(&effective_root, &model); - - // Ensure model is downloaded - let model_path = ensure_model(Some(&model), has_existing_index)?; - - let mut builder = IndexBuilder::with_options( - &effective_root, - &model, - &model_path, - quantized, - pool_factor, - parallel_sessions, - batch_size, - )?; - builder.set_auto_confirm(options.auto_confirm); - builder.set_dynamic_batch(!options.static_batch); - if let Some(encode_batch_size) = options.encode_batch_size { - builder.set_encode_batch_size(encode_batch_size.max(1)); - } - if let Some(index_chunk_size) = options.index_chunk_size { - builder.set_index_chunk_size(index_chunk_size.max(1)); - } - let stats = builder.index(None, false)?; + colgrep::profile::start_command("init"); + let result = (|| -> Result<()> { + let path = colgrep::profile::time_result("init.canonicalize_path", || { + std::fs::canonicalize(path) + .map_err(|_| anyhow::anyhow!("Path does not exist: {}", path.display())) + })?; + + if !path.is_dir() { + anyhow::bail!("Path is not a directory: {}", path.display()); + } - let changes = stats.added + stats.changed + stats.deleted; - if changes > 0 { - if let Some(ref info) = parent_info { - eprintln!( - "Indexed {} (subdir: {}) (added: {}, changed: {}, deleted: {}, unchanged: {})", - info.project_path.display(), - info.relative_subdir.display(), - stats.added, - stats.changed, - stats.deleted, - stats.unchanged, + let model = resolve_model(options.cli_model); + let pool_factor = resolve_pool_factor(options.pool_factor, options.no_pool); + + let config = Config::load().unwrap_or_default(); + let quantized = !config.use_fp32(); + let cpu_fallback_quantized = + !config.use_fp32_for_execution_provider(ExecutionProvider::Cpu); + let (parallel_sessions, batch_size) = + resolve_index_runtime_overrides(&config, options.batch_size); + colgrep::profile::set_metadata("model", &model); + colgrep::profile::set_metadata("path", path.display().to_string()); + colgrep::profile::set_metadata("primary_quantized", quantized); + colgrep::profile::set_metadata("cpu_fallback_quantized", cpu_fallback_quantized); + colgrep::profile::set_metadata("parallel_sessions", parallel_sessions); + colgrep::profile::set_metadata("batch_size", batch_size); + + // Check if path is inside an already-indexed parent project, and reuse + // the parent index when invoked from a subdirectory. + let (parent_info, has_existing_index) = + colgrep::profile::time_result("init.check_existing_index", || { + let parent_info = find_parent_index(&path, &model)?; + let has_existing_index = match &parent_info { + Some(info) => index_exists(&info.project_path, &model), + None => index_exists(&path, &model), + }; + Ok::<_, anyhow::Error>((parent_info, has_existing_index)) + })?; + let effective_root = match &parent_info { + Some(info) => info.project_path.clone(), + None => path.clone(), + }; + colgrep::profile::set_metadata("effective_root", effective_root.display().to_string()); + if let Some(info) = &parent_info { + colgrep::profile::set_metadata( + "relative_subdir", + info.relative_subdir.display().to_string(), ); + } + + // Ensure model is downloaded + let model_path = colgrep::profile::time_result("init.ensure_model", || { + ensure_model(Some(&model), has_existing_index) + })?; + colgrep::profile::set_metadata("model_path", model_path.display().to_string()); + + let mut builder = IndexBuilder::with_options( + &effective_root, + &model, + &model_path, + quantized, + pool_factor, + parallel_sessions, + batch_size, + )?; + builder.set_cpu_fallback_quantized(cpu_fallback_quantized); + builder.set_auto_confirm(options.auto_confirm); + builder.set_dynamic_batch(!options.static_batch); + if let Some(encode_batch_size) = options.encode_batch_size { + builder.set_encode_batch_size(encode_batch_size.max(1)); + } + if let Some(index_chunk_size) = options.index_chunk_size { + builder.set_index_chunk_size(index_chunk_size.max(1)); + } + let stats = + colgrep::profile::time_result("init.index_total", || builder.index(None, false))?; + + let changes = stats.added + stats.changed + stats.deleted; + if changes > 0 { + if let Some(ref info) = parent_info { + eprintln!( + "Indexed {} (subdir: {}) (added: {}, changed: {}, deleted: {}, unchanged: {})", + info.project_path.display(), + info.relative_subdir.display(), + stats.added, + stats.changed, + stats.deleted, + stats.unchanged, + ); + } else { + eprintln!( + "Indexed {} (added: {}, changed: {}, deleted: {}, unchanged: {})", + effective_root.display(), + stats.added, + stats.changed, + stats.deleted, + stats.unchanged, + ); + } } else { eprintln!( - "Indexed {} (added: {}, changed: {}, deleted: {}, unchanged: {})", + "Index is up to date for {} ({} files)", effective_root.display(), - stats.added, - stats.changed, - stats.deleted, - stats.unchanged, + stats.unchanged ); } - } else { - eprintln!( - "Index is up to date for {} ({} files)", - effective_root.display(), - stats.unchanged - ); - } - Ok(()) + Ok(()) + })(); + colgrep::profile::finish_command(result.is_ok()); + result } #[cfg(test)] diff --git a/colgrep/src/commands/search.rs b/colgrep/src/commands/search.rs index 111884b3..07511e57 100644 --- a/colgrep/src/commands/search.rs +++ b/colgrep/src/commands/search.rs @@ -9,6 +9,7 @@ use colgrep::{ get_index_dir_for_project, get_vector_index_path, index_exists, is_text_format, path_contains_ignored_dir, Config, IndexBuilder, IndexState, Searcher, DEFAULT_MODEL, }; +use next_plaid_onnx::ExecutionProvider; use crate::display::{ calc_display_ranges, find_representative_lines, group_results_by_file, @@ -558,6 +559,64 @@ pub fn cmd_search( auto_confirm: bool, static_batch: bool, ) -> Result<()> { + colgrep::profile::command_result("search", || { + cmd_search_impl( + query, + paths, + top_k, + top_k_explicit, + cli_model, + json, + include_patterns, + files_only, + show_content, + cli_context_lines, + text_pattern, + extended_regexp, + fixed_strings, + word_regexp, + case_sensitive, + exclude_patterns, + exclude_dirs, + code_only, + no_fts, + alpha, + pool_factor, + auto_confirm, + static_batch, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn cmd_search_impl( + query: &str, + paths: &[PathBuf], + top_k: usize, + top_k_explicit: bool, + cli_model: Option<&str>, + json: bool, + include_patterns: &[String], + files_only: bool, + show_content: bool, + cli_context_lines: Option, + text_pattern: Option<&str>, + extended_regexp: bool, + fixed_strings: bool, + word_regexp: bool, + case_sensitive: bool, + exclude_patterns: &[String], + exclude_dirs: &[String], + code_only: bool, + no_fts: bool, + alpha: Option, + pool_factor: Option, + auto_confirm: bool, + static_batch: bool, +) -> Result<()> { + colgrep::profile::set_metadata("query", query); + colgrep::profile::set_metadata("top_k", top_k); + colgrep::profile::set_metadata("path_count", paths.len()); // Resolve context_lines: CLI > config > default (20) let context_lines = resolve_context_lines(cli_context_lines, 20); // Resolve relative paths: config > default (false = absolute) @@ -1053,13 +1112,17 @@ fn search_single_path( auto_confirm: bool, static_batch: bool, ) -> Result> { - let path = match std::fs::canonicalize(path) { - Ok(p) => p, - Err(_) => { - let help = find_existing_parent_and_list(path); - anyhow::bail!("Path does not exist: {}\n\n{}", path.display(), help); - } - }; + let path = + colgrep::profile::time_result("search.canonicalize_path", || match std::fs::canonicalize( + path, + ) { + Ok(p) => Ok(p), + Err(_) => { + let help = find_existing_parent_and_list(path); + anyhow::bail!("Path does not exist: {}\n\n{}", path.display(), help); + } + })?; + colgrep::profile::set_metadata("path", path.display().to_string()); // Check if path is a file (not a directory) // If so, we'll use the parent directory for indexing and filter to this specific file @@ -1079,6 +1142,7 @@ fn search_single_path( // Resolve model: CLI > config > default let model = resolve_model(cli_model); + colgrep::profile::set_metadata("model", &model); // Load config for settings let config = Config::load().unwrap_or_default(); @@ -1094,20 +1158,23 @@ fn search_single_path( index_exists(&search_path, &model) || find_parent_index(&search_path, &model)?.is_some(); // Ensure model is downloaded (quiet if we already have an index) - let model_path = ensure_model(Some(&model), has_existing_index)?; + let model_path = colgrep::profile::time_result("search.ensure_model", || { + ensure_model(Some(&model), has_existing_index) + })?; + colgrep::profile::set_metadata("model_path", model_path.display().to_string()); // Check for parent index (scoped to current model) unless the resolved path // is outside the current directory (external project) - let parent_info = { + let parent_info = colgrep::profile::time_result("search.find_parent_index", || { let current_dir = std::env::current_dir().ok(); let is_external_project = is_external_project_path(&search_path, current_dir.as_deref()); if is_external_project { - None + Ok(None) } else { - find_parent_index(&search_path, &model)? + find_parent_index(&search_path, &model) } - }; + })?; // Determine effective project root and subdirectory filter let (effective_root, subdir_filter): (PathBuf, Option) = match &parent_info { @@ -1154,11 +1221,15 @@ fn search_single_path( parallel_sessions, batch_size, )?; + builder.set_cpu_fallback_quantized( + !config.use_fp32_for_execution_provider(ExecutionProvider::Cpu), + ); builder.set_auto_confirm(auto_confirm); builder.set_dynamic_batch(!static_batch); // Try non-blocking index update - match builder.try_index(None, false) { + match colgrep::profile::time_result("search.auto_index", || builder.try_index(None, false)) + { Ok(Some(stats)) => { let changes = stats.added + stats.changed + stats.deleted; if changes > 0 && !json && !files_only { @@ -1217,9 +1288,14 @@ fn search_single_path( parallel_sessions, batch_size, )?; + new_builder.set_cpu_fallback_quantized( + !config.use_fp32_for_execution_provider(ExecutionProvider::Cpu), + ); new_builder.set_auto_confirm(auto_confirm); new_builder.set_dynamic_batch(!static_batch); - new_builder.index(None, false)?; + colgrep::profile::time_result("search.rebuild_index", || { + new_builder.index(None, false) + })?; } else { return Err(e); } @@ -1251,18 +1327,26 @@ fn search_single_path( // If loading fails while another process holds the lock, retry a few times in case // the failure is due to a transient mid-write state. // If loading fails without a concurrent updater, clear and rebuild the index. + let cpu_fallback_quantized = !config.use_fp32_for_execution_provider(ExecutionProvider::Cpu); let load_searcher = || -> Result { match &parent_info { - Some(info) => Searcher::load_from_index_dir_with_quantized( + Some(info) => Searcher::load_from_index_dir_with_quantized_options( &info.index_dir, &model_path, quantized, + cpu_fallback_quantized, + ), + None => Searcher::load_with_quantized_options( + &effective_root, + &model, + &model_path, + quantized, + cpu_fallback_quantized, ), - None => Searcher::load_with_quantized(&effective_root, &model, &model_path, quantized), } }; - let searcher = match load_searcher() { + let searcher = match colgrep::profile::time_result("search.load_searcher", || load_searcher()) { Ok(s) => s, Err(e) if index_locked => { // Another process is updating the index — the load failure is likely @@ -1277,7 +1361,9 @@ fn search_single_path( let mut loaded = None; for _ in 0..MAX_RETRIES { std::thread::sleep(RETRY_DELAY); - match load_searcher() { + match colgrep::profile::time_result("search.load_searcher_retry", || { + load_searcher() + }) { Ok(s) => { loaded = Some(s); break; @@ -1328,11 +1414,14 @@ fn search_single_path( parallel_sessions, batch_size, )?; + builder.set_cpu_fallback_quantized( + !config.use_fp32_for_execution_provider(ExecutionProvider::Cpu), + ); builder.set_auto_confirm(auto_confirm); builder.set_dynamic_batch(!static_batch); - builder.index(None, false)?; + colgrep::profile::time_result("search.rebuild_index", || builder.index(None, false))?; - load_searcher()? + colgrep::profile::time_result("search.load_searcher", || load_searcher())? } Err(e) => return Err(e), }; @@ -1512,6 +1601,7 @@ fn search_single_path( // When no -e flag is provided, run BOTH semantic/hybrid search and text-pattern search // This ensures exact matches are found even if the vector database doesn't rank them highly + let search_execute_phase = colgrep::profile::phase("search.execute"); let results = if let Some(pattern) = &text_pattern { // -e flag provided: use existing hybrid search logic // Enhance semantic query with -e pattern (strip regex metacharacters and dedupe tokens) @@ -1632,6 +1722,7 @@ fn search_single_path( } merged.into_values().collect::>() }; + drop(search_execute_phase); // Note: When -e is used, results are already filtered to units containing the pattern // via filter_by_text_pattern_with_options() above, which supports -E, -F, -w flags. diff --git a/colgrep/src/config.rs b/colgrep/src/config.rs index 25a5fa84..4f3f8d11 100644 --- a/colgrep/src/config.rs +++ b/colgrep/src/config.rs @@ -6,6 +6,7 @@ use std::fs; use std::path::PathBuf; use anyhow::{Context, Result}; +use next_plaid_onnx::ExecutionProvider; use serde::{Deserialize, Serialize}; #[cfg(any( @@ -33,6 +34,15 @@ pub const DEFAULT_BATCH_SIZE_CPU: usize = 1; /// With 1 session, larger batch size (64) is optimal for GPU throughput pub const DEFAULT_BATCH_SIZE_GPU: usize = 64; +pub fn default_batch_size_for_execution_provider(provider: ExecutionProvider) -> usize { + match provider { + ExecutionProvider::Cpu => DEFAULT_BATCH_SIZE_CPU, + provider if provider.is_gpu() => DEFAULT_BATCH_SIZE_GPU, + ExecutionProvider::Auto => get_default_batch_size(), + _ => DEFAULT_BATCH_SIZE_CPU, + } +} + /// Default batch size - use GPU default when a GPU inference provider is enabled, CPU otherwise. /// Note: At compile time we set the GPU default, but at runtime we check provider availability. #[cfg(any( @@ -303,6 +313,20 @@ impl Config { } } + /// Check whether FP32/non-INT8 should be used for a resolved execution provider. + /// + /// An explicit `settings --fp32` / `settings --int8` preference is honored + /// for every provider. Without an explicit preference, GPU providers use + /// the non-INT8 ONNX graph (MIGraphX may prefer `model_fp16.onnx` or apply + /// FP16 conversion), while CPU uses the pre-quantized INT8 graph. + pub fn use_fp32_for_execution_provider(&self, provider: ExecutionProvider) -> bool { + if let Some(fp32) = self.fp32 { + return fp32; + } + + matches!(provider, ExecutionProvider::MIGraphX) || provider.is_gpu() + } + /// Set whether to use FP32 (non-quantized) model pub fn set_fp32(&mut self, fp32: bool) { self.fp32 = Some(fp32); diff --git a/colgrep/src/hardware.rs b/colgrep/src/hardware.rs new file mode 100644 index 00000000..cb83c49f --- /dev/null +++ b/colgrep/src/hardware.rs @@ -0,0 +1,638 @@ +//! Cheap host hardware inspection for production acceleration heuristics. +//! +//! This module intentionally uses `/proc` and `/sys` on Linux instead of ROCm +//! APIs. The auto path can call it before deciding whether touching the GPU +//! stack is worthwhile. + +use serde::Serialize; +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; + +pub const DEFAULT_MIGRAPHX_AUTO_MIN_UNITS: usize = 10_000; + +#[derive(Debug, Clone, Serialize)] +pub struct CpuInfo { + pub logical_cores: usize, + pub model_name: Option, + pub has_avx2: bool, + pub has_avx512: bool, + pub has_fma: bool, + pub has_neon: bool, +} + +#[derive(Debug, Clone, Serialize)] +pub struct AmdGpuInfo { + pub drm_node: String, + pub vendor_id: Option, + pub device_id: Option, + pub driver: Option, + pub vram_total_bytes: Option, + pub gtt_total_bytes: Option, + pub max_sclk_mhz: Option, + pub integrated_guess: bool, +} + +#[derive(Debug, Clone, Serialize)] +pub struct MigraphxAutoPolicy { + pub min_units: usize, + pub source: String, + pub reason: String, + pub cpu: CpuInfo, + pub gpu: Option, + pub model: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct ModelInfo { + pub model_name: Option, + pub model_class: Option, + pub embedding_dim: Option, + pub query_length: Option, + pub document_length: Option, + pub hidden_size: Option, + pub intermediate_size: Option, + pub num_hidden_layers: Option, + pub num_attention_heads: Option, + pub local_attention: Option, + pub global_attn_every_n_layers: Option, + pub estimated_query_macs: Option, + pub estimated_document_macs: Option, + pub model_onnx_bytes: Option, + pub model_fp16_onnx_bytes: Option, + pub model_int8_onnx_bytes: Option, +} + +pub fn migraphx_auto_policy() -> MigraphxAutoPolicy { + migraphx_auto_policy_for_model(None) +} + +pub fn migraphx_auto_policy_for_model(model_dir: Option<&Path>) -> MigraphxAutoPolicy { + let cpu = detect_cpu_info(); + let gpu = detect_amd_gpus().into_iter().next(); + let model = model_dir.map(detect_model_info); + + if let Some(value) = std::env::var("NEXT_PLAID_MIGRAPHX_AUTO_MIN_UNITS") + .ok() + .and_then(|value| value.trim().parse::().ok()) + .filter(|value| *value > 0) + { + return MigraphxAutoPolicy { + min_units: value, + source: "env".to_string(), + reason: "NEXT_PLAID_MIGRAPHX_AUTO_MIN_UNITS override".to_string(), + cpu, + gpu, + model, + }; + } + + let (min_units, reason) = estimate_migraphx_auto_min_units(&cpu, gpu.as_ref(), model.as_ref()); + MigraphxAutoPolicy { + min_units, + source: "hardware".to_string(), + reason, + cpu, + gpu, + model, + } +} + +fn detect_model_info(model_dir: &Path) -> ModelInfo { + let onnx_config = fs::read_to_string(model_dir.join("onnx_config.json")) + .ok() + .and_then(|contents| serde_json::from_str::(&contents).ok()); + let model_config = fs::read_to_string(model_dir.join("config.json")) + .ok() + .and_then(|contents| serde_json::from_str::(&contents).ok()); + let get_onnx_string = |key: &str| { + onnx_config + .as_ref() + .and_then(|value| value.get(key)) + .and_then(|value| value.as_str()) + .map(ToString::to_string) + }; + let get_onnx_usize = |key: &str| { + onnx_config + .as_ref() + .and_then(|value| value.get(key)) + .and_then(|value| value.as_u64()) + .and_then(|value| usize::try_from(value).ok()) + }; + let get_model_usize = |key: &str| { + model_config + .as_ref() + .and_then(|value| value.get(key)) + .and_then(|value| value.as_u64()) + .and_then(|value| usize::try_from(value).ok()) + }; + + let query_length = get_onnx_usize("query_length"); + let document_length = get_onnx_usize("document_length"); + let hidden_size = get_model_usize("hidden_size"); + let intermediate_size = get_model_usize("intermediate_size"); + let num_hidden_layers = get_model_usize("num_hidden_layers"); + let num_attention_heads = get_model_usize("num_attention_heads"); + let local_attention = get_model_usize("local_attention"); + let global_attn_every_n_layers = get_model_usize("global_attn_every_n_layers"); + let estimated_query_macs = estimate_transformer_macs( + query_length, + hidden_size, + intermediate_size, + num_hidden_layers, + local_attention, + global_attn_every_n_layers, + ); + let estimated_document_macs = estimate_transformer_macs( + document_length, + hidden_size, + intermediate_size, + num_hidden_layers, + local_attention, + global_attn_every_n_layers, + ); + + ModelInfo { + model_name: get_onnx_string("model_name"), + model_class: get_onnx_string("model_class"), + embedding_dim: get_onnx_usize("embedding_dim"), + query_length, + document_length, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + local_attention, + global_attn_every_n_layers, + estimated_query_macs, + estimated_document_macs, + model_onnx_bytes: file_len(model_dir.join("model.onnx")), + model_fp16_onnx_bytes: file_len(model_dir.join("model_fp16.onnx")), + model_int8_onnx_bytes: file_len(model_dir.join("model_int8.onnx")), + } +} + +fn estimate_transformer_macs( + seq_len: Option, + hidden_size: Option, + intermediate_size: Option, + num_hidden_layers: Option, + local_attention: Option, + global_attn_every_n_layers: Option, +) -> Option { + let seq_len = seq_len?; + let hidden_size = hidden_size?; + let intermediate_size = intermediate_size?; + let num_hidden_layers = num_hidden_layers?; + if seq_len == 0 || hidden_size == 0 || intermediate_size == 0 || num_hidden_layers == 0 { + return None; + } + + let seq = seq_len as u128; + let hidden = hidden_size as u128; + let intermediate = intermediate_size as u128; + let mut total = 0u128; + + for layer in 0..num_hidden_layers { + // Transformer block approximation in MACs, not FLOPs: + // - Q/K/V/O projections: 4 Ɨ S Ɨ H Ɨ H + // - MLP up/down projections: 2 Ɨ S Ɨ H Ɨ I + // - attention score and value matmuls: 2 Ɨ S Ɨ attention_window Ɨ H + let linear_macs = seq * ((4 * hidden * hidden) + (2 * hidden * intermediate)); + let is_global_attention = global_attn_every_n_layers + .filter(|every| *every > 0) + .is_none_or(|every| layer % every == 0); + let attention_window = if is_global_attention { + seq_len + } else { + local_attention.unwrap_or(seq_len).min(seq_len) + } as u128; + let attention_macs = 2 * seq * attention_window * hidden; + total = total.saturating_add(linear_macs.saturating_add(attention_macs)); + } + + Some(total.min(u64::MAX as u128) as u64) +} + +fn file_len(path: PathBuf) -> Option { + fs::metadata(path).ok().map(|metadata| metadata.len()) +} + +fn detect_cpu_info() -> CpuInfo { + let logical_cores = std::thread::available_parallelism() + .map(|cores| cores.get()) + .unwrap_or(1); + + let mut model_name = None; + let mut flags = HashSet::new(); + + #[cfg(target_os = "linux")] + if let Ok(cpuinfo) = fs::read_to_string("/proc/cpuinfo") { + for line in cpuinfo.lines() { + if model_name.is_none() { + if let Some((key, value)) = line.split_once(':') { + if matches!(key.trim(), "model name" | "Hardware") { + let value = value.trim(); + if !value.is_empty() { + model_name = Some(value.to_string()); + } + } + } + } + if flags.is_empty() { + if let Some((key, value)) = line.split_once(':') { + if matches!(key.trim(), "flags" | "Features") { + flags.extend(value.split_whitespace().map(|flag| flag.to_string())); + } + } + } + if model_name.is_some() && !flags.is_empty() { + break; + } + } + } + + CpuInfo { + logical_cores, + model_name, + has_avx2: flags.contains("avx2"), + has_avx512: flags.iter().any(|flag| flag.starts_with("avx512")), + has_fma: flags.contains("fma"), + has_neon: flags.contains("neon") || flags.contains("asimd"), + } +} + +fn detect_amd_gpus() -> Vec { + #[cfg(target_os = "linux")] + { + detect_amd_gpus_linux() + } + #[cfg(not(target_os = "linux"))] + { + Vec::new() + } +} + +#[cfg(target_os = "linux")] +fn detect_amd_gpus_linux() -> Vec { + let mut gpus = Vec::new(); + let mut seen_devices = HashSet::new(); + let Ok(entries) = fs::read_dir("/sys/class/drm") else { + return gpus; + }; + + let mut drm_nodes = entries + .flatten() + .filter_map(|entry| { + let name = entry.file_name().to_string_lossy().to_string(); + (name.starts_with("renderD") || (name.starts_with("card") && !name.contains('-'))) + .then_some((name, entry.path())) + }) + .collect::>(); + // Prefer render nodes, which are less likely to be display connector dirs. + drm_nodes.sort_by_key(|(name, _)| if name.starts_with("renderD") { 0 } else { 1 }); + + for (name, path) in drm_nodes { + let device_path = path.join("device"); + let canonical = fs::canonicalize(&device_path).unwrap_or_else(|_| device_path.clone()); + if !seen_devices.insert(canonical) { + continue; + } + let vendor_id = read_trimmed(device_path.join("vendor")); + if !vendor_id + .as_deref() + .is_some_and(|vendor| vendor.eq_ignore_ascii_case("0x1002")) + { + continue; + } + + let device_id = read_trimmed(device_path.join("device")); + let driver = read_uevent_value(&device_path, "DRIVER"); + let vram_total_bytes = read_u64(device_path.join("mem_info_vram_total")); + let gtt_total_bytes = read_u64(device_path.join("mem_info_gtt_total")); + let max_sclk_mhz = read_max_clock_mhz(device_path.join("pp_dpm_sclk")); + let integrated_guess = looks_integrated_gpu(vram_total_bytes, gtt_total_bytes); + + gpus.push(AmdGpuInfo { + drm_node: name, + vendor_id, + device_id, + driver, + vram_total_bytes, + gtt_total_bytes, + max_sclk_mhz, + integrated_guess, + }); + } + + gpus.sort_by_key(|gpu| std::cmp::Reverse(gpu.vram_total_bytes.unwrap_or(0))); + gpus +} + +#[cfg(target_os = "linux")] +fn read_trimmed>(path: P) -> Option { + fs::read_to_string(path) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) +} + +#[cfg(target_os = "linux")] +fn read_u64(path: PathBuf) -> Option { + read_trimmed(path).and_then(|value| value.parse::().ok()) +} + +#[cfg(target_os = "linux")] +fn read_uevent_value(device_path: &Path, key: &str) -> Option { + let uevent = fs::read_to_string(device_path.join("uevent")).ok()?; + uevent.lines().find_map(|line| { + line.split_once('=').and_then(|(line_key, value)| { + (line_key == key && !value.trim().is_empty()).then(|| value.trim().to_string()) + }) + }) +} + +#[cfg(target_os = "linux")] +fn read_max_clock_mhz(path: PathBuf) -> Option { + let contents = fs::read_to_string(path).ok()?; + contents.lines().filter_map(parse_clock_mhz).max() +} + +#[cfg(target_os = "linux")] +fn parse_clock_mhz(line: &str) -> Option { + let lower = line.to_ascii_lowercase(); + let idx = lower.find("mhz")?; + let before = &lower[..idx]; + let digits = before + .chars() + .rev() + .skip_while(|ch| ch.is_whitespace() || *ch == '*') + .take_while(|ch| ch.is_ascii_digit()) + .collect::(); + if digits.is_empty() { + return None; + } + digits.chars().rev().collect::().parse().ok() +} + +fn looks_integrated_gpu(vram_total_bytes: Option, gtt_total_bytes: Option) -> bool { + const TWO_GIB: u64 = 2 * 1024 * 1024 * 1024; + match (vram_total_bytes, gtt_total_bytes) { + (Some(vram), Some(gtt)) => vram < TWO_GIB || (vram < gtt / 4), + (Some(vram), None) => vram < TWO_GIB, + _ => false, + } +} + +fn estimate_migraphx_auto_min_units( + cpu: &CpuInfo, + gpu: Option<&AmdGpuInfo>, + model: Option<&ModelInfo>, +) -> (usize, String) { + let Some(gpu) = gpu else { + return ( + DEFAULT_MIGRAPHX_AUTO_MIN_UNITS, + "no AMD GPU detected via sysfs; provider availability still decides final fallback" + .to_string(), + ); + }; + + let gib = 1024 * 1024 * 1024u64; + let vram = gpu.vram_total_bytes.unwrap_or(0); + let mut threshold = if gpu.integrated_guess { + 2_000usize + } else if vram >= 12 * gib { + 1_000 + } else if vram >= 8 * gib { + 1_500 + } else if vram >= 4 * gib { + 2_500 + } else if vram >= 2 * gib { + 4_000 + } else { + 6_000 + }; + + let mut reasons = Vec::new(); + if gpu.integrated_guess { + reasons.push("integrated/UMA AMD GPU".to_string()); + } else { + reasons.push(format!("AMD GPU with {} GiB VRAM", vram / gib)); + } + + if cpu.logical_cores >= 24 && (cpu.has_avx2 || cpu.has_avx512) { + threshold = threshold.saturating_mul(3) / 2; + reasons.push(format!( + "strong CPU ({} logical cores with SIMD)", + cpu.logical_cores + )); + } else if cpu.logical_cores >= 16 && (cpu.has_avx2 || cpu.has_neon) { + threshold = threshold.saturating_mul(5) / 4; + reasons.push(format!( + "multi-core CPU ({} logical cores)", + cpu.logical_cores + )); + } else if cpu.logical_cores <= 8 { + threshold = threshold.saturating_mul(3) / 4; + reasons.push(format!("smaller CPU ({} logical cores)", cpu.logical_cores)); + } + + if let Some(max_sclk_mhz) = gpu.max_sclk_mhz { + if !gpu.integrated_guess && max_sclk_mhz >= 2400 { + threshold = threshold.saturating_mul(9) / 10; + reasons.push(format!("high GPU clock ({max_sclk_mhz} MHz)")); + } else if max_sclk_mhz < 1500 { + threshold = threshold.saturating_mul(5) / 4; + reasons.push(format!("low GPU clock ({max_sclk_mhz} MHz)")); + } + } + + if let Some(model) = model { + let before = threshold; + threshold = apply_model_complexity_adjustment(threshold, model); + if threshold < before { + reasons.push(model_complexity_reason(model, "larger model favors GPU")); + } else if threshold > before { + reasons.push(model_complexity_reason(model, "small model favors CPU")); + } + } + + threshold = threshold.clamp(500, 25_000); + (threshold, reasons.join("; ")) +} + +fn apply_model_complexity_adjustment(threshold: usize, model: &ModelInfo) -> usize { + // Estimated 1Ɨ2048 document MACs for LateOn-Code-edge from config.json + // (hidden=256, intermediate=384, layers=7, local attention=128, + // global attention every 3 layers). This anchors the model scaling to a + // known lightweight ColGREP model instead of raw file size. + const LATEON_EDGE_DOC_MACS: f64 = 13_555_990_528.0; + const EDGE_MODEL_BYTES: u64 = 65 * 1024 * 1024; + const LARGE_MODEL_BYTES: u64 = 400 * 1024 * 1024; + const MID_MODEL_BYTES: u64 = 150 * 1024 * 1024; + + if let Some(document_macs) = model.estimated_document_macs.filter(|macs| *macs > 0) { + let factor = (1.25 * (LATEON_EDGE_DOC_MACS / document_macs as f64).sqrt()).clamp(0.4, 1.25); + return ((threshold as f64) * factor).round() as usize; + } + + let complexity_bytes = model + .model_onnx_bytes + .or(model + .model_fp16_onnx_bytes + .map(|bytes| bytes.saturating_mul(2))) + .or(model + .model_int8_onnx_bytes + .map(|bytes| bytes.saturating_mul(4))); + + if model.embedding_dim.is_some_and(|dim| dim >= 128) + || complexity_bytes.is_some_and(|bytes| bytes >= LARGE_MODEL_BYTES) + { + threshold.saturating_mul(2) / 5 + } else if model.embedding_dim.is_some_and(|dim| dim >= 96) + || complexity_bytes.is_some_and(|bytes| bytes >= MID_MODEL_BYTES) + { + threshold.saturating_mul(3) / 5 + } else if model.embedding_dim.is_some_and(|dim| dim <= 64) + || complexity_bytes.is_some_and(|bytes| bytes <= EDGE_MODEL_BYTES) + { + threshold.saturating_mul(5) / 4 + } else { + threshold + } +} + +fn model_complexity_reason(model: &ModelInfo, prefix: &str) -> String { + let name = model.model_name.as_deref().unwrap_or("unknown model"); + let dim = model + .embedding_dim + .map(|dim| dim.to_string()) + .unwrap_or_else(|| "?".to_string()); + let mib = model + .model_onnx_bytes + .map(|bytes| (bytes as f64 / (1024.0 * 1024.0)).round() as u64) + .map(|mib| mib.to_string()) + .unwrap_or_else(|| "?".to_string()); + let doc_gmacs = model + .estimated_document_macs + .map(|macs| format!(", docā‰ˆ{:.1} GMAC", macs as f64 / 1_000_000_000.0)) + .unwrap_or_default(); + format!("{prefix}: {name} (dim={dim}, model.onnxā‰ˆ{mib} MiB{doc_gmacs})") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn cpu(logical_cores: usize, simd: bool) -> CpuInfo { + CpuInfo { + logical_cores, + model_name: None, + has_avx2: simd, + has_avx512: false, + has_fma: simd, + has_neon: false, + } + } + + fn gpu(vram_gib: u64, integrated_guess: bool) -> AmdGpuInfo { + AmdGpuInfo { + drm_node: "renderD128".to_string(), + vendor_id: Some("0x1002".to_string()), + device_id: Some("0x0000".to_string()), + driver: Some("amdgpu".to_string()), + vram_total_bytes: Some(vram_gib * 1024 * 1024 * 1024), + gtt_total_bytes: None, + max_sclk_mhz: None, + integrated_guess, + } + } + + fn model(embedding_dim: usize, model_mib: u64) -> ModelInfo { + ModelInfo { + model_name: Some(format!("test-dim-{embedding_dim}")), + model_class: Some("ModernBertModel".to_string()), + embedding_dim: Some(embedding_dim), + query_length: Some(256), + document_length: Some(2048), + hidden_size: None, + intermediate_size: None, + num_hidden_layers: None, + num_attention_heads: None, + local_attention: None, + global_attn_every_n_layers: None, + estimated_query_macs: None, + estimated_document_macs: None, + model_onnx_bytes: Some(model_mib * 1024 * 1024), + model_fp16_onnx_bytes: None, + model_int8_onnx_bytes: None, + } + } + + #[test] + fn integrated_gpu_with_strong_cpu_uses_high_threshold() { + let (threshold, reason) = + estimate_migraphx_auto_min_units(&cpu(32, true), Some(&gpu(1, true)), None); + + assert_eq!(threshold, 3_000); + assert!(reason.contains("integrated")); + } + + #[test] + fn discrete_gpu_with_small_cpu_uses_lower_threshold() { + let (threshold, reason) = + estimate_migraphx_auto_min_units(&cpu(8, true), Some(&gpu(16, false)), None); + + assert_eq!(threshold, 750); + assert!(reason.contains("16 GiB")); + } + + #[test] + fn larger_model_lowers_gpu_auto_threshold() { + let base_cpu = cpu(32, true); + let integrated = gpu(1, true); + let (edge_threshold, edge_reason) = + estimate_migraphx_auto_min_units(&base_cpu, Some(&integrated), Some(&model(48, 65))); + let (full_threshold, full_reason) = + estimate_migraphx_auto_min_units(&base_cpu, Some(&integrated), Some(&model(128, 570))); + + assert_eq!(edge_threshold, 3_750); + assert_eq!(full_threshold, 1_200); + assert!(edge_reason.contains("small model")); + assert!(full_reason.contains("larger model")); + } + + #[test] + fn transformer_mac_estimate_tracks_lateon_model_scale() { + let edge_doc_macs = estimate_transformer_macs( + Some(2048), + Some(256), + Some(384), + Some(7), + Some(128), + Some(3), + ) + .unwrap(); + let full_doc_macs = estimate_transformer_macs( + Some(2048), + Some(768), + Some(1152), + Some(22), + Some(128), + Some(3), + ) + .unwrap(); + let ratio = full_doc_macs as f64 / edge_doc_macs as f64; + + assert!((13_000_000_000..14_000_000_000).contains(&edge_doc_macs)); + assert!((17.0..19.0).contains(&ratio)); + } + + #[cfg(target_os = "linux")] + #[test] + fn parses_amdgpu_clock_lines() { + assert_eq!(parse_clock_mhz("2: 2900Mhz *"), Some(2900)); + assert_eq!(parse_clock_mhz("0: 600Mhz"), Some(600)); + assert_eq!(parse_clock_mhz("not a clock"), None); + } +} diff --git a/colgrep/src/index/mod.rs b/colgrep/src/index/mod.rs index 9aa162c7..b99584ba 100644 --- a/colgrep/src/index/mod.rs +++ b/colgrep/src/index/mod.rs @@ -158,6 +158,7 @@ const DEFAULT_ENCODE_BATCH_SIZE: usize = 64; /// Threshold for forcing CPU encoding even when a GPU provider is available. /// For small batches (< this many units), CPU is faster due to GPU initialization overhead. const SMALL_BATCH_CPU_THRESHOLD: usize = 300; + /// Bounded channel capacity between the pool and index stages. /// Kept small (4 chunks) to limit memory: each chunk holds full embeddings /// waiting to be written to disk. Back-pressure here slows encoding when @@ -334,6 +335,85 @@ fn prepare_deduplicated_chunk(unit_chunk: &[SortedUnit]) -> PreparedChunk { } } +fn should_auto_use_warm_migraphx_for_indexing( + model_path: &Path, + quantized: bool, + batch_size: usize, + num_units: usize, +) -> bool { + let trace_enabled = std::env::var("NEXT_PLAID_MIGRAPHX_TRACE") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + if let Ok(cache_root) = std::env::var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT") { + let cache_root = Path::new(&cache_root); + if !cache_root.exists() { + if trace_enabled { + eprintln!( + "__MIGRAPHX_CACHE_STATUS__ cache_root={} missing", + cache_root.display() + ); + } + return false; + } + if cache_root + .read_dir() + .map(|mut entries| entries.next().is_none()) + .unwrap_or(false) + { + if trace_enabled { + eprintln!( + "__MIGRAPHX_CACHE_STATUS__ cache_root={} empty", + cache_root.display() + ); + } + return false; + } + } + + let _ = crate::onnx_runtime::ensure_migraphx_onnx_runtime_path_for_cache_key(); + + match next_plaid_onnx::migraphx_static_shape_cache_status(model_path, quantized, batch_size) { + Ok(status) => { + if trace_enabled { + eprintln!( + "__MIGRAPHX_CACHE_STATUS__ cache_root={} model_key={} document_shapes={:?} warm_document_shapes={:?} cold_document_shapes={:?}", + status.cache_root.display(), + status.model_cache_key, + status.document_shapes, + status.warm_document_shapes, + status.cold_document_shapes + ); + } + if status.warm_document_shapes.is_empty() { + return false; + } + + let min_warm_sequence_len = status + .warm_document_shapes + .iter() + .map(|shape| shape.sequence_length) + .min() + .unwrap_or(0); + let estimated_warm_tokens = num_units.saturating_mul(min_warm_sequence_len.max(1)); + let min_run_tokens = next_plaid_onnx::migraphx_default_min_run_tokens(); + if trace_enabled { + eprintln!( + "__MIGRAPHX_CACHE_STATUS__ estimated_warm_tokens={} min_run_tokens={}", + estimated_warm_tokens, min_run_tokens + ); + } + estimated_warm_tokens >= min_run_tokens + } + Err(err) => { + if trace_enabled { + eprintln!("__MIGRAPHX_CACHE_STATUS__ error={err:#}"); + } + false + } + } +} + fn run_encode_stage( receiver: mpsc::Receiver, sender: mpsc::Sender, @@ -727,21 +807,23 @@ fn run_chunk_pipeline( // all downstream stages. drop(tokenize_tx); - tokenize_handle - .join() - .map_err(|_| anyhow::anyhow!("Tokenize stage thread panicked"))??; - encode_handle - .join() - .map_err(|_| anyhow::anyhow!("Encode stage thread panicked"))??; - pool_handle - .join() - .map_err(|_| anyhow::anyhow!("Pool stage thread panicked"))??; - index_handle - .join() - .map_err(|_| anyhow::anyhow!("Index stage thread panicked"))??; - metadata_handle - .join() - .map_err(|_| anyhow::anyhow!("Metadata stage thread panicked"))??; + let mut stage_errors = Vec::new(); + for (stage, result) in [ + ("tokenize", tokenize_handle.join()), + ("encode", encode_handle.join()), + ("pool", pool_handle.join()), + ("index", index_handle.join()), + ("metadata", metadata_handle.join()), + ] { + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => stage_errors.push(format!("{stage}: {err:#}")), + Err(_) => stage_errors.push(format!("{stage}: stage thread panicked")), + } + } + if !stage_errors.is_empty() { + anyhow::bail!("Indexing pipeline failed:\n{}", stage_errors.join("\n")); + } Ok(was_interrupted) } @@ -833,7 +915,10 @@ pub struct IndexBuilder { model: Option, /// Builder parameters for lazy model creation model_path: PathBuf, + /// Quantization setting for the primary execution provider. quantized: bool, + /// Quantization setting to use when execution resolves to CPU fallback. + cpu_fallback_quantized: bool, parallel_sessions: Option, batch_size: Option, project_root: PathBuf, @@ -888,6 +973,7 @@ impl IndexBuilder { model: None, // Lazily created when needed model_path: model_path.to_path_buf(), quantized, + cpu_fallback_quantized: quantized, parallel_sessions, batch_size, project_root: project_root.to_path_buf(), @@ -918,14 +1004,20 @@ impl IndexBuilder { self.dynamic_batch = dynamic_batch; } + pub fn set_cpu_fallback_quantized(&mut self, quantized: bool) { + self.cpu_fallback_quantized = quantized; + } + /// Ensure the model is created for encoding. /// The model is lazily created on first use to avoid overhead when just scanning files /// or when checking for index updates that have no changes. /// /// # Arguments /// * `num_units` - Number of code units to encode. Used to decide whether to use GPU or CPU. - /// For small batches (< SMALL_BATCH_CPU_THRESHOLD), CPU is preferred even when a GPU - /// provider is available, as GPU initialization overhead outweighs the benefits for small workloads. + /// For small batches (< SMALL_BATCH_CPU_THRESHOLD), CPU is preferred for dynamic GPU + /// providers because startup overhead outweighs the benefits. MIGraphX is handled + /// separately: warm static-shape caches are used even for small repos, while cold + /// caches fall back to CPU to avoid compilation cost. fn ensure_model_created(&mut self, num_units: usize) -> Result<()> { if self.model.is_none() { let acceleration_mode = env_acceleration_mode_lossy(); @@ -955,24 +1047,113 @@ impl IndexBuilder { ) } AccelerationMode::Auto => { - let force_cpu_for_small_batch = num_units < SMALL_BATCH_CPU_THRESHOLD; - if force_cpu_for_small_batch { - apply_acceleration_mode(AccelerationMode::ForceCpu); + // On ROCm/MIGraphX, probing provider availability starts + // the ROCm/ORT stack. For the common cold-cache/CPU path, + // decide first and force CPU before ORT initialization so + // a MIGraphX-feature build does not regress normal indexing. + if next_plaid_onnx::compiled_gpu_execution_provider() + == Some(ExecutionProvider::MIGraphX) + { + let migraphx_batch = self.batch_size.unwrap_or_else(|| { + crate::config::default_batch_size_for_execution_provider( + ExecutionProvider::MIGraphX, + ) + }); + let migraphx_policy = + crate::hardware::migraphx_auto_policy_for_model(Some(&self.model_path)); + let migraphx_min_units = migraphx_policy.min_units; + crate::profile::set_metadata( + "index.migraphx_auto_policy", + &migraphx_policy, + ); + if num_units < migraphx_min_units { + apply_acceleration_mode(AccelerationMode::ForceCpu); + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + eprintln!( + "ā„¹ļø Auto mode using CPU inference for indexing: {num_units} code units is below the MIGraphX auto threshold ({migraphx_min_units}). Use --force-gpu to benchmark MIGraphX." + ); + ( + self.parallel_sessions.unwrap_or_else( + crate::config::get_default_cpu_parallel_sessions, + ), + ExecutionProvider::Cpu, + ) + } else if should_auto_use_warm_migraphx_for_indexing( + &self.model_path, + self.quantized, + migraphx_batch, + num_units, + ) { + apply_acceleration_mode(AccelerationMode::Auto); + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + if preferred_gpu_provider() == Some(ExecutionProvider::MIGraphX) { + eprintln!( + "ā„¹ļø Warm MIGraphX cache shape(s) detected; auto mode using GPU inference for matching indexing batches with CPU fallback for cold shapes." + ); + ( + self.parallel_sessions + .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), + ExecutionProvider::MIGraphX, + ) + } else { + apply_acceleration_mode(AccelerationMode::ForceCpu); + ( + self.parallel_sessions.unwrap_or_else( + crate::config::get_default_cpu_parallel_sessions, + ), + ExecutionProvider::Cpu, + ) + } + } else { + apply_acceleration_mode(AccelerationMode::ForceCpu); + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + eprintln!( + "ā„¹ļø Auto mode using CPU inference for indexing: MIGraphX static-shape caches are not warm for this model/batch." + ); + ( + self.parallel_sessions.unwrap_or_else( + crate::config::get_default_cpu_parallel_sessions, + ), + ExecutionProvider::Cpu, + ) + } } else { apply_acceleration_mode(AccelerationMode::Auto); - } - - crate::onnx_runtime::ensure_onnx_runtime() - .context("Failed to initialize ONNX Runtime")?; + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; - if !force_cpu_for_small_batch { if let Some(provider) = preferred_gpu_provider() { + if num_units >= SMALL_BATCH_CPU_THRESHOLD { + ( + self.parallel_sessions + .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), + provider, + ) + } else { + apply_acceleration_mode(AccelerationMode::ForceCpu); + eprintln!( + "ā„¹ļø Auto mode using CPU inference for indexing: {num_units} code units is below the small-batch threshold ({SMALL_BATCH_CPU_THRESHOLD})." + ); + ( + self.parallel_sessions.unwrap_or_else( + crate::config::get_default_cpu_parallel_sessions, + ), + ExecutionProvider::Cpu, + ) + } + } else if num_units >= SMALL_BATCH_CPU_THRESHOLD { + apply_acceleration_mode(AccelerationMode::ForceCpu); ( - self.parallel_sessions - .unwrap_or(crate::config::DEFAULT_PARALLEL_SESSIONS_GPU), - provider, + self.parallel_sessions.unwrap_or_else( + crate::config::get_default_cpu_parallel_sessions, + ), + ExecutionProvider::Cpu, ) } else { + apply_acceleration_mode(AccelerationMode::ForceCpu); ( self.parallel_sessions.unwrap_or_else( crate::config::get_default_cpu_parallel_sessions, @@ -980,12 +1161,6 @@ impl IndexBuilder { ExecutionProvider::Cpu, ) } - } else { - ( - self.parallel_sessions - .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), - ExecutionProvider::Cpu, - ) } } }; @@ -994,26 +1169,56 @@ impl IndexBuilder { eprintln!("šŸ¤– Model: {}", self.model_id); eprintln!("šŸ“‚ Building index..."); - // Use runtime default for batch size (respects provider availability) - let batch = self - .batch_size - .unwrap_or_else(crate::config::get_default_batch_size); + // Use a provider-specific runtime default. For MIGraphX the batch + // size determines the static-shape token budget that colgrep can + // warm and reuse, so its default is intentionally smaller than + // throughput-oriented GPU providers. + let batch = self.batch_size.unwrap_or_else(|| { + crate::config::default_batch_size_for_execution_provider(execution_provider) + }); + let model_quantized = if execution_provider == ExecutionProvider::Cpu { + self.cpu_fallback_quantized + } else { + self.quantized + }; + crate::profile::set_metadata( + "index.execution_provider", + execution_provider.display_name(), + ); + crate::profile::set_metadata("index.num_sessions", num_sessions); + crate::profile::set_metadata("index.batch_size", batch); + crate::profile::set_metadata("index.model_quantized", model_quantized); // Suppress stderr during model loading to hide CoreML's harmless // "Context leak detected" warnings on macOS. // `with_suppressed_stderr` captures any panic message via a temporary // panic hook and prints it to the restored stderr before resuming, // so panics inside the suppressed region remain visible. - let model = crate::stderr::with_suppressed_stderr(|| { - Colbert::builder(&self.model_path) - .with_quantized(self.quantized) - .with_parallel(num_sessions) - .with_batch_size(batch) - .with_dynamic_batch(self.dynamic_batch) - .with_execution_provider(execution_provider) - .build() - }) - .context("Failed to load ColBERT model")?; + let model = crate::profile::time_result("index.model_load", || { + crate::stderr::with_suppressed_stderr(|| { + let mut builder = Colbert::builder(&self.model_path) + .with_quantized(model_quantized) + .with_parallel(num_sessions) + .with_batch_size(batch) + .with_dynamic_batch(self.dynamic_batch) + .with_execution_provider(execution_provider) + .with_migraphx_cpu_fallback_quantized(self.cpu_fallback_quantized); + + if acceleration_mode == AccelerationMode::Auto + && execution_provider == ExecutionProvider::MIGraphX + { + // Auto selection already applied the command-level + // warm-work threshold. Do not re-apply it per index + // chunk, otherwise normal 1024-unit index chunks never + // reach the intended run-level threshold for short + // document buckets. + builder = builder.with_migraphx_min_run_tokens(0); + } + + builder.build() + }) + .context("Failed to load ColBERT model") + })?; self.model = Some(model); } @@ -1051,7 +1256,7 @@ impl IndexBuilder { let model = crate::stderr::with_suppressed_stderr(|| { Colbert::builder(&self.model_path) - .with_quantized(self.quantized) + .with_quantized(self.cpu_fallback_quantized) .with_parallel(num_sessions) .with_batch_size(batch) .with_dynamic_batch(false) @@ -1107,8 +1312,8 @@ impl IndexBuilder { if accel == AccelerationMode::ForceGpu { anyhow::bail!( "GPU encoding failed with --force-gpu. \ - Not enough GPU memory for batch size {batch} and document length. \ - Try reducing the batch size or use auto mode to allow CPU fallback.\n\ + This can be caused by a missing warm static shape, a MIGraphX runtime error, or insufficient GPU memory for batch size {batch} and document length. \ + Warm matching MIGraphX shapes, adjust --batch-size, or use auto mode to allow CPU fallback.\n\ \nCaused by: {gpu_err}", batch = self .batch_size @@ -2020,7 +2225,9 @@ impl IndexBuilder { std::fs::remove_dir_all(&old_path)?; } - let (files, skipped) = self.scan_files(languages)?; + let (files, skipped) = + crate::profile::time_result("index.scan_files", || self.scan_files(languages))?; + crate::profile::set_metadata("index.files", files.len()); let mut state = IndexState::default(); let mut all_units: Vec = Vec::new(); @@ -2036,18 +2243,22 @@ impl IndexBuilder { pb.set_message("Parsing files..."); // Extract units from all files - for parsed in parse_files_parallel(&self.project_root, &files, Some(&pb)) { - if let Some(reason) = parsed.skip_reason { - eprintln!("āš ļø {}", reason); - state.ignored_files.insert(parsed.path); - continue; - } + { + let _phase = crate::profile::phase("index.parse_files"); + for parsed in parse_files_parallel(&self.project_root, &files, Some(&pb)) { + if let Some(reason) = parsed.skip_reason { + eprintln!("āš ļø {}", reason); + state.ignored_files.insert(parsed.path); + continue; + } - all_units.extend(parsed.units); - if let Some(file_info) = parsed.file_info { - state.files.insert(parsed.path, file_info); + all_units.extend(parsed.units); + if let Some(file_info) = parsed.file_info { + state.files.insert(parsed.path, file_info); + } } } + crate::profile::set_metadata("index.units", all_units.len()); let parsing_interrupted = is_interrupted(); pb.finish_and_clear(); @@ -2057,7 +2268,9 @@ impl IndexBuilder { } // Build call graph to populate called_by - build_call_graph(&mut all_units); + crate::profile::time("index.build_call_graph", || { + build_call_graph(&mut all_units) + }); // Prompt for confirmation if indexing a large codebase if !self.auto_confirm @@ -2069,7 +2282,9 @@ impl IndexBuilder { let was_interrupted = if !all_units.is_empty() { // Ensure model is created before encoding (lazy initialization) - self.ensure_model_created(all_units.len())?; + crate::profile::time_result("index.model_ready", || { + self.ensure_model_created(all_units.len()) + })?; #[cfg(feature = "cuda")] if !self.is_using_gpu() @@ -2081,7 +2296,9 @@ impl IndexBuilder { } // Build new index in temp directory to avoid destroying the old one - self.write_index_impl(&all_units, true, Some(&temp_path))? + crate::profile::time_result("index.write_index", || { + self.write_index_impl(&all_units, true, Some(&temp_path)) + })? } else { false }; @@ -2093,34 +2310,40 @@ impl IndexBuilder { } // Atomic swap: replace old index with newly built one - if all_units.is_empty() { - // No files to index — just remove the old index if it exists - if index_path.exists() { - std::fs::remove_dir_all(&index_path)?; - } - } else { - if index_path.exists() { - std::fs::rename(&index_path, &old_path) - .context("Failed to move old index aside")?; - } - if let Err(e) = std::fs::rename(&temp_path, &index_path) { - // Try to restore old index - if old_path.exists() && !index_path.exists() { - let _ = std::fs::rename(&old_path, &index_path); + crate::profile::time_result("index.atomic_swap", || -> Result<()> { + if all_units.is_empty() { + // No files to index — just remove the old index if it exists + if index_path.exists() { + std::fs::remove_dir_all(&index_path)?; + } + } else { + if index_path.exists() { + std::fs::rename(&index_path, &old_path) + .context("Failed to move old index aside")?; + } + if let Err(e) = std::fs::rename(&temp_path, &index_path) { + // Try to restore old index + if old_path.exists() && !index_path.exists() { + let _ = std::fs::rename(&old_path, &index_path); + } + return Err(anyhow::anyhow!( + "Failed to move new index into place: {}", + e + )); + } + if old_path.exists() { + let _ = std::fs::remove_dir_all(&old_path); } - return Err(anyhow::anyhow!( - "Failed to move new index into place: {}", - e - )); - } - if old_path.exists() { - let _ = std::fs::remove_dir_all(&old_path); } - } + Ok(()) + })?; // Save state and project metadata only on successful completion - state.save(&self.index_dir)?; - ProjectMetadata::new(&self.project_root, &self.model_id).save(&self.index_dir)?; + crate::profile::time_result("index.save_state", || -> Result<()> { + state.save(&self.index_dir)?; + ProjectMetadata::new(&self.project_root, &self.model_id).save(&self.index_dir)?; + Ok(()) + })?; Ok(UpdateStats { added: files.len(), @@ -2137,7 +2360,13 @@ impl IndexBuilder { old_state: &IndexState, languages: Option<&[Language]>, ) -> Result { - let plan = self.compute_update_plan(old_state, languages)?; + let plan = crate::profile::time_result("index.compute_update_plan", || { + self.compute_update_plan(old_state, languages) + })?; + crate::profile::set_metadata("index.plan_added", plan.added.len()); + crate::profile::set_metadata("index.plan_changed", plan.changed.len()); + crate::profile::set_metadata("index.plan_deleted", plan.deleted.len()); + crate::profile::set_metadata("index.plan_unchanged", plan.unchanged); let index_dir = get_vector_index_path(&self.index_dir); let index_path = index_dir.to_str().unwrap(); @@ -2151,7 +2380,9 @@ impl IndexBuilder { // 0. Clean up orphaned entries (files in index but not on disk) // This handles directory deletion/rename and any inconsistencies - let orphaned_deleted = self.cleanup_orphaned_entries(index_path)?; + let orphaned_deleted = crate::profile::time_result("index.cleanup_orphans", || { + self.cleanup_orphaned_entries(index_path) + })?; // Nothing to do if plan.added.is_empty() @@ -2233,20 +2464,24 @@ impl IndexBuilder { }; let mut skipped_files: Vec = Vec::new(); - for parsed in parse_files_parallel(&self.project_root, &files_to_index, pb.as_ref()) { - if let Some(reason) = parsed.skip_reason { - eprintln!("āš ļø {}", reason); - state.files.remove(&parsed.path); - state.ignored_files.insert(parsed.path.clone()); - skipped_files.push(parsed.path); - continue; - } + { + let _phase = crate::profile::phase("index.parse_files"); + for parsed in parse_files_parallel(&self.project_root, &files_to_index, pb.as_ref()) { + if let Some(reason) = parsed.skip_reason { + eprintln!("āš ļø {}", reason); + state.files.remove(&parsed.path); + state.ignored_files.insert(parsed.path.clone()); + skipped_files.push(parsed.path); + continue; + } - new_units.extend(parsed.units); - if let Some(file_info) = parsed.file_info { - state.files.insert(parsed.path, file_info); + new_units.extend(parsed.units); + if let Some(file_info) = parsed.file_info { + state.files.insert(parsed.path, file_info); + } } } + crate::profile::set_metadata("index.units", new_units.len()); let parsing_interrupted = is_interrupted(); if let Some(pb) = pb { pb.finish_and_clear(); @@ -2272,7 +2507,9 @@ impl IndexBuilder { let mut was_interrupted = false; if !new_units.is_empty() { // Build call graph for new units - build_call_graph(&mut new_units); + crate::profile::time("index.build_call_graph", || { + build_call_graph(&mut new_units) + }); // Prompt for confirmation if indexing a large number of new units if !self.auto_confirm @@ -2283,7 +2520,9 @@ impl IndexBuilder { } // Ensure model is created before encoding (lazy initialization) - self.ensure_model_created(new_units.len())?; + crate::profile::time_result("index.model_ready", || { + self.ensure_model_created(new_units.len()) + })?; // Progress bar for encoding let pb = ProgressBar::new(new_units.len() as u64); @@ -2311,14 +2550,19 @@ impl IndexBuilder { // into a single index rewrite — see delete_files_from_index / issue #116. delete_files_from_index(index_path, &plan.changed)?; - let sorted_units = prepare_units_for_encoding(&new_units, index_chunk_size); - let pipeline_interrupted = self.run_encoding_pipeline( - &sorted_units, - index_chunk_size, - pool_factor, - index_path, - Some(&pb), - )?; + let sorted_units = crate::profile::time("index.prepare_units", || { + prepare_units_for_encoding(&new_units, index_chunk_size) + }); + let pipeline_interrupted = + crate::profile::time_result("index.encoding_pipeline", || { + self.run_encoding_pipeline( + &sorted_units, + index_chunk_size, + pool_factor, + index_path, + Some(&pb), + ) + })?; was_interrupted |= pipeline_interrupted; pb.finish_and_clear(); @@ -2331,7 +2575,7 @@ impl IndexBuilder { } state.dirty = false; - state.save(&self.index_dir)?; + crate::profile::time_result("index.save_state", || state.save(&self.index_dir))?; Ok(UpdateStats { added: plan.added.len(), @@ -2719,15 +2963,21 @@ impl IndexBuilder { // Compute effective pool factor based on batch size let pool_factor = self.resolve_pool_factor(units.len()); - let sorted_units = prepare_units_for_encoding(units, index_chunk_size); - self.ensure_model_created(units.len())?; - let was_interrupted = self.run_encoding_pipeline( - &sorted_units, - index_chunk_size, - pool_factor, - index_path, - pb.as_ref(), - )?; + let sorted_units = crate::profile::time("index.prepare_units", || { + prepare_units_for_encoding(units, index_chunk_size) + }); + crate::profile::time_result("index.model_ready", || { + self.ensure_model_created(units.len()) + })?; + let was_interrupted = crate::profile::time_result("index.encoding_pipeline", || { + self.run_encoding_pipeline( + &sorted_units, + index_chunk_size, + pool_factor, + index_path, + pb.as_ref(), + ) + })?; if let Some(pb) = pb { pb.finish_and_clear(); @@ -3243,18 +3493,24 @@ fn apply_search_acceleration_mode(acceleration_mode: AccelerationMode) { match acceleration_mode { AccelerationMode::ForceGpu => apply_acceleration_mode(AccelerationMode::ForceGpu), AccelerationMode::ForceCpu => apply_acceleration_mode(AccelerationMode::ForceCpu), - // Keep the existing search behavior: single-query searches default to - // CPU to avoid GPU initialization overhead, except on CoreML builds - // where automatic acceleration has historically been the default. - AccelerationMode::Auto if cfg!(feature = "coreml") => { - apply_acceleration_mode(AccelerationMode::Auto) + // Search is a tiny, one-query workload in the CLI. On Linux, auto GPU + // discovery (especially ROCm/MIGraphX) can add hundreds of ms even if + // we ultimately resolve to CPU. Keep auto search CPU-first unless the + // user explicitly requests `--force-gpu`. macOS is left in Auto so the + // cheap CoreML path can still be selected below. + AccelerationMode::Auto => { + #[cfg(target_os = "macos")] + apply_acceleration_mode(AccelerationMode::Auto); + #[cfg(not(target_os = "macos"))] + apply_acceleration_mode(AccelerationMode::ForceCpu); } - AccelerationMode::Auto => apply_acceleration_mode(AccelerationMode::ForceCpu), } } fn resolve_search_execution_provider( acceleration_mode: AccelerationMode, + _model_path: &Path, + _quantized: bool, ) -> Result { match acceleration_mode { AccelerationMode::ForceGpu => require_gpu_provider(), @@ -3263,6 +3519,9 @@ fn resolve_search_execution_provider( if next_plaid_onnx::is_coreml_available() { Ok(ExecutionProvider::CoreML) } else { + eprintln!( + "ā„¹ļø Auto mode using CPU inference for search. Use --force-gpu to benchmark GPU query encoding." + ); Ok(ExecutionProvider::Cpu) } } @@ -3279,6 +3538,16 @@ impl Searcher { model_id: &str, model_path: &Path, quantized: bool, + ) -> Result { + Self::load_with_quantized_options(project_root, model_id, model_path, quantized, quantized) + } + + pub fn load_with_quantized_options( + project_root: &Path, + model_id: &str, + model_path: &Path, + quantized: bool, + cpu_fallback_quantized: bool, ) -> Result { let index_dir = get_index_dir_for_project(project_root, model_id)?; let vector_dir = get_vector_index_path(&index_dir); @@ -3287,8 +3556,25 @@ impl Searcher { let acceleration_mode = env_acceleration_mode_lossy(); apply_search_acceleration_mode(acceleration_mode); - crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?; - let execution_provider = resolve_search_execution_provider(acceleration_mode)?; + crate::profile::time_result("search.ort_init", || { + crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime") + })?; + let execution_provider = crate::profile::time_result("search.resolve_provider", || { + resolve_search_execution_provider(acceleration_mode, model_path, quantized) + })?; + if execution_provider == ExecutionProvider::Cpu { + apply_acceleration_mode(AccelerationMode::ForceCpu); + } + let model_quantized = if execution_provider == ExecutionProvider::Cpu { + cpu_fallback_quantized + } else { + quantized + }; + crate::profile::set_metadata( + "search.execution_provider", + execution_provider.display_name(), + ); + crate::profile::set_metadata("search.model_quantized", model_quantized); // Cap intra-op threads to avoid overhead on high-core-count systems let num_threads = std::thread::available_parallelism() @@ -3298,17 +3584,21 @@ impl Searcher { // Suppress stderr during model loading to hide CoreML's harmless // "Context leak detected" warnings on macOS - let model = crate::stderr::with_suppressed_stderr(|| { - Colbert::builder(model_path) - .with_quantized(quantized) - .with_threads(num_threads) - .with_execution_provider(execution_provider) - .build() - }) - .context("Failed to load ColBERT model")?; + let model = crate::profile::time_result("search.model_load", || { + crate::stderr::with_suppressed_stderr(|| { + Colbert::builder(model_path) + .with_quantized(model_quantized) + .with_threads(num_threads) + .with_execution_provider(execution_provider) + .build() + }) + .context("Failed to load ColBERT model") + })?; // Load index - let index = MmapIndex::load(&index_path).context("Failed to load index")?; + let index = crate::profile::time_result("search.index_load", || { + MmapIndex::load(&index_path).context("Failed to load index") + })?; Ok(Self { model, @@ -3327,6 +3617,17 @@ impl Searcher { index_dir: &Path, model_path: &Path, quantized: bool, + ) -> Result { + Self::load_from_index_dir_with_quantized_options( + index_dir, model_path, quantized, quantized, + ) + } + + pub fn load_from_index_dir_with_quantized_options( + index_dir: &Path, + model_path: &Path, + quantized: bool, + cpu_fallback_quantized: bool, ) -> Result { let vector_dir = get_vector_index_path(index_dir); let index_path = vector_dir.to_str().unwrap().to_string(); @@ -3334,8 +3635,25 @@ impl Searcher { let acceleration_mode = env_acceleration_mode_lossy(); apply_search_acceleration_mode(acceleration_mode); - crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?; - let execution_provider = resolve_search_execution_provider(acceleration_mode)?; + crate::profile::time_result("search.ort_init", || { + crate::onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime") + })?; + let execution_provider = crate::profile::time_result("search.resolve_provider", || { + resolve_search_execution_provider(acceleration_mode, model_path, quantized) + })?; + if execution_provider == ExecutionProvider::Cpu { + apply_acceleration_mode(AccelerationMode::ForceCpu); + } + let model_quantized = if execution_provider == ExecutionProvider::Cpu { + cpu_fallback_quantized + } else { + quantized + }; + crate::profile::set_metadata( + "search.execution_provider", + execution_provider.display_name(), + ); + crate::profile::set_metadata("search.model_quantized", model_quantized); // Cap intra-op threads to avoid overhead on high-core-count systems let num_threads = std::thread::available_parallelism() @@ -3345,16 +3663,20 @@ impl Searcher { // Suppress stderr during model loading to hide CoreML's harmless // "Context leak detected" warnings on macOS - let model = crate::stderr::with_suppressed_stderr(|| { - Colbert::builder(model_path) - .with_quantized(quantized) - .with_threads(num_threads) - .with_execution_provider(execution_provider) - .build() - }) - .context("Failed to load ColBERT model")?; + let model = crate::profile::time_result("search.model_load", || { + crate::stderr::with_suppressed_stderr(|| { + Colbert::builder(model_path) + .with_quantized(model_quantized) + .with_threads(num_threads) + .with_execution_provider(execution_provider) + .build() + }) + .context("Failed to load ColBERT model") + })?; - let index = MmapIndex::load(&index_path).context("Failed to load index")?; + let index = crate::profile::time_result("search.index_load", || { + MmapIndex::load(&index_path).context("Failed to load index") + })?; Ok(Self { model, @@ -3614,9 +3936,10 @@ impl Searcher { /// Encode a query once for reuse across multiple searches. pub fn encode_query(&self, query: &str) -> Result> { - let query_embeddings = + let query_embeddings = crate::profile::time_result("search.encode_query", || { crate::stderr::with_suppressed_stderr(|| self.model.encode_queries(&[query])) - .context("Failed to encode query")?; + .context("Failed to encode query") + })?; Ok(query_embeddings.into_iter().next().unwrap()) } @@ -3640,12 +3963,19 @@ impl Searcher { if sanitized_query.is_empty() { return None; } - if let Some(sub) = subset { - next_plaid::text_search::search_filtered(&self.index_path, &sanitized_query, top_k, sub) + crate::profile::time("search.fts5", || { + if let Some(sub) = subset { + next_plaid::text_search::search_filtered( + &self.index_path, + &sanitized_query, + top_k, + sub, + ) .ok() - } else { - next_plaid::text_search::search(&self.index_path, &sanitized_query, top_k).ok() - } + } else { + next_plaid::text_search::search(&self.index_path, &sanitized_query, top_k).ok() + } + }) } pub fn search( @@ -3666,14 +3996,17 @@ impl Searcher { subset: Option<&[i64]>, ) -> Result> { let params = search_params_from_env(top_k); - let results = self - .index - .search(query_emb, ¶ms, subset) - .context("Search failed")?; + let results = crate::profile::time_result("search.semantic_index", || { + self.index + .search(query_emb, ¶ms, subset) + .context("Search failed") + })?; let doc_ids: Vec = results.passage_ids.to_vec(); - let metadata = filtering::get(&self.index_path, None, &[], Some(&doc_ids)) - .context("Failed to retrieve metadata")?; + let metadata = crate::profile::time_result("search.metadata_fetch", || { + filtering::get(&self.index_path, None, &[], Some(&doc_ids)) + .context("Failed to retrieve metadata") + })?; let search_results: Vec = metadata .into_iter() @@ -3730,10 +4063,11 @@ impl Searcher { self.index.num_documents().max(top_k), ); let params = search_params_from_env(fetch_k); - let semantic = self - .index - .search(query_emb, ¶ms, subset) - .context("Semantic search failed")?; + let semantic = crate::profile::time_result("search.semantic_index", || { + self.index + .search(query_emb, ¶ms, subset) + .context("Semantic search failed") + })?; trace_log( query, "semantic", @@ -3811,8 +4145,10 @@ impl Searcher { }; trace_log(query, "fused", &fused_ids, &fused_scores, 20); - let metadata = filtering::get(&self.index_path, None, &[], Some(&fused_ids)) - .context("Failed to retrieve metadata")?; + let metadata = crate::profile::time_result("search.metadata_fetch", || { + filtering::get(&self.index_path, None, &[], Some(&fused_ids)) + .context("Failed to retrieve metadata") + })?; let apply_penalty = crate::ranking::should_apply_path_penalty(query); @@ -4710,6 +5046,7 @@ mod tests { model: None, model_path: PathBuf::from("/nonexistent-model"), quantized: false, + cpu_fallback_quantized: false, parallel_sessions: None, batch_size: None, project_root: project_root.to_path_buf(), diff --git a/colgrep/src/lib.rs b/colgrep/src/lib.rs index bca81230..c6e8716d 100644 --- a/colgrep/src/lib.rs +++ b/colgrep/src/lib.rs @@ -8,11 +8,13 @@ pub mod acceleration; pub mod config; pub mod embed; +pub mod hardware; pub mod index; pub mod install; pub mod model; pub mod onnx_runtime; pub mod parser; +pub mod profile; pub mod ranking; pub mod signal; pub mod stderr; diff --git a/colgrep/src/model.rs b/colgrep/src/model.rs index 47f1e9b4..291e3a8c 100644 --- a/colgrep/src/model.rs +++ b/colgrep/src/model.rs @@ -1,5 +1,6 @@ use anyhow::Result; use hf_hub::api::sync::ApiBuilder; +use hf_hub::Cache; use std::path::PathBuf; pub const DEFAULT_MODEL: &str = "lightonai/LateOn-Code-edge"; @@ -8,19 +9,24 @@ pub const DEFAULT_MODEL: &str = "lightonai/LateOn-Code-edge"; const REQUIRED_FILES: &[&str] = &[ "model_int8.onnx", "tokenizer.json", - "config_sentence_transformers.json", "config.json", "onnx_config.json", ]; -/// Optional files (non-quantized model) -const OPTIONAL_FILES: &[&str] = &["model.onnx"]; +/// Optional files (non-quantized models) +const OPTIONAL_FILES: &[&str] = &[ + "model.onnx", + "model_fp16.onnx", + "config_sentence_transformers.json", +]; /// Load model from cache or download from HuggingFace. /// Returns path to the model directory. -/// The `quiet` parameter is kept for API compatibility but no longer used -/// (output is now handled in IndexBuilder::ensure_model_created after ONNX runtime init). -pub fn ensure_model(model_id: Option<&str>, _quiet: bool) -> Result { +/// When `quiet` is true (the common search path for an already-indexed repo), +/// optional files are resolved from the local HuggingFace cache only. This +/// avoids a network metadata check for missing optional artifacts on every +/// query while still downloading required files if they are absent. +pub fn ensure_model(model_id: Option<&str>, quiet: bool) -> Result { let model_id = model_id.unwrap_or(DEFAULT_MODEL); // Check if it's a local path @@ -62,9 +68,19 @@ pub fn ensure_model(model_id: Option<&str>, _quiet: bool) -> Result { } } - // Try to download optional files (non-quantized model) - ignore errors + let local_cache = Cache::from_env().model(model_id.to_string()); + + // Try to download optional files (non-quantized models) - ignore errors. + // In quiet/search mode, only touch files already present in the local cache + // so missing optional artifacts do not add a remote HEAD/GET round trip to + // every query. for file in OPTIONAL_FILES { - let _ = repo.get(file); + if local_cache.get(file).is_some() { + continue; + } + if !quiet { + let _ = repo.get(file); + } } model_dir.ok_or_else(|| anyhow::anyhow!("Failed to determine model directory")) diff --git a/colgrep/src/profile.rs b/colgrep/src/profile.rs new file mode 100644 index 00000000..df66245b --- /dev/null +++ b/colgrep/src/profile.rs @@ -0,0 +1,171 @@ +//! Lightweight opt-in profiling for ColGREP commands. +//! +//! Set `COLGREP_PROFILE=1` to emit one JSON line on stderr at command exit: +//! +//! ```text +//! __COLGREP_PROFILE__ {"type":"colgrep_profile", ...} +//! ``` + +use std::sync::{Mutex, OnceLock}; +use std::time::Instant; + +use serde::Serialize; + +#[derive(Default)] +struct ProfileState { + command: Option, + started_at: Option, + phases: Vec, + metadata: serde_json::Map, +} + +#[derive(Clone, Serialize)] +struct ProfilePhase { + name: String, + start_ms: f64, + duration_ms: f64, + thread: String, +} + +#[derive(Serialize)] +struct ProfileReport<'a> { + #[serde(rename = "type")] + kind: &'static str, + command: &'a str, + status: &'static str, + total_ms: f64, + phases: &'a [ProfilePhase], + metadata: &'a serde_json::Map, +} + +pub struct PhaseGuard { + name: String, + start: Instant, +} + +fn state() -> &'static Mutex { + static STATE: OnceLock> = OnceLock::new(); + STATE.get_or_init(|| Mutex::new(ProfileState::default())) +} + +fn truthy(value: &str) -> bool { + !matches!( + value.trim().to_ascii_lowercase().as_str(), + "" | "0" | "false" | "no" | "off" + ) +} + +pub fn enabled() -> bool { + std::env::var("COLGREP_PROFILE") + .map(|value| truthy(&value)) + .unwrap_or(false) +} + +pub fn start_command(command: &str) { + if !enabled() { + return; + } + let mut guard = state().lock().unwrap(); + guard.command = Some(command.to_string()); + guard.started_at = Some(Instant::now()); + guard.phases.clear(); + guard.metadata.clear(); +} + +pub fn set_metadata(key: &str, value: T) { + if !enabled() { + return; + } + if let Ok(value) = serde_json::to_value(value) { + state() + .lock() + .unwrap() + .metadata + .insert(key.to_string(), value); + } +} + +pub fn phase(name: &str) -> Option { + enabled().then(|| PhaseGuard { + name: name.to_string(), + start: Instant::now(), + }) +} + +pub fn time(name: &str, f: F) -> T +where + F: FnOnce() -> T, +{ + let _guard = phase(name); + f() +} + +pub fn time_result(name: &str, f: F) -> Result +where + F: FnOnce() -> Result, +{ + let _guard = phase(name); + f() +} + +pub fn command_result(command: &str, f: F) -> Result +where + F: FnOnce() -> Result, +{ + start_command(command); + let result = f(); + finish_command(result.is_ok()); + result +} + +pub fn finish_command(ok: bool) { + if !enabled() { + return; + } + let guard = state().lock().unwrap(); + let Some(command) = guard.command.as_deref() else { + return; + }; + let total_ms = guard + .started_at + .map(|started| started.elapsed().as_secs_f64() * 1000.0) + .unwrap_or_default(); + let report = ProfileReport { + kind: "colgrep_profile", + command, + status: if ok { "ok" } else { "error" }, + total_ms, + phases: &guard.phases, + metadata: &guard.metadata, + }; + if let Ok(json) = serde_json::to_string(&report) { + eprintln!("__COLGREP_PROFILE__ {json}"); + } +} + +impl Drop for PhaseGuard { + fn drop(&mut self) { + if !enabled() { + return; + } + let duration_ms = self.start.elapsed().as_secs_f64() * 1000.0; + let (start_ms, thread) = { + let guard = state().lock().unwrap(); + let start_ms = guard + .started_at + .map(|started| self.start.duration_since(started).as_secs_f64() * 1000.0) + .unwrap_or_default(); + let thread = std::thread::current() + .name() + .unwrap_or("unnamed") + .to_string(); + (start_ms, thread) + }; + state().lock().unwrap().phases.push(ProfilePhase { + name: self.name.clone(), + start_ms, + duration_ms, + thread, + }); + } +} diff --git a/next-plaid-onnx/Cargo.toml b/next-plaid-onnx/Cargo.toml index 4360c80f..3c245fa2 100644 --- a/next-plaid-onnx/Cargo.toml +++ b/next-plaid-onnx/Cargo.toml @@ -41,6 +41,8 @@ anyhow = "1.0" glob = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +sha2 = "0.10" +fs2 = "0.4" pyo3 = { version = "0.23", features = ["extension-module"], optional = true } numpy = { version = "0.23", optional = true } rayon = "1.10" diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 0ec947cd..133b05c0 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -49,6 +49,7 @@ pub mod hierarchy; use anyhow::{Context, Result}; +use fs2::FileExt; use ndarray::Array2; use ort::session::builder::GraphOptimizationLevel; use ort::session::Session; @@ -56,9 +57,12 @@ use ort::value::Tensor; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use rayon::{ThreadPool, ThreadPoolBuilder}; use serde::{Deserialize, Serialize}; -use std::collections::{HashSet, VecDeque}; +use sha2::{Digest, Sha256}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::fs; -use std::path::Path; +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc; use std::sync::Once; @@ -110,9 +114,10 @@ use ort::execution_providers::DirectMLExecutionProvider; use ort::execution_providers::MIGraphXExecutionProvider; #[cfg(feature = "tensorrt")] use ort::execution_providers::TensorRTExecutionProvider; +#[cfg(feature = "migraphx")] +use ort::ortsys; use ort::session::builder::SessionBuilder; - // ============================================================================= // ONNX Runtime initialization (internal) // ============================================================================= @@ -249,23 +254,29 @@ const GPU_PROVIDER_ORDER: [ExecutionProvider; 5] = [ ExecutionProvider::MIGraphX, ]; +/// Whether this crate was compiled with support for a given execution provider. +/// +/// CPU and `Auto` do not require a feature-gated provider, so they always +/// return `true`. GPU providers only return `true` when their corresponding +/// Cargo feature is enabled. +pub fn is_execution_provider_compiled(provider: ExecutionProvider) -> bool { + match provider { + ExecutionProvider::Auto | ExecutionProvider::Cpu => true, + ExecutionProvider::Cuda => cfg!(feature = "cuda"), + ExecutionProvider::TensorRT => cfg!(feature = "tensorrt"), + ExecutionProvider::CoreML => cfg!(feature = "coreml"), + ExecutionProvider::DirectML => cfg!(feature = "directml"), + ExecutionProvider::MIGraphX => cfg!(feature = "migraphx"), + } +} + /// GPU execution providers compiled into this crate, in auto-selection order. pub fn compiled_gpu_execution_providers() -> Vec { - #[allow(unused_mut)] - let mut providers = Vec::new(); - - #[cfg(feature = "cuda")] - providers.push(ExecutionProvider::Cuda); - #[cfg(feature = "tensorrt")] - providers.push(ExecutionProvider::TensorRT); - #[cfg(feature = "coreml")] - providers.push(ExecutionProvider::CoreML); - #[cfg(feature = "directml")] - providers.push(ExecutionProvider::DirectML); - #[cfg(feature = "migraphx")] - providers.push(ExecutionProvider::MIGraphX); - - providers + GPU_PROVIDER_ORDER + .iter() + .copied() + .filter(|provider| is_execution_provider_compiled(*provider)) + .collect() } /// First compiled GPU execution provider in auto-selection order. @@ -275,7 +286,26 @@ pub fn compiled_gpu_execution_provider() -> Option { /// Return whether a specific execution provider is available in the currently /// loaded ONNX Runtime library. +/// +/// For `ExecutionProvider::Auto`, this returns whether any compiled GPU +/// provider is available. CPU fallback is intentionally not counted as an +/// available accelerator. pub fn is_execution_provider_available(provider: ExecutionProvider) -> bool { + if !is_execution_provider_compiled(provider) { + return false; + } + + if (matches!(provider, ExecutionProvider::Auto) || provider.is_gpu()) && is_force_cpu() { + return false; + } + + let needs_provider_probe = provider.is_gpu() + || (matches!(provider, ExecutionProvider::Auto) + && !compiled_gpu_execution_providers().is_empty()); + if needs_provider_probe { + init_ort_runtime(); + } + match provider { ExecutionProvider::Auto => preferred_gpu_execution_provider().is_some(), ExecutionProvider::Cpu => true, @@ -341,18 +371,22 @@ pub fn require_gpu_execution_provider() -> Result { }) } -fn configure_execution_provider( +fn configure_execution_provider_with_options( builder: SessionBuilder, provider: ExecutionProvider, + migraphx_model_cache_dir: Option<&Path>, + migraphx_fp16_enable: bool, ) -> Result { match provider { - ExecutionProvider::Auto => configure_auto_provider(builder), + ExecutionProvider::Auto => configure_auto_provider(builder, migraphx_fp16_enable), ExecutionProvider::Cpu => Ok(builder), ExecutionProvider::Cuda => configure_cuda(builder), ExecutionProvider::TensorRT => configure_tensorrt(builder), ExecutionProvider::CoreML => configure_coreml(builder), ExecutionProvider::DirectML => configure_directml(builder), - ExecutionProvider::MIGraphX => configure_migraphx(builder), + ExecutionProvider::MIGraphX => { + configure_migraphx(builder, migraphx_model_cache_dir, migraphx_fp16_enable) + } } } @@ -522,26 +556,23 @@ pub fn is_migraphx_available() -> bool { false } -fn configure_auto_provider(builder: SessionBuilder) -> Result { +fn configure_auto_provider( + builder: SessionBuilder, + migraphx_fp16_enable: bool, +) -> Result { if is_force_gpu() { let provider = preferred_gpu_execution_provider().ok_or_else(|| { - let compiled = compiled_gpu_execution_providers(); - if compiled.is_empty() { - anyhow::anyhow!( - "NEXT_PLAID_FORCE_GPU is set, but no GPU execution provider was compiled. Enable a feature such as 'cuda', 'migraphx', 'coreml', or 'directml'." - ) - } else { - let names = compiled - .iter() - .map(|provider| provider.display_name()) - .collect::>() - .join(", "); - anyhow::anyhow!( - "NEXT_PLAID_FORCE_GPU is set, but no compiled GPU execution provider is available in the loaded ONNX Runtime library. Compiled provider(s): {names}." - ) - } + anyhow::anyhow!( + "NEXT_PLAID_FORCE_GPU is set, but {}", + unavailable_gpu_execution_provider_reason() + ) })?; - return configure_execution_provider(builder, provider); + return configure_execution_provider_with_options( + builder, + provider, + None, + migraphx_fp16_enable, + ); } // Skip GPU providers entirely if CPU-only mode is forced @@ -593,7 +624,7 @@ fn configure_auto_provider(builder: SessionBuilder) -> Result { #[cfg(feature = "migraphx")] if !force_cpu { - if let Ok(b) = configure_migraphx(builder.clone()) { + if let Ok(b) = configure_migraphx(builder.clone(), None, migraphx_fp16_enable) { return Ok(b); } } @@ -678,19 +709,77 @@ fn configure_directml(_builder: SessionBuilder) -> Result { } #[cfg(feature = "migraphx")] -fn configure_migraphx(builder: SessionBuilder) -> Result { +fn configure_migraphx( + builder: SessionBuilder, + model_cache_dir: Option<&Path>, + fp16_enable: bool, +) -> Result { if is_force_cpu() { return Ok(builder); } - builder - .with_execution_providers([MIGraphXExecutionProvider::default() - .build() - .error_on_failure()]) - .context("Failed to configure MIGraphX execution provider. Ensure ROCm and MIGraphX are installed.") + let mut builder = builder; + append_migraphx_execution_provider(&mut builder, model_cache_dir, fp16_enable).context( + "Failed to configure MIGraphX execution provider. Ensure ROCm and MIGraphX are installed.", + )?; + Ok(builder) +} + +#[cfg(feature = "migraphx")] +fn append_migraphx_execution_provider( + builder: &mut SessionBuilder, + model_cache_dir: Option<&Path>, + fp16_enable: bool, +) -> ort::Result<()> { + use ort::AsPointer; + + // Use the provider-options map API instead of the legacy + // `OrtMIGraphXProviderOptions` struct. The Rust `ort` crate currently ships + // an older struct layout, and ORT 1.24's legacy MIGraphX wrapper also + // stringifies an empty model-cache path as `""`, which enables MXR caching + // to an invalid directory. Supplying only explicit non-default options via + // the map leaves MIGraphX's cache path truly empty. + let provider_name = std::ffi::CString::new("MIGraphXExecutionProvider").unwrap(); + let mut options = vec![("device_id".to_string(), "0".to_string())]; + if fp16_enable { + // FP16 is the fast path when MIGraphX receives an FP32 ColBERT ONNX. + // Do not set this for model_fp16.onnx: those graphs are already + // precision-shaped by export, and may intentionally keep selected + // operations (for example the final layer) in FP32. + options.push(("migraphx_fp16_enable".to_string(), "1".to_string())); + } + if let Some(path) = model_cache_dir { + options.push(( + "migraphx_model_cache_dir".to_string(), + path.display().to_string(), + )); + } + let keys = options + .iter() + .map(|(key, _)| std::ffi::CString::new(key.as_str()).unwrap()) + .collect::>(); + let values = options + .iter() + .map(|(_, value)| std::ffi::CString::new(value.as_str()).unwrap()) + .collect::>(); + let key_ptrs = keys.iter().map(|s| s.as_ptr()).collect::>(); + let value_ptrs = values.iter().map(|s| s.as_ptr()).collect::>(); + + ortsys![unsafe SessionOptionsAppendExecutionProvider( + builder.ptr_mut(), + provider_name.as_ptr(), + key_ptrs.as_ptr(), + value_ptrs.as_ptr(), + key_ptrs.len(), + )?]; + Ok(()) } #[cfg(not(feature = "migraphx"))] -fn configure_migraphx(_builder: SessionBuilder) -> Result { +fn configure_migraphx( + _builder: SessionBuilder, + _model_cache_dir: Option<&Path>, + _fp16_enable: bool, +) -> Result { anyhow::bail!("MIGraphX support not compiled. Enable the 'migraphx' feature.") } @@ -852,6 +941,36 @@ const DEFAULT_CPU_BATCH_SIZE: usize = 32; /// Default batch size for GPU encoding. const DEFAULT_GPU_BATCH_SIZE: usize = 64; +/// Fixed ONNX input shape used for shape-specialized MIGraphX sessions. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct MigraphxStaticShape { + pub batch_size: usize, + pub sequence_length: usize, +} + +impl MigraphxStaticShape { + fn cache_dir_name(self) -> String { + format!("{}x{}", self.batch_size, self.sequence_length) + } +} + +#[derive(Clone, Debug)] +pub struct MigraphxStaticShapeCacheStatus { + pub cache_root: PathBuf, + pub model_cache_key: String, + pub document_shapes: Vec, + pub warm_document_shapes: Vec, + pub cold_document_shapes: Vec, + pub query_shape: MigraphxStaticShape, + pub query_shape_warm: bool, +} + +impl MigraphxStaticShapeCacheStatus { + pub fn all_document_shapes_warm(&self) -> bool { + !self.document_shapes.is_empty() && self.cold_document_shapes.is_empty() + } +} + /// Type alias for batch encoding data: (input_ids, attention_mask, token_type_ids, token_ids) /// ColBERT model for encoding documents and queries into multi-vector embeddings. /// @@ -883,10 +1002,42 @@ pub struct Colbert { pub requested_execution_provider: ExecutionProvider, batch_size: usize, dynamic_batch: bool, + migraphx_hybrid: Option>, +} + +struct MigraphxHybrid { + model_dir: PathBuf, + quantized: bool, + cpu_fallback_quantized: bool, + tokenizer: Arc, + config: Arc, + query_length: usize, + document_length: usize, + cpu_fallback_parallel: usize, + cpu_model: Mutex>, + cache_root: PathBuf, + model_cache_key: String, + document_shapes: HashSet, + supported_shapes: HashSet, + shape_models: Mutex>, + min_run_tokens: usize, +} + +struct MigraphxGpuJob { + batch_idx: usize, + cache_shape: MigraphxStaticShape, + route: &'static str, + prepared: PreparedDocumentBatch, + cpu_fallback: PreparedDocumentBatch, } +#[derive(Clone)] pub struct PreparedDocumentBatch { + /// Number of real documents/chunks in this prepared batch. batch_size: usize, + /// Number of rows in the ONNX tensors. Shape-sensitive execution + /// providers may pad this above `batch_size` to reuse compiled plans. + tensor_batch_size: usize, batch_max_len: usize, all_input_ids: Vec, all_attention_mask: Vec, @@ -914,6 +1065,10 @@ impl PreparedDocumentBatch { self.batch_size } + pub fn tensor_batch_size(&self) -> usize { + self.tensor_batch_size + } + pub fn batch_max_len(&self) -> usize { self.batch_max_len } @@ -1029,6 +1184,12 @@ pub struct ColbertBuilder { dynamic_batch: bool, query_length: Option, document_length: Option, + migraphx_static_shape: Option, + migraphx_model_cache_dir: Option, + migraphx_cold_shape_cpu_fallback: Option, + migraphx_cpu_fallback_parallel: Option, + migraphx_cpu_fallback_quantized: Option, + migraphx_min_run_tokens_override: Option, } impl ColbertBuilder { @@ -1052,6 +1213,12 @@ impl ColbertBuilder { dynamic_batch: true, query_length: None, document_length: None, + migraphx_static_shape: None, + migraphx_model_cache_dir: None, + migraphx_cold_shape_cpu_fallback: None, + migraphx_cpu_fallback_parallel: None, + migraphx_cpu_fallback_quantized: None, + migraphx_min_run_tokens_override: None, } } @@ -1105,6 +1272,52 @@ impl ColbertBuilder { self } + /// Specialize a MIGraphX session to one fixed ONNX input shape. + /// + /// This is primarily used internally by the cold-shape CPU fallback/cache + /// path. It binds the model's symbolic `batch_size` and `sequence_length` + /// dimensions before creating the ONNX Runtime session. + pub fn with_migraphx_static_shape(mut self, batch_size: usize, sequence_length: usize) -> Self { + self.migraphx_static_shape = Some(MigraphxStaticShape { + batch_size: batch_size.max(1), + sequence_length: sequence_length.max(1), + }); + self + } + + /// Set the MIGraphX model-cache directory for this session. + /// + /// Shape-specialized callers should provide one directory per fixed input + /// shape to avoid cross-shape MXR cache reuse. + pub fn with_migraphx_model_cache_dir>(mut self, cache_dir: P) -> Self { + self.migraphx_model_cache_dir = Some(cache_dir.as_ref().to_path_buf()); + self + } + + /// Enable or disable MIGraphX cold-shape CPU fallback. + pub fn with_migraphx_cold_shape_cpu_fallback(mut self, enabled: bool) -> Self { + self.migraphx_cold_shape_cpu_fallback = Some(enabled); + self + } + + /// Set the number of CPU sessions used by MIGraphX cold-shape fallback. + pub fn with_migraphx_cpu_fallback_parallel(mut self, num_sessions: usize) -> Self { + self.migraphx_cpu_fallback_parallel = Some(num_sessions.max(1)); + self + } + + pub fn with_migraphx_cpu_fallback_quantized(mut self, quantized: bool) -> Self { + self.migraphx_cpu_fallback_quantized = Some(quantized); + self + } + + /// Override the minimum warm MIGraphX token work required before the + /// hybrid router sends matching batches to the GPU lane. + pub fn with_migraphx_min_run_tokens(mut self, min_tokens: usize) -> Self { + self.migraphx_min_run_tokens_override = Some(min_tokens); + self + } + /// Set the maximum query length. /// /// If not set, uses `query_length` from `onnx_config.json` (default: 48). @@ -1125,10 +1338,23 @@ impl ColbertBuilder { /// Build the Colbert model. pub fn build(self) -> Result { + let model_dir_path = self.model_dir.clone(); + let quantized = self.quantized; + let requested_execution_provider = self.execution_provider; + let migraphx_static_shape = self.migraphx_static_shape; + let migraphx_model_cache_dir = self.migraphx_model_cache_dir.clone(); + let migraphx_cold_shape_cpu_fallback = self.migraphx_cold_shape_cpu_fallback; + let migraphx_cpu_fallback_parallel = self.migraphx_cpu_fallback_parallel; + let migraphx_cpu_fallback_quantized = self.migraphx_cpu_fallback_quantized; + let migraphx_min_run_tokens_override = self.migraphx_min_run_tokens_override; + init_ort_runtime(); let model_dir = &self.model_dir; - let onnx_path = select_onnx_file(model_dir, self.quantized)?; + let onnx_path = + select_onnx_file_for_provider(model_dir, self.quantized, requested_execution_provider)?; + validate_preconverted_fp16_migraphx_env(&onnx_path)?; + let migraphx_fp16_enable = migraphx_should_enable_fp16_conversion(&onnx_path, quantized); let tokenizer_path = model_dir.join("tokenizer.json"); let tokenizer = Tokenizer::from_file(&tokenizer_path) @@ -1154,6 +1380,26 @@ impl ColbertBuilder { provider => provider.is_gpu(), }; + // Determine batch size before session creation because MIGraphX hybrid + // mode needs it to derive the supported static-shape set. In hybrid + // mode the top-level model is only a router: actual work runs on + // lazily-created static MIGraphX child sessions or CPU fallback + // sessions, so creating a dynamic parent MIGraphX session is pure + // startup overhead. + let batch_size = self.batch_size.unwrap_or(if self.num_sessions > 1 { + 2 // Small batches optimal for parallel sessions + } else if gpu_execution_requested { + DEFAULT_GPU_BATCH_SIZE + } else { + DEFAULT_CPU_BATCH_SIZE + }); + + let migraphx_hybrid_enabled = should_enable_migraphx_cold_shape_cpu_fallback( + requested_execution_provider, + migraphx_static_shape, + migraphx_cold_shape_cpu_fallback, + ); + // For GPU execution, cap intra-op threads to 1 — the GPU handles parallelism // and extra threads only cause ORT to allocate per-thread CUDA workspace buffers, // wasting GPU memory. The high thread count only benefits CPU sessions. @@ -1163,51 +1409,94 @@ impl ColbertBuilder { self.threads_per_session }; - let mut sessions = Vec::with_capacity(self.num_sessions); - for _i in 0..self.num_sessions { - let builder = Session::builder() - .map_err(|e| anyhow::anyhow!("Failed to create ONNX session builder: {e:?}"))? - .with_optimization_level(GraphOptimizationLevel::Level3) - .map_err(|e| anyhow::anyhow!("Failed to set ONNX optimization level: {e:?}"))? - .with_intra_threads(threads_per_session) - .map_err(|e| anyhow::anyhow!("Failed to set ONNX intra-op threads: {e:?}"))? - .with_inter_threads(if self.num_sessions > 1 { 1 } else { 2 }) - .map_err(|e| anyhow::anyhow!("Failed to set ONNX inter-op threads: {e:?}"))?; - // Disable memory pattern optimization for all providers. - // On CPU this helps with variable-length sequences (~7% speedup). - // On GPU this prevents ORT from pre-allocating a large memory arena - // that can cause OOM on GPUs with limited free memory. - let builder = builder - .with_memory_pattern(false) - .map_err(|e| anyhow::anyhow!("Failed to configure ONNX memory pattern: {e:?}"))?; - - let builder = configure_execution_provider(builder, self.execution_provider)?; - - let session = builder - .commit_from_file(&onnx_path) - .context("Failed to load ONNX model")?; - - sessions.push(Arc::new(Mutex::new(session))); - } - - // Determine batch size - let batch_size = self.batch_size.unwrap_or(if self.num_sessions > 1 { - 2 // Small batches optimal for parallel sessions - } else if gpu_execution_requested { - DEFAULT_GPU_BATCH_SIZE + let sessions = if migraphx_hybrid_enabled { + Vec::new() } else { - DEFAULT_CPU_BATCH_SIZE - }); + let mut sessions = Vec::with_capacity(self.num_sessions); + for _i in 0..self.num_sessions { + let builder = Session::builder() + .map_err(|e| anyhow::anyhow!("Failed to create ONNX session builder: {e:?}"))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| anyhow::anyhow!("Failed to set ONNX optimization level: {e:?}"))? + .with_intra_threads(threads_per_session) + .map_err(|e| anyhow::anyhow!("Failed to set ONNX intra-op threads: {e:?}"))? + .with_inter_threads(if self.num_sessions > 1 { 1 } else { 2 }) + .map_err(|e| anyhow::anyhow!("Failed to set ONNX inter-op threads: {e:?}"))?; + let builder = if let Some(shape) = migraphx_static_shape { + builder + .with_dimension_override("batch_size", shape.batch_size as i64) + .map_err(|e| { + anyhow::anyhow!("Failed to set MIGraphX static batch dimension: {e:?}") + })? + .with_dimension_override("sequence_length", shape.sequence_length as i64) + .map_err(|e| { + anyhow::anyhow!( + "Failed to set MIGraphX static sequence dimension: {e:?}" + ) + })? + } else { + builder + }; + // Disable memory pattern optimization for all providers. + // On CPU this helps with variable-length sequences (~7% speedup). + // On GPU this prevents ORT from pre-allocating a large memory arena + // that can cause OOM on GPUs with limited free memory. + let builder = builder.with_memory_pattern(false).map_err(|e| { + anyhow::anyhow!("Failed to configure ONNX memory pattern: {e:?}") + })?; + + let builder = configure_execution_provider_with_options( + builder, + self.execution_provider, + migraphx_model_cache_dir.as_deref(), + migraphx_fp16_enable, + )?; + + let session = builder + .commit_from_file(&onnx_path) + .context("Failed to load ONNX model")?; + + sessions.push(Arc::new(Mutex::new(session))); + } + sessions + }; + + let tokenizer = Arc::new(tokenizer); + let config = Arc::new(config); + let skiplist_ids = Arc::new(skiplist_ids); + + let migraphx_hybrid = if migraphx_hybrid_enabled { + let cache_root = default_migraphx_static_cache_root().ok_or_else(|| { + anyhow::anyhow!( + "Failed to determine MIGraphX static-shape cache directory. Set NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT." + ) + })?; + Some(Arc::new(MigraphxHybrid::new( + model_dir_path.clone(), + quantized, + &onnx_path, + Arc::clone(&tokenizer), + Arc::clone(&config), + batch_size, + cache_root, + migraphx_cpu_fallback_parallel, + migraphx_cpu_fallback_quantized.unwrap_or(quantized), + migraphx_min_run_tokens_override, + )?)) + } else { + None + }; Ok(Colbert { sessions, - tokenizer: Arc::new(tokenizer), - config: Arc::new(config), - skiplist_ids: Arc::new(skiplist_ids), + tokenizer, + config, + skiplist_ids, next_session_idx: Arc::new(AtomicUsize::new(0)), requested_execution_provider: self.execution_provider, batch_size, dynamic_batch: self.dynamic_batch, + migraphx_hybrid, }) } } @@ -1276,6 +1565,11 @@ impl Colbert { return Ok(Vec::new()); } + if self.migraphx_hybrid.is_some() { + let prepared = self.tokenize_documents_in_batches(documents)?; + return self.encode_prepared_document_batches(prepared); + } + if self.sessions.len() == 1 { self.encode_single_session(documents, false, true) } else { @@ -1328,6 +1622,7 @@ impl Colbert { false, true, piece_indices, + None, )?); } @@ -1335,9 +1630,11 @@ impl Colbert { } // GPU path: token-budget dynamic batching. Documents are sorted by - // length and bucketed into fixed shapes (quantized to 32-token steps). - // This lets the GPU reuse execution plans across batches with the same - // shape, reducing kernel launch overhead and minimizing padding waste. + // length and bucketed into planned sequence lengths. Shape-sensitive + // execution providers (currently MIGraphX) pad tensors to those + // planned sequence lengths so compiled execution plans can be reused. + // Other providers keep the historical exact per-batch tensor sizes to + // avoid changing their padding/throughput behavior. // We carry the original input index alongside each tokenized doc so // `encode_prepared_document_batches` can restore the caller-visible // input order in the returned embeddings. @@ -1355,6 +1652,8 @@ impl Colbert { let shapes = build_fixed_dynamic_shapes(self.batch_size.max(1), self.config.document_length); + let pad_to_planned_sequence_len = + execution_provider_prefers_planned_sequence_lengths(self.requested_execution_provider); let mut buckets: Vec> = (0..shapes.len()).map(|_| Vec::new()).collect(); @@ -1379,6 +1678,15 @@ impl Colbert { piece_encodings.push(encoding); piece_indices.push(idx); } + // For MIGraphX we must avoid per-batch sequence lengths like + // 255/505/1008, because each distinct tensor shape triggers a + // new compile/cache entry. Keep the real row count here: + // padding a small/final short-doc batch up to `shape.docs` + // can turn one real document into a much larger tensor. + let planned_shape = pad_to_planned_sequence_len.then_some(FixedDynamicShape { + docs: piece_encodings.len(), + planned_len: shape.planned_len, + }); batches.push(prepare_batch_from_tokenized_documents( &self.tokenizer, &self.config, @@ -1386,6 +1694,7 @@ impl Colbert { false, true, piece_indices, + planned_shape, )?); } } @@ -1397,6 +1706,14 @@ impl Colbert { &self, prepared: PreparedDocumentBatch, ) -> Result>> { + if let Some(hybrid) = &self.migraphx_hybrid { + return hybrid.encode_one_prepared(prepared); + } + + if self.sessions.is_empty() { + anyhow::bail!("ColBERT model has no ONNX sessions available for encoding"); + } + let session_idx = self.next_session_idx.fetch_add(1, Ordering::Relaxed) % self.sessions.len().max(1); let mut session = self.sessions[session_idx].lock().unwrap(); @@ -1411,6 +1728,10 @@ impl Colbert { return Ok(Vec::new()); } + if let Some(hybrid) = &self.migraphx_hybrid { + return hybrid.encode_prepared_document_batches(prepared_batches); + } + // Collect the original-input position for every document across all // batches in the order they appear here. When `tokenize_documents_in_batches` // sorts documents by length (GPU dynamic batching path) the embeddings @@ -1427,73 +1748,58 @@ impl Colbert { } } - let encoded: Vec> = if self.sessions.len() <= 1 || prepared_batches.len() == 1 { - let mut all_embeddings = Vec::new(); - for prepared_batch in prepared_batches { - all_embeddings.extend(self.encode_prepared_documents(prepared_batch)?); - } - all_embeddings - } else { - let results: Vec>>> = std::thread::scope(|scope| { - let mut handles = Vec::with_capacity(prepared_batches.len()); - - for (i, prepared_batch) in prepared_batches.into_iter().enumerate() { - let session_idx = i % self.sessions.len(); - let session_mutex = &self.sessions[session_idx]; - let config = &self.config; - let skiplist_ids = &self.skiplist_ids; + let encoded = self.encode_prepared_batches_unordered(prepared_batches)?; - handles.push(scope.spawn(move || { - let mut session = session_mutex.lock().unwrap(); - encode_prepared_batch_with_session( - &mut session, - config, - skiplist_ids, - prepared_batch, - ) - })); - } + restore_original_input_order(encoded, combined_indices, has_reordering) + } - handles - .into_iter() - .map(|handle| handle.join().unwrap()) - .collect() - }); + fn encode_prepared_batches_unordered( + &self, + prepared_batches: Vec, + ) -> Result>> { + if self.sessions.is_empty() { + anyhow::bail!("ColBERT model has no ONNX sessions available for encoding"); + } + if self.sessions.len() <= 1 || prepared_batches.len() == 1 { let mut all_embeddings = Vec::new(); - for result in results { - all_embeddings.extend(result?); + for prepared_batch in prepared_batches { + all_embeddings.extend(self.encode_prepared_documents(prepared_batch)?); } - all_embeddings - }; - - if !has_reordering || combined_indices.len() != encoded.len() { - return Ok(encoded); + return Ok(all_embeddings); } - // Restore input order: encoded[i] belongs at output position combined_indices[i]. - let n = encoded.len(); - let mut reordered: Vec>> = (0..n).map(|_| None).collect(); - for (encoded_pos, embedding) in encoded.into_iter().enumerate() { - let target = combined_indices[encoded_pos]; - if target >= n { - anyhow::bail!( - "original_input_indices points to out-of-range slot ({} >= {})", - target, - n - ); + let results: Vec>>> = std::thread::scope(|scope| { + let mut handles = Vec::with_capacity(prepared_batches.len()); + + for (i, prepared_batch) in prepared_batches.into_iter().enumerate() { + let session_idx = i % self.sessions.len(); + let session_mutex = &self.sessions[session_idx]; + let config = &self.config; + let skiplist_ids = &self.skiplist_ids; + + handles.push(scope.spawn(move || { + let mut session = session_mutex.lock().unwrap(); + encode_prepared_batch_with_session( + &mut session, + config, + skiplist_ids, + prepared_batch, + ) + })); } - reordered[target] = Some(embedding); + + handles + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect() + }); + + let mut all_embeddings = Vec::new(); + for result in results { + all_embeddings.extend(result?); } - reordered - .into_iter() - .enumerate() - .map(|(i, opt)| { - opt.ok_or_else(|| { - anyhow::anyhow!("original_input_indices missing slot {} in output", i) - }) - }) - .collect() + Ok(all_embeddings) } /// Stream document embeddings chunk-by-chunk. @@ -1546,6 +1852,31 @@ impl Colbert { }); } + if self.migraphx_hybrid.is_some() { + let model = self.clone(); + let (raw_tx, raw_rx) = mpsc::channel::>(); + let handle = std::thread::Builder::new() + .name("next-plaid-hybrid-stream".to_string()) + .spawn(move || { + let refs: Vec<&str> = documents.iter().map(String::as_str).collect(); + let result = model.encode_documents_raw(&refs).map(|embeddings| { + RawDocumentEmbeddingChunk { + chunk_index: 0, + start_offset: 0, + embeddings, + } + }); + + let _ = raw_tx.send(result); + }) + .expect("failed to spawn next-plaid hybrid stream worker"); + + return Ok(RawDocumentEmbeddingStream { + receiver: raw_rx, + handles: vec![handle], + }); + } + let chunk_queue = Arc::new(Mutex::new(self.build_document_work_queue(documents))); let (raw_tx, raw_rx) = mpsc::channel::>(); @@ -1623,6 +1954,10 @@ impl Colbert { return Ok(Vec::new()); } + if let Some(hybrid) = &self.migraphx_hybrid { + return hybrid.encode_queries(queries, self.batch_size); + } + if self.sessions.len() == 1 { self.encode_single_session(queries, true, false) } else { @@ -1650,6 +1985,44 @@ impl Colbert { self.sessions.len() } + /// Warm and validate all fixed-shape MIGraphX caches for this model. + /// + /// This is only available when the model was built with + /// `ExecutionProvider::MIGraphX` and cold-shape CPU fallback enabled. + pub fn warm_migraphx_static_shape_cache(&self) -> Result { + self.migraphx_hybrid + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!( + "MIGraphX static-shape cache warming is only available for non-static MIGraphX models with cold-shape CPU fallback enabled" + ) + })? + .warm_all_default_shapes() + } + + /// Warm and validate fixed-shape MIGraphX caches up to `max_sequence_len`. + pub fn warm_migraphx_static_shape_cache_up_to(&self, max_sequence_len: usize) -> Result { + self.migraphx_hybrid + .as_ref() + .ok_or_else(|| { + anyhow::anyhow!( + "MIGraphX static-shape cache warming is only available for non-static MIGraphX models with cold-shape CPU fallback enabled" + ) + })? + .warm_default_shapes(max_sequence_len) + } + + /// Return the fixed MIGraphX shapes this model may use when their caches + /// are warm and validated. + pub fn migraphx_static_shapes(&self) -> Vec { + let Some(hybrid) = &self.migraphx_hybrid else { + return Vec::new(); + }; + let mut shapes: Vec<_> = hybrid.document_shapes.iter().copied().collect(); + shapes.sort_by_key(|shape| (shape.sequence_length, shape.batch_size)); + shapes + } + // ========================================================================= // Internal encoding implementations // ========================================================================= @@ -1785,7 +2158,68 @@ fn tokenizer_thread_pool() -> &'static ThreadPool { // Helper functions // ============================================================================= -fn select_onnx_file>(model_dir: P, quantized: bool) -> Result { +fn execution_provider_prefers_fp16_onnx(provider: ExecutionProvider) -> bool { + match provider { + ExecutionProvider::MIGraphX => true, + ExecutionProvider::Auto => { + preferred_gpu_execution_provider() == Some(ExecutionProvider::MIGraphX) + } + _ => false, + } +} + +fn is_preconverted_fp16_onnx(path: &Path) -> bool { + path.file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.eq_ignore_ascii_case("model_fp16.onnx")) +} + +fn env_flag_enabled(name: &str) -> bool { + std::env::var(name) + .map(|value| { + let value = value.trim(); + !(value.is_empty() + || value == "0" + || value.eq_ignore_ascii_case("false") + || value.eq_ignore_ascii_case("off")) + }) + .unwrap_or(false) +} + +fn validate_preconverted_fp16_migraphx_env(onnx_path: &Path) -> Result<()> { + if !is_preconverted_fp16_onnx(onnx_path) { + return Ok(()); + } + + let forced_precision_env = [ + "ORT_MIGRAPHX_FP16_ENABLE", + "ORT_MIGRAPHX_BF16_ENABLE", + "ORT_MIGRAPHX_INT8_ENABLE", + "ORT_MIGRAPHX_FP8_ENABLE", + ] + .into_iter() + .find(|name| env_flag_enabled(name)); + + if let Some(name) = forced_precision_env { + anyhow::bail!( + "{} is set while loading {}. Unset it so MIGraphX preserves the precision layout encoded in model_fp16.onnx.", + name, + onnx_path.display() + ); + } + + Ok(()) +} + +fn migraphx_should_enable_fp16_conversion(onnx_path: &Path, quantized: bool) -> bool { + !quantized && !is_preconverted_fp16_onnx(onnx_path) +} + +fn select_onnx_file_for_provider>( + model_dir: P, + quantized: bool, + provider: ExecutionProvider, +) -> Result { let model_dir = model_dir.as_ref(); if quantized { @@ -1799,9 +2233,30 @@ fn select_onnx_file>(model_dir: P, quantized: bool) -> Result bool { + match provider { + ExecutionProvider::MIGraphX => true, + ExecutionProvider::Auto => { + preferred_gpu_execution_provider() == Some(ExecutionProvider::MIGraphX) + } + _ => false, + } +} + +fn can_pad_migraphx_warm_tail_rows(real_rows: usize, planned_rows: usize) -> bool { + can_pad_migraphx_warm_tail_rows_with_factor(real_rows, planned_rows, 2) +} + +fn can_pad_migraphx_active_warm_tail_rows(real_rows: usize, planned_rows: usize) -> bool { + can_pad_migraphx_warm_tail_rows_with_factor(real_rows, planned_rows, 16) +} + +fn can_pad_migraphx_warm_tail_rows_with_factor( + real_rows: usize, + planned_rows: usize, + max_factor: usize, +) -> bool { + real_rows > 0 + && planned_rows > real_rows + && planned_rows <= real_rows.saturating_mul(max_factor.max(1)) +} + +fn migraphx_trace_enabled() -> bool { + env_truthy("NEXT_PLAID_MIGRAPHX_TRACE") +} + +fn migraphx_split_warm_batches_enabled() -> bool { + env_truthy("NEXT_PLAID_MIGRAPHX_SPLIT_WARM_BATCHES") +} + +fn env_truthy(name: &str) -> bool { + std::env::var(name) + .map(|value| { + !matches!( + value.trim().to_ascii_lowercase().as_str(), + "" | "0" | "false" | "no" | "off" + ) + }) + .unwrap_or(false) +} + +fn record_migraphx_route( + summary: &mut BTreeMap, + route: &str, + shape: MigraphxStaticShape, + real_rows: usize, + tensor_rows: usize, +) { + let key = format!("{route}:{}x{}", shape.batch_size, shape.sequence_length); + let entry = summary.entry(key).or_insert((0, 0, 0)); + entry.0 += 1; + entry.1 += real_rows; + entry.2 += tensor_rows; +} + +fn migraphx_padded_rows_for_warm_shape(real_rows: usize, warm_rows: &[usize]) -> Option { + warm_rows.iter().copied().find(|rows| *rows >= real_rows) +} + fn build_fixed_dynamic_shapes(batch_size: usize, document_length: usize) -> Vec { let total_budget = batch_size.max(1).saturating_mul(document_length.max(1)); let mut shapes = Vec::new(); @@ -1904,19 +2424,1734 @@ fn build_fixed_dynamic_shapes(batch_size: usize, document_length: usize) -> Vec< shapes } -fn update_token_ids(config: &mut ColbertConfig, tokenizer: &Tokenizer) { - if config.mask_token_id == default_mask_token_id() { - if let Some(mask_id) = tokenizer.token_to_id("[MASK]") { - config.mask_token_id = mask_id; - } else if let Some(mask_id) = tokenizer.token_to_id("") { - config.mask_token_id = mask_id; - } - } - if config.pad_token_id == default_pad_token_id() { - if let Some(pad_id) = tokenizer.token_to_id("[PAD]") { - config.pad_token_id = pad_id; - } else if let Some(pad_id) = tokenizer.token_to_id("") { - config.pad_token_id = pad_id; +fn build_migraphx_document_static_shapes( + batch_size: usize, + document_length: usize, +) -> Vec { + build_migraphx_document_static_shapes_with_min_tokens( + batch_size, + document_length, + migraphx_min_static_shape_tokens(), + ) +} + +const DEFAULT_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS: usize = 65_536; +const DEFAULT_MIGRAPHX_MIN_RUN_TOKENS: usize = 1_048_576; + +fn migraphx_min_static_shape_tokens() -> usize { + std::env::var("NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS") + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(DEFAULT_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS) +} + +fn migraphx_min_run_tokens() -> usize { + std::env::var("NEXT_PLAID_MIGRAPHX_MIN_RUN_TOKENS") + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(DEFAULT_MIGRAPHX_MIN_RUN_TOKENS) +} + +pub fn migraphx_default_min_run_tokens() -> usize { + migraphx_min_run_tokens() +} + +fn build_migraphx_document_static_shapes_with_min_tokens( + batch_size: usize, + document_length: usize, + min_shape_tokens: usize, +) -> Vec { + let mut shapes = HashSet::new(); + let fixed_shapes = build_fixed_dynamic_shapes(batch_size.max(1), document_length); + + for shape in fixed_shapes { + let shape_tokens = shape.docs.saturating_mul(shape.planned_len); + if shape_tokens < min_shape_tokens { + continue; + } + shapes.insert(MigraphxStaticShape { + batch_size: shape.docs, + sequence_length: shape.planned_len, + }); + } + + let mut shapes: Vec<_> = shapes.into_iter().collect(); + shapes.sort_by_key(|shape| (shape.sequence_length, shape.batch_size)); + shapes +} + +fn should_enable_migraphx_cold_shape_cpu_fallback( + provider: ExecutionProvider, + static_shape: Option, + override_enabled: Option, +) -> bool { + provider == ExecutionProvider::MIGraphX + && static_shape.is_none() + && override_enabled.unwrap_or(true) +} + +fn default_migraphx_cpu_fallback_sessions() -> usize { + std::thread::available_parallelism() + .map(|p| p.get()) + .unwrap_or(16) + .min(16) + .max(1) +} + +fn default_migraphx_static_cache_root() -> Option { + if let Ok(path) = std::env::var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT") { + if !path.trim().is_empty() { + return Some(PathBuf::from(path)); + } + } + + if let Ok(path) = std::env::var("XDG_CACHE_HOME") { + if !path.trim().is_empty() { + return Some(PathBuf::from(path).join("next-plaid").join("migraphx")); + } + } + + std::env::var("HOME").ok().and_then(|home| { + if home.trim().is_empty() { + None + } else { + Some( + PathBuf::from(home) + .join(".cache") + .join("next-plaid") + .join("migraphx"), + ) + } + }) +} + +fn hex_encode(bytes: &[u8]) -> String { + const HEX: &[u8; 16] = b"0123456789abcdef"; + let mut out = String::with_capacity(bytes.len() * 2); + for &byte in bytes { + out.push(HEX[(byte >> 4) as usize] as char); + out.push(HEX[(byte & 0x0f) as usize] as char); + } + out +} + +fn hash_file_sha256(path: &Path) -> Result { + let mut file = fs::File::open(path) + .with_context(|| format!("failed to open {} for hashing", path.display()))?; + let mut hasher = Sha256::new(); + let mut buffer = [0u8; 64 * 1024]; + loop { + let read = file + .read(&mut buffer) + .with_context(|| format!("failed to read {} for hashing", path.display()))?; + if read == 0 { + break; + } + hasher.update(&buffer[..read]); + } + Ok(hex_encode(&hasher.finalize())) +} + +fn metadata_fingerprint(path: &Path) -> String { + let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); + let mut parts = vec![format!("path={}", canonical.display())]; + match fs::metadata(path) { + Ok(metadata) => { + parts.push(format!("len={}", metadata.len())); + if let Ok(modified) = metadata.modified() { + if let Ok(duration) = modified.duration_since(std::time::UNIX_EPOCH) { + parts.push(format!( + "mtime={}.{}", + duration.as_secs(), + duration.subsec_nanos() + )); + } + } + } + Err(err) => parts.push(format!("metadata_error={err}")), + } + parts.join(";") +} + +fn onnxruntime_dylib_path_for_cache_key() -> Option { + if let Ok(path) = std::env::var("ORT_DYLIB_PATH") { + if !path.trim().is_empty() { + return Some(PathBuf::from(path)); + } + } + + find_onnxruntime_library().map(PathBuf::from) +} + +fn onnxruntime_dylib_fingerprint() -> String { + onnxruntime_dylib_path_for_cache_key() + .map(|path| metadata_fingerprint(&path)) + .unwrap_or_else(|| "unavailable".to_string()) +} + +fn migraphx_provider_library_fingerprint() -> String { + let mut candidates = Vec::new(); + + if let Some(ort_path) = onnxruntime_dylib_path_for_cache_key() { + if let Some(parent) = ort_path.parent() { + candidates.push(parent.join("libonnxruntime_providers_migraphx.so")); + candidates.push(parent.join("onnxruntime_providers_migraphx.dll")); + } + } + + candidates.extend([ + PathBuf::from("/usr/lib/libonnxruntime_providers_migraphx.so"), + PathBuf::from("/usr/local/lib/libonnxruntime_providers_migraphx.so"), + PathBuf::from("/opt/rocm/lib/libonnxruntime_providers_migraphx.so"), + ]); + + candidates + .into_iter() + .find(|path| path.exists()) + .map(|path| metadata_fingerprint(&path)) + .unwrap_or_else(|| "unavailable".to_string()) +} + +fn migraphx_driver_version_fingerprint() -> String { + static VERSION: OnceLock = OnceLock::new(); + VERSION + .get_or_init(|| { + Command::new("migraphx-driver") + .arg("--version") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .ok() + .and_then(|output| { + if output.status.success() { + let mut text = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr); + if !stderr.trim().is_empty() { + if !text.trim().is_empty() { + text.push('\n'); + } + text.push_str(stderr.trim()); + } + Some(text.trim().to_string()) + } else { + None + } + }) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "unavailable".to_string()) + }) + .clone() +} + +#[cfg(target_os = "linux")] +fn linux_kfd_gpu_topology_fingerprint() -> Option { + let nodes = fs::read_dir("/sys/class/kfd/kfd/topology/nodes").ok()?; + let mut entries = Vec::new(); + + for node in nodes.flatten() { + let node_path = node.path(); + let properties = fs::read_to_string(node_path.join("properties")).ok()?; + let mut selected = BTreeMap::new(); + for line in properties.lines() { + let mut parts = line.split_whitespace(); + let Some(key) = parts.next() else { continue }; + let Some(value) = parts.next() else { continue }; + if matches!( + key, + "gfx_target_version" | "vendor_id" | "device_id" | "simd_count" + ) { + selected.insert(key.to_string(), value.to_string()); + } + } + + if selected + .get("gfx_target_version") + .is_some_and(|value| value != "0") + { + if let Ok(gpu_id) = fs::read_to_string(node_path.join("gpu_id")) { + selected.insert("gpu_id".to_string(), gpu_id.trim().to_string()); + } + let node_name = node.file_name().to_string_lossy().to_string(); + entries.push(format!("node={node_name};{:?}", selected)); + } + } + + if entries.is_empty() { + None + } else { + entries.sort(); + Some(entries.join("|")) + } +} + +#[cfg(not(target_os = "linux"))] +fn linux_kfd_gpu_topology_fingerprint() -> Option { + None +} + +fn migraphx_gpu_fingerprint() -> String { + let mut parts = Vec::new(); + + if let Some(topology) = linux_kfd_gpu_topology_fingerprint() { + parts.push(format!("kfd={topology}")); + } + + for name in [ + "HSA_OVERRIDE_GFX_VERSION", + "HIP_VISIBLE_DEVICES", + "ROCR_VISIBLE_DEVICES", + "GPU_DEVICE_ORDINAL", + "CUDA_VISIBLE_DEVICES", + ] { + if let Ok(value) = std::env::var(name) { + let value = value.trim(); + if !value.is_empty() { + parts.push(format!("{name}={value}")); + } + } + } + + if parts.is_empty() { + "unavailable".to_string() + } else { + parts.join(";") + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +struct MigraphxCacheOptions { + entries: Vec<(String, String)>, +} + +impl MigraphxCacheOptions { + fn from_provider_options(fp16_enable: bool) -> Self { + let mut entries = vec![("device_id".to_string(), "0".to_string())]; + if fp16_enable { + entries.push(("migraphx_fp16_enable".to_string(), "1".to_string())); + } + + // ORT documents these environment variables as global MIGraphX knobs + // that take precedence over provider/session options. Include any + // non-empty values in the static MXR cache key so user overrides do + // not accidentally reuse validation markers from a differently + // compiled graph. Cache-path variables themselves are intentionally + // excluded because this function chooses our per-shape cache path. + for name in [ + "ORT_MIGRAPHX_FP16_ENABLE", + "ORT_MIGRAPHX_BF16_ENABLE", + "ORT_MIGRAPHX_INT8_ENABLE", + "ORT_MIGRAPHX_FP8_ENABLE", + "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME", + "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE", + "ORT_MIGRAPHX_EXHAUSTIVE_TUNE", + "ORT_MIGRAPHX_MEM_LIMIT", + ] { + if let Ok(value) = std::env::var(name) { + let value = value.trim(); + if !value.is_empty() { + entries.push((name.to_string(), value.to_string())); + } + } + } + + entries.sort(); + Self { entries } + } +} + +fn cache_key_for_onnx(path: &Path, quantized: bool) -> String { + cache_key_for_onnx_with_options( + path, + quantized, + MigraphxCacheOptions::from_provider_options(migraphx_should_enable_fp16_conversion( + path, quantized, + )), + ) +} + +fn cache_key_for_onnx_with_options( + path: &Path, + quantized: bool, + migraphx_options: MigraphxCacheOptions, +) -> String { + let mut hasher = Sha256::new(); + + fn add_entry(hasher: &mut Sha256, key: &str, value: &str) { + hasher.update(key.as_bytes()); + hasher.update([0]); + hasher.update(value.as_bytes()); + hasher.update([0xff]); + } + + // This key is the model/provider/runtime prefix of the full cache path: + // `//x/`. The trailing shape + // directory contributes the static ONNX input shape. Everything below can + // affect MIGraphX's compiled MXR program and must not share a validation + // marker/cache directory across incompatible runs. + add_entry(&mut hasher, "namespace", "migraphx-static-cache-v3"); + add_entry(&mut hasher, "quantized", &quantized.to_string()); + add_entry( + &mut hasher, + "model_path", + &path + .canonicalize() + .unwrap_or_else(|_| path.to_path_buf()) + .display() + .to_string(), + ); + add_entry( + &mut hasher, + "model_sha256", + &hash_file_sha256(path).unwrap_or_else(|err| format!("unavailable:{err:#}")), + ); + add_entry( + &mut hasher, + "ort_api_version", + &ort::MINOR_VERSION.to_string(), + ); + add_entry( + &mut hasher, + "onnxruntime_dylib", + &onnxruntime_dylib_fingerprint(), + ); + add_entry( + &mut hasher, + "migraphx_provider_library", + &migraphx_provider_library_fingerprint(), + ); + add_entry( + &mut hasher, + "migraphx_driver_version", + &migraphx_driver_version_fingerprint(), + ); + add_entry(&mut hasher, "gpu", &migraphx_gpu_fingerprint()); + + let mut option_entries = migraphx_options.entries; + option_entries.sort(); + for (key, value) in option_entries { + add_entry(&mut hasher, &format!("migraphx_option:{key}"), &value); + } + + hex_encode(&hasher.finalize()) +} + +fn shape_cache_has_mxr(cache_dir: &Path) -> bool { + fs::read_dir(cache_dir) + .ok() + .into_iter() + .flat_map(|entries| entries.flatten()) + .any(|entry| entry.path().extension().is_some_and(|ext| ext == "mxr")) +} + +fn migraphx_shape_cache_is_warm( + cache_root: &Path, + model_cache_key: &str, + shape: MigraphxStaticShape, +) -> bool { + let cache_dir = cache_root + .join(model_cache_key) + .join(shape.cache_dir_name()); + cache_dir.join("validated-v1").exists() && shape_cache_has_mxr(&cache_dir) +} + +struct MigraphxShapeCacheLock { + file: fs::File, +} + +impl Drop for MigraphxShapeCacheLock { + fn drop(&mut self) { + let _ = self.file.unlock(); + } +} + +fn acquire_migraphx_shape_cache_lock(cache_dir: &Path) -> Result { + // ONNX Runtime's MIGraphX EP writes `.mxr` files directly into the model + // cache directory. Serialize cache-producing session creation per fixed + // shape so concurrent warmers do not read/write the same MXR path at once. + fs::create_dir_all(cache_dir).with_context(|| { + format!( + "Failed to create MIGraphX cache directory {}", + cache_dir.display() + ) + })?; + + let lock_path = cache_dir.join(".compile.lock"); + let file = fs::OpenOptions::new() + .create(true) + .read(true) + .write(true) + .open(&lock_path) + .with_context(|| { + format!( + "Failed to open MIGraphX shape-cache lock {}", + lock_path.display() + ) + })?; + + file.lock_exclusive().with_context(|| { + format!( + "Failed to acquire MIGraphX shape-cache lock {}", + lock_path.display() + ) + })?; + + Ok(MigraphxShapeCacheLock { file }) +} + +fn write_migraphx_validation_marker(marker_path: &Path, contents: &str) -> Result<()> { + // Readers treat this marker as the publish point for a usable shape cache, + // so write it via a temporary file and atomic rename after validation. + let parent = marker_path.parent().ok_or_else(|| { + anyhow::anyhow!( + "MIGraphX validation marker has no parent directory: {}", + marker_path.display() + ) + })?; + fs::create_dir_all(parent).with_context(|| { + format!( + "Failed to create MIGraphX marker directory {}", + parent.display() + ) + })?; + + let file_name = marker_path + .file_name() + .map(|name| name.to_string_lossy()) + .unwrap_or_else(|| "validated-v1".into()); + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or_default(); + let tmp_path = + marker_path.with_file_name(format!("{file_name}.tmp.{}.{}", std::process::id(), nanos)); + + let result = (|| -> Result<()> { + let mut file = fs::OpenOptions::new() + .create_new(true) + .write(true) + .open(&tmp_path) + .with_context(|| { + format!( + "Failed to create temporary MIGraphX validation marker {}", + tmp_path.display() + ) + })?; + file.write_all(contents.as_bytes()).with_context(|| { + format!( + "Failed to write temporary MIGraphX validation marker {}", + tmp_path.display() + ) + })?; + file.sync_all().with_context(|| { + format!( + "Failed to sync temporary MIGraphX validation marker {}", + tmp_path.display() + ) + })?; + drop(file); + + fs::rename(&tmp_path, marker_path).with_context(|| { + format!( + "Failed to publish MIGraphX validation marker {}", + marker_path.display() + ) + })?; + Ok(()) + })(); + + if result.is_err() { + let _ = fs::remove_file(&tmp_path); + } + + result +} + +/// Inspect fixed-shape MIGraphX caches for a model without creating ONNX +/// sessions. +/// +/// This is intended for higher-level auto-selection policy: callers can choose +/// MIGraphX only when all document shapes have already been compiled and +/// validated, avoiding cold graph-compilation stalls in interactive commands. +pub fn migraphx_static_shape_cache_status>( + model_dir: P, + quantized: bool, + batch_size: usize, +) -> Result { + let model_dir = model_dir.as_ref(); + let onnx_path = + select_onnx_file_for_provider(model_dir, quantized, ExecutionProvider::MIGraphX)?; + let config = ColbertConfig::from_model_dir(model_dir)?; + + let cache_root = default_migraphx_static_cache_root().ok_or_else(|| { + anyhow::anyhow!( + "Failed to determine MIGraphX static-shape cache directory. Set NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT." + ) + })?; + let model_cache_key = cache_key_for_onnx(&onnx_path, quantized); + + let document_shapes = build_migraphx_document_static_shapes(batch_size, config.document_length); + + let mut warm_document_shapes = Vec::new(); + let mut cold_document_shapes = Vec::new(); + for shape in &document_shapes { + if migraphx_shape_cache_is_warm(&cache_root, &model_cache_key, *shape) { + warm_document_shapes.push(*shape); + } else { + cold_document_shapes.push(*shape); + } + } + + let query_shape = MigraphxStaticShape { + batch_size: 1, + sequence_length: config.query_length, + }; + let query_shape_warm = migraphx_shape_cache_is_warm(&cache_root, &model_cache_key, query_shape); + + Ok(MigraphxStaticShapeCacheStatus { + cache_root, + model_cache_key, + document_shapes, + warm_document_shapes, + cold_document_shapes, + query_shape, + query_shape_warm, + }) +} + +/// Return true when every document fixed-shape cache that MIGraphX indexing may +/// use for this model/batch size is present and validated. +pub fn migraphx_document_static_shape_caches_warm>( + model_dir: P, + quantized: bool, + batch_size: usize, +) -> Result { + Ok( + migraphx_static_shape_cache_status(model_dir, quantized, batch_size)? + .all_document_shapes_warm(), + ) +} + +fn restore_original_input_order( + encoded: Vec>, + combined_indices: Vec, + has_reordering: bool, +) -> Result>> { + if !has_reordering || combined_indices.len() != encoded.len() { + return Ok(encoded); + } + + let n = encoded.len(); + let mut reordered: Vec>> = (0..n).map(|_| None).collect(); + for (encoded_pos, embedding) in encoded.into_iter().enumerate() { + let target = combined_indices[encoded_pos]; + if target >= n { + anyhow::bail!( + "original_input_indices points to out-of-range slot ({} >= {})", + target, + n + ); + } + reordered[target] = Some(embedding); + } + reordered + .into_iter() + .enumerate() + .map(|(i, opt)| { + opt.ok_or_else(|| { + anyhow::anyhow!("original_input_indices missing slot {} in output", i) + }) + }) + .collect() +} + +fn trim_prepared_batch_for_cpu_fallback( + prepared: PreparedDocumentBatch, +) -> Result { + if prepared.batch_size == 0 { + return Ok(prepared); + } + + let source_rows = prepared.tensor_batch_size; + let source_len = prepared.batch_max_len; + if source_rows < prepared.batch_size { + anyhow::bail!( + "prepared batch has {} tensor rows but {} real documents", + source_rows, + prepared.batch_size + ); + } + if prepared.original_lengths.len() != prepared.batch_size { + anyhow::bail!( + "prepared batch has {} original lengths but {} real documents", + prepared.original_lengths.len(), + prepared.batch_size + ); + } + if prepared.all_token_ids.len() != prepared.batch_size { + anyhow::bail!( + "prepared batch has {} token-id rows but {} real documents", + prepared.all_token_ids.len(), + prepared.batch_size + ); + } + + // Document batches prepared for MIGraphX are padded to fixed sequence + // lengths such as 128/256/512 so that warm static-shape caches can be + // reused. When a shape is cold and we fall back to CPU, those padded tokens + // only add CPU work. Trim documents back to the longest real document in + // this batch. Query batches keep their full length so query expansion still + // returns the configured number of query vectors. + let required_len = prepared + .original_lengths + .iter() + .copied() + .chain(prepared.all_token_ids.iter().map(Vec::len)) + .max() + .unwrap_or(source_len) + .max(1); + if required_len > source_len { + anyhow::bail!( + "prepared batch requires {} tokens but tensor sequence length is {}", + required_len, + source_len + ); + } + + let target_len = if prepared.is_query { + source_len + } else { + required_len + }; + let target_rows = prepared.batch_size; + + if target_rows == source_rows && target_len == source_len { + return Ok(prepared); + } + + fn trim_matrix( + data: Vec, + source_rows: usize, + source_len: usize, + target_rows: usize, + target_len: usize, + name: &str, + ) -> Result> { + let expected = source_rows.checked_mul(source_len).ok_or_else(|| { + anyhow::anyhow!("{name} source shape [{source_rows},{source_len}] overflows") + })?; + if data.len() != expected { + anyhow::bail!( + "{name} length {} does not match source shape [{},{}]", + data.len(), + source_rows, + source_len + ); + } + + let target_elements = target_rows.checked_mul(target_len).ok_or_else(|| { + anyhow::anyhow!("{name} target shape [{target_rows},{target_len}] overflows") + })?; + let mut trimmed = Vec::with_capacity(target_elements); + for row in 0..target_rows { + let row_start = row * source_len; + trimmed.extend_from_slice(&data[row_start..row_start + target_len]); + } + Ok(trimmed) + } + + let all_input_ids = trim_matrix( + prepared.all_input_ids, + source_rows, + source_len, + target_rows, + target_len, + "input_ids", + )?; + let all_attention_mask = trim_matrix( + prepared.all_attention_mask, + source_rows, + source_len, + target_rows, + target_len, + "attention_mask", + )?; + let all_token_type_ids = prepared + .all_token_type_ids + .map(|ids| { + trim_matrix( + ids, + source_rows, + source_len, + target_rows, + target_len, + "token_type_ids", + ) + }) + .transpose()?; + + Ok(PreparedDocumentBatch { + batch_size: prepared.batch_size, + tensor_batch_size: target_rows, + batch_max_len: target_len, + all_input_ids, + all_attention_mask, + all_token_type_ids, + all_token_ids: prepared.all_token_ids, + original_lengths: prepared.original_lengths, + is_query: prepared.is_query, + filter_skiplist: prepared.filter_skiplist, + original_input_indices: prepared.original_input_indices, + }) +} + +fn split_prepared_batch_rows( + prepared: PreparedDocumentBatch, + max_rows: usize, +) -> Result> { + let max_rows = max_rows.max(1); + if prepared.batch_size <= max_rows && prepared.tensor_batch_size == prepared.batch_size { + return Ok(vec![prepared]); + } + + if prepared.tensor_batch_size < prepared.batch_size { + anyhow::bail!( + "prepared batch has {} tensor rows but {} real documents", + prepared.tensor_batch_size, + prepared.batch_size + ); + } + + let source_rows = prepared.tensor_batch_size; + let source_len = prepared.batch_max_len; + let expected = source_rows.checked_mul(source_len).ok_or_else(|| { + anyhow::anyhow!("prepared batch source shape [{source_rows},{source_len}] overflows") + })?; + if prepared.all_input_ids.len() != expected { + anyhow::bail!( + "input_ids length {} does not match source shape [{},{}]", + prepared.all_input_ids.len(), + source_rows, + source_len + ); + } + if prepared.all_attention_mask.len() != expected { + anyhow::bail!( + "attention_mask length {} does not match source shape [{},{}]", + prepared.all_attention_mask.len(), + source_rows, + source_len + ); + } + if let Some(token_type_ids) = &prepared.all_token_type_ids { + if token_type_ids.len() != expected { + anyhow::bail!( + "token_type_ids length {} does not match source shape [{},{}]", + token_type_ids.len(), + source_rows, + source_len + ); + } + } + + fn slice_rows(data: &[i64], source_len: usize, start: usize, end: usize) -> Vec { + let mut out = Vec::with_capacity((end - start) * source_len); + for row in start..end { + let row_start = row * source_len; + out.extend_from_slice(&data[row_start..row_start + source_len]); + } + out + } + + let PreparedDocumentBatch { + batch_size, + tensor_batch_size: _, + batch_max_len, + all_input_ids, + all_attention_mask, + all_token_type_ids, + all_token_ids, + original_lengths, + is_query, + filter_skiplist, + original_input_indices, + } = prepared; + + let mut chunks = Vec::new(); + let mut start = 0usize; + while start < batch_size { + let end = (start + max_rows).min(batch_size); + let rows = end - start; + chunks.push(PreparedDocumentBatch { + batch_size: rows, + tensor_batch_size: rows, + batch_max_len, + all_input_ids: slice_rows(&all_input_ids, source_len, start, end), + all_attention_mask: slice_rows(&all_attention_mask, source_len, start, end), + all_token_type_ids: all_token_type_ids + .as_ref() + .map(|ids| slice_rows(ids, source_len, start, end)), + all_token_ids: all_token_ids[start..end].to_vec(), + original_lengths: original_lengths[start..end].to_vec(), + is_query, + filter_skiplist, + original_input_indices: if original_input_indices.is_empty() { + Vec::new() + } else { + original_input_indices[start..end].to_vec() + }, + }); + start = end; + } + + Ok(chunks) +} + +fn pad_prepared_batch_rows_for_migraphx_tail( + prepared: PreparedDocumentBatch, + target_rows: usize, + config: &ColbertConfig, +) -> Result { + if target_rows < prepared.batch_size { + anyhow::bail!( + "MIGraphX tail row padding target has {} rows but batch contains {} documents", + target_rows, + prepared.batch_size + ); + } + if prepared.tensor_batch_size < prepared.batch_size { + anyhow::bail!( + "prepared batch has {} tensor rows but {} real documents", + prepared.tensor_batch_size, + prepared.batch_size + ); + } + if target_rows <= prepared.tensor_batch_size { + return Ok(prepared); + } + + let source_rows = prepared.tensor_batch_size; + let source_len = prepared.batch_max_len; + let expected = source_rows.checked_mul(source_len).ok_or_else(|| { + anyhow::anyhow!("prepared batch source shape [{source_rows},{source_len}] overflows") + })?; + if prepared.all_input_ids.len() != expected { + anyhow::bail!( + "input_ids length {} does not match source shape [{},{}]", + prepared.all_input_ids.len(), + source_rows, + source_len + ); + } + if prepared.all_attention_mask.len() != expected { + anyhow::bail!( + "attention_mask length {} does not match source shape [{},{}]", + prepared.all_attention_mask.len(), + source_rows, + source_len + ); + } + if let Some(token_type_ids) = &prepared.all_token_type_ids { + if token_type_ids.len() != expected { + anyhow::bail!( + "token_type_ids length {} does not match source shape [{},{}]", + token_type_ids.len(), + source_rows, + source_len + ); + } + } + + let extra_elements = target_rows + .checked_sub(source_rows) + .and_then(|rows| rows.checked_mul(source_len)) + .ok_or_else(|| { + anyhow::anyhow!( + "MIGraphX tail row padding target shape [{target_rows},{source_len}] overflows" + ) + })?; + let default_input_id = if prepared.is_query && config.do_query_expansion { + config.mask_token_id as i64 + } else { + config.pad_token_id as i64 + }; + let default_attention = if prepared.is_query && config.do_query_expansion { + 1i64 + } else { + 0i64 + }; + + let mut all_input_ids = prepared.all_input_ids; + all_input_ids.extend(std::iter::repeat_n(default_input_id, extra_elements)); + let mut all_attention_mask = prepared.all_attention_mask; + all_attention_mask.extend(std::iter::repeat_n(default_attention, extra_elements)); + let all_token_type_ids = prepared.all_token_type_ids.map(|mut ids| { + ids.extend(std::iter::repeat_n(0, extra_elements)); + ids + }); + + Ok(PreparedDocumentBatch { + batch_size: prepared.batch_size, + tensor_batch_size: target_rows, + batch_max_len: prepared.batch_max_len, + all_input_ids, + all_attention_mask, + all_token_type_ids, + all_token_ids: prepared.all_token_ids, + original_lengths: prepared.original_lengths, + is_query: prepared.is_query, + filter_skiplist: prepared.filter_skiplist, + original_input_indices: prepared.original_input_indices, + }) +} + +impl MigraphxHybrid { + fn new( + model_dir: PathBuf, + quantized: bool, + onnx_path: &Path, + tokenizer: Arc, + config: Arc, + batch_size: usize, + cache_root: PathBuf, + cpu_fallback_parallel: Option, + cpu_fallback_quantized: bool, + min_run_tokens_override: Option, + ) -> Result { + let cpu_sessions = + cpu_fallback_parallel.unwrap_or_else(default_migraphx_cpu_fallback_sessions); + + let document_shapes: HashSet = + build_migraphx_document_static_shapes(batch_size, config.document_length) + .into_iter() + .collect(); + let mut supported_shapes = document_shapes.clone(); + supported_shapes.insert(MigraphxStaticShape { + batch_size: 1, + sequence_length: config.query_length, + }); + + let hybrid = Self { + model_dir, + quantized, + cpu_fallback_quantized, + tokenizer, + config: Arc::clone(&config), + query_length: config.query_length, + document_length: config.document_length, + cpu_fallback_parallel: cpu_sessions, + cpu_model: Mutex::new(None), + cache_root, + model_cache_key: cache_key_for_onnx(onnx_path, quantized), + document_shapes, + supported_shapes, + shape_models: Mutex::new(HashMap::new()), + min_run_tokens: min_run_tokens_override.unwrap_or_else(migraphx_min_run_tokens), + }; + + Ok(hybrid) + } + + fn cpu_model(&self) -> Result { + let mut guard = self.cpu_model.lock().unwrap(); + if guard.is_none() { + let model = ColbertBuilder::new(&self.model_dir) + .with_quantized(self.cpu_fallback_quantized) + .with_parallel(self.cpu_fallback_parallel) + .with_batch_size(1) + .with_dynamic_batch(false) + .with_query_length(self.query_length) + .with_document_length(self.document_length) + .with_execution_provider(ExecutionProvider::Cpu) + .with_migraphx_cold_shape_cpu_fallback(false) + .build() + .context("Failed to build CPU fallback model for MIGraphX cold shapes")?; + *guard = Some(model); + } + Ok(guard + .as_ref() + .expect("CPU fallback model just initialized") + .clone()) + } + + fn shape_cache_dir(&self, shape: MigraphxStaticShape) -> PathBuf { + self.cache_root + .join(&self.model_cache_key) + .join(shape.cache_dir_name()) + } + + fn marker_path(&self, shape: MigraphxStaticShape) -> PathBuf { + self.shape_cache_dir(shape).join("validated-v1") + } + + fn is_supported_shape(&self, shape: MigraphxStaticShape) -> bool { + self.supported_shapes.contains(&shape) + } + + fn is_shape_cache_warm(&self, shape: MigraphxStaticShape) -> bool { + if !self.is_supported_shape(shape) { + return false; + } + let cache_dir = self.shape_cache_dir(shape); + self.marker_path(shape).exists() && shape_cache_has_mxr(&cache_dir) + } + + fn warm_tail_shape_for_prepared_with( + &self, + prepared: &PreparedDocumentBatch, + can_pad: impl Fn(usize, usize) -> bool, + ) -> Option { + if prepared.is_query + || prepared.batch_size == 0 + || prepared.tensor_batch_size > prepared.batch_size + { + return None; + } + + self.supported_shapes + .iter() + .copied() + .filter(|shape| { + if shape.sequence_length != prepared.batch_max_len { + return false; + } + if is_force_gpu() { + can_pad_migraphx_warm_tail_rows_with_factor( + prepared.batch_size, + shape.batch_size, + usize::MAX, + ) + } else { + can_pad(prepared.batch_size, shape.batch_size) + } + }) + .filter(|shape| self.is_shape_cache_warm(*shape)) + .min_by_key(|shape| shape.batch_size) + } + + fn warm_tail_shape_for_prepared( + &self, + prepared: &PreparedDocumentBatch, + ) -> Option { + self.warm_tail_shape_for_prepared_with(prepared, can_pad_migraphx_warm_tail_rows) + } + + fn active_warm_tail_shape_for_prepared( + &self, + prepared: &PreparedDocumentBatch, + ) -> Option { + self.warm_tail_shape_for_prepared_with(prepared, can_pad_migraphx_active_warm_tail_rows) + } + + fn warm_tail_shape_model_if_warm( + &self, + prepared: &PreparedDocumentBatch, + ) -> Result> { + let Some(shape) = self.warm_tail_shape_for_prepared(prepared) else { + return Ok(None); + }; + Ok(self.shape_model_if_warm(shape)?.map(|model| (shape, model))) + } + + fn warm_route_for_prepared( + &self, + prepared: &PreparedDocumentBatch, + active_gpu_lane: bool, + ) -> Option<(&'static str, MigraphxStaticShape, usize)> { + let exact = MigraphxStaticShape { + batch_size: prepared.tensor_batch_size, + sequence_length: prepared.batch_max_len, + }; + if self.is_shape_cache_warm(exact) { + return Some(("gpu-exact", exact, prepared.tensor_batch_size)); + } + if active_gpu_lane { + self.active_warm_tail_shape_for_prepared(prepared) + } else { + self.warm_tail_shape_for_prepared(prepared) + } + .map(|shape| ("gpu-tail", shape, shape.batch_size)) + } + + fn warm_row_sizes_for_sequence(&self, sequence_length: usize) -> Vec { + let mut rows = self + .supported_shapes + .iter() + .copied() + .filter(|shape| shape.sequence_length == sequence_length) + .filter(|shape| self.is_shape_cache_warm(*shape)) + .map(|shape| shape.batch_size) + .collect::>(); + rows.sort_unstable(); + rows.dedup(); + rows + } + + fn split_prepared_for_warm_shapes( + &self, + prepared: PreparedDocumentBatch, + ) -> Result> { + if prepared.is_query + || prepared.batch_size <= 1 + || prepared.tensor_batch_size != prepared.batch_size + { + return Ok(vec![prepared]); + } + + let warm_rows = self.warm_row_sizes_for_sequence(prepared.batch_max_len); + if warm_rows.is_empty() || warm_rows.binary_search(&prepared.batch_size).is_ok() { + return Ok(vec![prepared]); + } + + let Some(whole_rows) = migraphx_padded_rows_for_warm_shape(prepared.batch_size, &warm_rows) + else { + return Ok(vec![prepared]); + }; + const MIN_MIGRAPHX_SPLIT_ROWS: usize = 64; + let Some(split_rows) = warm_rows + .iter() + .copied() + .filter(|rows| *rows < prepared.batch_size && *rows >= MIN_MIGRAPHX_SPLIT_ROWS) + .max() + else { + return Ok(vec![prepared]); + }; + + let chunks = split_prepared_batch_rows(prepared.clone(), split_rows)?; + if chunks.len() <= 1 { + return Ok(vec![prepared]); + } + + let Some(split_padded_rows) = chunks.iter().try_fold(0usize, |total, chunk| { + migraphx_padded_rows_for_warm_shape(chunk.batch_size, &warm_rows) + .map(|rows| total.saturating_add(rows)) + }) else { + return Ok(vec![prepared]); + }; + + const EXTRA_GPU_RUN_PENALTY_TOKENS: usize = 8_192; + let sequence_length = prepared.batch_max_len.max(1); + let whole_tokens = whole_rows.saturating_mul(sequence_length); + let split_tokens = split_padded_rows + .saturating_mul(sequence_length) + .saturating_add(EXTRA_GPU_RUN_PENALTY_TOKENS.saturating_mul(chunks.len() - 1)); + + if split_tokens < whole_tokens { + Ok(chunks) + } else { + Ok(vec![prepared]) + } + } + + fn split_prepared_batches_for_warm_shapes( + &self, + prepared_batches: Vec, + ) -> Result> { + if !migraphx_split_warm_batches_enabled() { + return Ok(prepared_batches); + } + + let mut out = Vec::with_capacity(prepared_batches.len()); + for prepared in prepared_batches { + out.extend(self.split_prepared_for_warm_shapes(prepared)?); + } + Ok(out) + } + + fn invalidate_shape_cache(&self, shape: MigraphxStaticShape) { + let _ = fs::remove_file(self.marker_path(shape)); + self.shape_models.lock().unwrap().remove(&shape); + } + + fn bail_if_force_gpu_would_fallback_to_cpu( + &self, + prepared: &PreparedDocumentBatch, + reason: &str, + ) -> Result<()> { + if is_force_gpu() { + anyhow::bail!( + "NEXT_PLAID_FORCE_GPU is set, but MIGraphX hybrid indexing would fall back to CPU for input shape [{} real row(s), {} tensor row(s), {} token(s)]: {}. Warm a matching static-shape cache, adjust --batch-size, or unset --force-gpu to allow CPU fallback.", + prepared.batch_size, + prepared.tensor_batch_size, + prepared.batch_max_len, + reason, + ); + } + Ok(()) + } + + fn build_shape_model(&self, shape: MigraphxStaticShape) -> Result { + let cache_dir = self.shape_cache_dir(shape); + fs::create_dir_all(&cache_dir).with_context(|| { + format!( + "Failed to create MIGraphX cache directory {}", + cache_dir.display() + ) + })?; + + ColbertBuilder::new(&self.model_dir) + .with_quantized(self.quantized) + .with_parallel(1) + .with_batch_size(shape.batch_size) + .with_dynamic_batch(false) + .with_query_length(self.query_length) + .with_document_length(self.document_length) + .with_execution_provider(ExecutionProvider::MIGraphX) + .with_migraphx_static_shape(shape.batch_size, shape.sequence_length) + .with_migraphx_model_cache_dir(cache_dir) + .with_migraphx_cold_shape_cpu_fallback(false) + .build() + } + + fn shape_model_if_warm(&self, shape: MigraphxStaticShape) -> Result> { + if !self.is_shape_cache_warm(shape) { + return Ok(None); + } + + if let Some(model) = self.shape_models.lock().unwrap().get(&shape).cloned() { + return Ok(Some(model)); + } + + let cache_dir = self.shape_cache_dir(shape); + let _lock = acquire_migraphx_shape_cache_lock(&cache_dir)?; + if !self.is_shape_cache_warm(shape) { + return Ok(None); + } + + if let Some(model) = self.shape_models.lock().unwrap().get(&shape).cloned() { + return Ok(Some(model)); + } + + let model = self.build_shape_model(shape).with_context(|| { + format!( + "Failed to build warm MIGraphX static-shape model for {:?}", + shape + ) + })?; + self.shape_models + .lock() + .unwrap() + .insert(shape, model.clone()); + Ok(Some(model)) + } + + fn dummy_prepared_batch(&self, shape: MigraphxStaticShape) -> PreparedDocumentBatch { + let element_count = shape.batch_size * shape.sequence_length; + let token_id = self.config.mask_token_id; + PreparedDocumentBatch { + batch_size: shape.batch_size, + tensor_batch_size: shape.batch_size, + batch_max_len: shape.sequence_length, + all_input_ids: vec![token_id as i64; element_count], + all_attention_mask: vec![1; element_count], + all_token_type_ids: if self.config.uses_token_type_ids { + Some(vec![0; element_count]) + } else { + None + }, + all_token_ids: vec![vec![token_id; shape.sequence_length]; shape.batch_size], + original_lengths: vec![shape.sequence_length; shape.batch_size], + is_query: false, + filter_skiplist: false, + original_input_indices: Vec::new(), + } + } + + fn warm_shape(&self, shape: MigraphxStaticShape) -> Result<()> { + if !self.is_supported_shape(shape) { + anyhow::bail!( + "MIGraphX static shape {:?} is not in the supported shape set", + shape + ); + } + if self.is_shape_cache_warm(shape) { + return Ok(()); + } + + let cache_dir = self.shape_cache_dir(shape); + let _lock = acquire_migraphx_shape_cache_lock(&cache_dir)?; + if self.is_shape_cache_warm(shape) { + return Ok(()); + } + + let model = self.build_shape_model(shape)?; + let prepared = self.dummy_prepared_batch(shape); + model + .encode_prepared_documents(prepared) + .with_context(|| format!("Failed to validate MIGraphX static shape {:?}", shape))?; + write_migraphx_validation_marker( + &self.marker_path(shape), + &format!( + "validated-v1\nshape={}x{}\n", + shape.batch_size, shape.sequence_length + ), + ) + .context("Failed to write MIGraphX shape-cache validation marker")?; + self.shape_models.lock().unwrap().insert(shape, model); + Ok(()) + } + + fn encode_one_prepared( + self: &Arc, + prepared: PreparedDocumentBatch, + ) -> Result>> { + let shape = MigraphxStaticShape { + batch_size: prepared.tensor_batch_size, + sequence_length: prepared.batch_max_len, + }; + + if let Some(model) = self.shape_model_if_warm(shape)? { + let cpu_fallback = prepared.clone(); + match model.encode_prepared_documents(prepared) { + Ok(embeddings) => { + return Ok(embeddings); + } + Err(err) => { + self.invalidate_shape_cache(shape); + self.bail_if_force_gpu_would_fallback_to_cpu( + &cpu_fallback, + &format!("warm MIGraphX static shape {shape:?} failed: {err}"), + )?; + return self.cpu_model()?.encode_prepared_documents( + trim_prepared_batch_for_cpu_fallback(cpu_fallback)?, + ); + } + } + } + + if let Some((tail_shape, model)) = self.warm_tail_shape_model_if_warm(&prepared)? { + let cpu_fallback = prepared.clone(); + let padded = pad_prepared_batch_rows_for_migraphx_tail( + prepared, + tail_shape.batch_size, + &self.config, + )?; + match model.encode_prepared_documents(padded) { + Ok(embeddings) => { + return Ok(embeddings); + } + Err(err) => { + self.invalidate_shape_cache(tail_shape); + self.bail_if_force_gpu_would_fallback_to_cpu( + &cpu_fallback, + &format!("warm MIGraphX tail static shape {tail_shape:?} failed: {err}"), + )?; + return self.cpu_model()?.encode_prepared_documents( + trim_prepared_batch_for_cpu_fallback(cpu_fallback)?, + ); + } + } + } + + self.bail_if_force_gpu_would_fallback_to_cpu( + &prepared, + &format!("no warm MIGraphX static shape for {shape:?}"), + )?; + self.cpu_model()? + .encode_prepared_documents(trim_prepared_batch_for_cpu_fallback(prepared)?) + } + + fn encode_gpu_jobs( + self: &Arc, + jobs: Vec, + ) -> Result>)>> { + let mut encoded_segments = Vec::with_capacity(jobs.len()); + for job in jobs { + let MigraphxGpuJob { + batch_idx, + cache_shape, + route, + prepared, + cpu_fallback, + } = job; + + let model = match self.shape_model_if_warm(cache_shape) { + Ok(Some(model)) => model, + Ok(None) => { + self.bail_if_force_gpu_would_fallback_to_cpu( + &cpu_fallback, + &format!( + "warm MIGraphX {route} static shape {cache_shape:?} disappeared before execution" + ), + )?; + let embeddings = self.cpu_model()?.encode_prepared_documents( + trim_prepared_batch_for_cpu_fallback(cpu_fallback)?, + )?; + encoded_segments.push((batch_idx, embeddings)); + continue; + } + Err(err) => { + self.invalidate_shape_cache(cache_shape); + self.bail_if_force_gpu_would_fallback_to_cpu( + &cpu_fallback, + &format!( + "failed to load warm MIGraphX {route} static shape {cache_shape:?}: {err}" + ), + )?; + let embeddings = self.cpu_model()?.encode_prepared_documents( + trim_prepared_batch_for_cpu_fallback(cpu_fallback)?, + )?; + encoded_segments.push((batch_idx, embeddings)); + continue; + } + }; + + match model.encode_prepared_documents(prepared) { + Ok(embeddings) => encoded_segments.push((batch_idx, embeddings)), + Err(err) => { + self.invalidate_shape_cache(cache_shape); + self.bail_if_force_gpu_would_fallback_to_cpu( + &cpu_fallback, + &format!( + "warm MIGraphX {route} static shape {cache_shape:?} failed: {err}" + ), + )?; + let embeddings = self.cpu_model()?.encode_prepared_documents( + trim_prepared_batch_for_cpu_fallback(cpu_fallback)?, + )?; + encoded_segments.push((batch_idx, embeddings)); + } + } + } + Ok(encoded_segments) + } + + fn encode_cpu_fallback_batches( + &self, + cpu_batches: Vec<(usize, PreparedDocumentBatch)>, + ) -> Result>)>> { + if cpu_batches.is_empty() { + return Ok(Vec::new()); + } + + let counts: Vec<(usize, usize)> = cpu_batches + .iter() + .map(|(idx, batch)| (*idx, batch.batch_size)) + .collect(); + let mut batches = Vec::new(); + for (_, batch) in cpu_batches { + for chunk in split_prepared_batch_rows(batch, DEFAULT_CPU_BATCH_SIZE)? { + batches.push(trim_prepared_batch_for_cpu_fallback(chunk)?); + } + } + let cpu_model = self.cpu_model()?; + let cpu_encoded = cpu_model.encode_prepared_batches_unordered(batches)?; + let mut iter = cpu_encoded.into_iter(); + let mut encoded_segments = Vec::with_capacity(counts.len()); + for (batch_idx, count) in counts { + let mut embeddings = Vec::with_capacity(count); + for _ in 0..count { + embeddings.push(iter.next().ok_or_else(|| { + anyhow::anyhow!( + "CPU fallback returned fewer embeddings than expected for MIGraphX hybrid batch" + ) + })?); + } + encoded_segments.push((batch_idx, embeddings)); + } + if iter.next().is_some() { + anyhow::bail!( + "CPU fallback returned more embeddings than expected for MIGraphX hybrid batches" + ); + } + Ok(encoded_segments) + } + + fn encode_prepared_document_batches( + self: &Arc, + prepared_batches: Vec, + ) -> Result>> { + let prepared_batches = self.split_prepared_batches_for_warm_shapes(prepared_batches)?; + let mut combined_indices: Vec = + Vec::with_capacity(prepared_batches.iter().map(|b| b.batch_size).sum()); + let mut has_reordering = false; + for batch in &prepared_batches { + if !batch.original_input_indices.is_empty() { + combined_indices.extend_from_slice(&batch.original_input_indices); + has_reordering = true; + } + } + + let warm_gpu_tokens: usize = prepared_batches + .iter() + .filter_map(|prepared| self.warm_route_for_prepared(prepared, false)) + .map(|(_, shape, tensor_rows)| tensor_rows.saturating_mul(shape.sequence_length)) + .sum(); + let use_gpu_lane = is_force_gpu() || warm_gpu_tokens >= self.min_run_tokens; + + let mut gpu_jobs: Vec = Vec::new(); + let mut cpu_batches: Vec<(usize, PreparedDocumentBatch)> = Vec::new(); + let trace_enabled = migraphx_trace_enabled(); + let mut route_summary: BTreeMap = BTreeMap::new(); + + for (batch_idx, prepared) in prepared_batches.into_iter().enumerate() { + let shape = MigraphxStaticShape { + batch_size: prepared.tensor_batch_size, + sequence_length: prepared.batch_max_len, + }; + + if !use_gpu_lane { + if trace_enabled { + record_migraphx_route( + &mut route_summary, + "cpu-below-gpu-threshold", + shape, + prepared.batch_size, + prepared.tensor_batch_size, + ); + } + cpu_batches.push((batch_idx, prepared)); + continue; + } + + if let Some((route, warm_shape, tensor_rows)) = + self.warm_route_for_prepared(&prepared, use_gpu_lane) + { + let cpu_fallback = prepared.clone(); + let prepared = if warm_shape == shape { + prepared + } else { + pad_prepared_batch_rows_for_migraphx_tail( + prepared, + warm_shape.batch_size, + &self.config, + )? + }; + if trace_enabled { + record_migraphx_route( + &mut route_summary, + route, + warm_shape, + cpu_fallback.batch_size, + tensor_rows, + ); + } + gpu_jobs.push(MigraphxGpuJob { + batch_idx, + cache_shape: warm_shape, + route, + prepared, + cpu_fallback, + }); + } else { + self.bail_if_force_gpu_would_fallback_to_cpu( + &prepared, + &format!("no warm MIGraphX static shape for {shape:?}"), + )?; + if trace_enabled { + record_migraphx_route( + &mut route_summary, + "cpu-cold", + shape, + prepared.batch_size, + prepared.tensor_batch_size, + ); + } + cpu_batches.push((batch_idx, prepared)); + } + } + + if trace_enabled && !route_summary.is_empty() { + let parts = route_summary + .iter() + .map(|(key, (batches, real_rows, tensor_rows))| { + format!( + "{key}:batches={batches},real_rows={real_rows},tensor_rows={tensor_rows}" + ) + }) + .collect::>() + .join(" "); + eprintln!("__MIGRAPHX_HYBRID_TRACE__ {parts}"); + } + + let mut encoded_segments: Vec<(usize, Vec>)> = Vec::new(); + if gpu_jobs.is_empty() { + encoded_segments.extend(self.encode_cpu_fallback_batches(cpu_batches)?); + } else if cpu_batches.is_empty() { + encoded_segments.extend(self.encode_gpu_jobs(gpu_jobs)?); + } else { + let gpu_self = Arc::clone(self); + let (gpu_result, cpu_result) = std::thread::scope(|scope| { + let gpu_handle = scope.spawn(move || gpu_self.encode_gpu_jobs(gpu_jobs)); + let cpu_result = self.encode_cpu_fallback_batches(cpu_batches); + let gpu_result = gpu_handle.join().unwrap(); + (gpu_result, cpu_result) + }); + encoded_segments.extend(gpu_result?); + encoded_segments.extend(cpu_result?); + } + + encoded_segments.sort_by_key(|(batch_idx, _)| *batch_idx); + let mut encoded = Vec::new(); + for (_, embeddings) in encoded_segments { + encoded.extend(embeddings); + } + restore_original_input_order(encoded, combined_indices, has_reordering) + } + + fn encode_queries( + self: &Arc, + queries: &[&str], + batch_size: usize, + ) -> Result>> { + let _ = batch_size; + let mut encoded = Vec::with_capacity(queries.len()); + for query in queries { + let processed = preprocess_texts(&self.config, &[*query]); + let tokenized = tokenize_processed_texts_individually(&self.tokenizer, &processed)?; + let prepared = prepare_batch_from_tokenized_documents( + &self.tokenizer, + &self.config, + tokenized, + true, + false, + Vec::new(), + Some(FixedDynamicShape { + docs: 1, + planned_len: self.query_length, + }), + )?; + encoded.extend(self.encode_one_prepared(prepared)?); + } + Ok(encoded) + } + + fn warm_default_shapes(&self, max_sequence_len: usize) -> Result { + let mut shapes: Vec<_> = self.document_shapes.iter().copied().collect(); + shapes.retain(|shape| shape.sequence_length <= max_sequence_len); + shapes.sort_by_key(|shape| (shape.sequence_length, shape.batch_size)); + let mut warmed = 0; + for shape in shapes { + self.warm_shape(shape)?; + warmed += 1; + } + Ok(warmed) + } + + fn warm_all_default_shapes(&self) -> Result { + let mut shapes: Vec<_> = self.document_shapes.iter().copied().collect(); + shapes.sort_by_key(|shape| (shape.sequence_length, shape.batch_size)); + let mut warmed = 0; + for shape in shapes { + self.warm_shape(shape)?; + warmed += 1; + } + Ok(warmed) + } +} + +fn update_token_ids(config: &mut ColbertConfig, tokenizer: &Tokenizer) { + if config.mask_token_id == default_mask_token_id() { + if let Some(mask_id) = tokenizer.token_to_id("[MASK]") { + config.mask_token_id = mask_id; + } else if let Some(mask_id) = tokenizer.token_to_id("") { + config.mask_token_id = mask_id; + } + } + if config.pad_token_id == default_pad_token_id() { + if let Some(pad_id) = tokenizer.token_to_id("[PAD]") { + config.pad_token_id = pad_id; + } else if let Some(pad_id) = tokenizer.token_to_id("") { + config.pad_token_id = pad_id; } } } @@ -1966,6 +4201,7 @@ fn prepare_batch_for_session( if texts.is_empty() { return Ok(PreparedDocumentBatch { batch_size: 0, + tensor_batch_size: 0, batch_max_len: 0, all_input_ids: Vec::new(), all_attention_mask: Vec::new(), @@ -2001,6 +4237,7 @@ fn prepare_batch_from_tokenized_documents( is_query: bool, filter_skiplist: bool, original_input_indices: Vec, + planned_shape: Option, ) -> Result { let (prefix_str, prefix_token_id_opt, max_length) = if is_query { ( @@ -2041,6 +4278,25 @@ fn prepare_batch_from_tokenized_documents( } let batch_size = batch_docs.len(); + let (tensor_batch_size, batch_max_len) = if let Some(shape) = planned_shape { + if shape.docs < batch_size { + anyhow::bail!( + "planned batch shape has {} rows but batch contains {} documents", + shape.docs, + batch_size + ); + } + if shape.planned_len < batch_max_len { + anyhow::bail!( + "planned batch shape has sequence length {} but batch requires {} tokens", + shape.planned_len, + batch_max_len + ); + } + (shape.docs, shape.planned_len) + } else { + (batch_size, batch_max_len) + }; let default_input_id = if is_query && config.do_query_expansion { config.mask_token_id as i64 } else { @@ -2051,9 +4307,10 @@ fn prepare_batch_from_tokenized_documents( } else { 0i64 }; - let mut all_input_ids: Vec = vec![default_input_id; batch_size * batch_max_len]; - let mut all_attention_mask: Vec = vec![default_attention; batch_size * batch_max_len]; - let mut all_token_type_ids: Vec = vec![0; batch_size * batch_max_len]; + let mut all_input_ids: Vec = vec![default_input_id; tensor_batch_size * batch_max_len]; + let mut all_attention_mask: Vec = + vec![default_attention; tensor_batch_size * batch_max_len]; + let mut all_token_type_ids: Vec = vec![0; tensor_batch_size * batch_max_len]; let mut all_token_ids: Vec> = Vec::with_capacity(batch_size); let mut original_lengths: Vec = Vec::with_capacity(batch_size); @@ -2102,6 +4359,7 @@ fn prepare_batch_from_tokenized_documents( Ok(PreparedDocumentBatch { batch_size, + tensor_batch_size, batch_max_len, all_input_ids, all_attention_mask, @@ -2180,6 +4438,7 @@ fn prepare_batch_from_tokenizer_encodings( } let batch_size = batch_encodings.len(); + let tensor_batch_size = batch_size; let default_input_id = if is_query && config.do_query_expansion { config.mask_token_id as i64 } else { @@ -2246,6 +4505,7 @@ fn prepare_batch_from_tokenizer_encodings( Ok(PreparedDocumentBatch { batch_size, + tensor_batch_size, batch_max_len, all_input_ids, all_attention_mask, @@ -2273,6 +4533,7 @@ fn encode_prepared_batch_with_session( ) -> Result>> { let PreparedDocumentBatch { batch_size, + tensor_batch_size, batch_max_len, all_input_ids, all_attention_mask, @@ -2288,12 +4549,12 @@ fn encode_prepared_batch_with_session( return Ok(Vec::new()); } - let input_ids_tensor = Tensor::from_array(([batch_size, batch_max_len], all_input_ids))?; + let input_ids_tensor = Tensor::from_array(([tensor_batch_size, batch_max_len], all_input_ids))?; let attention_mask_tensor = - Tensor::from_array(([batch_size, batch_max_len], all_attention_mask))?; + Tensor::from_array(([tensor_batch_size, batch_max_len], all_attention_mask))?; let token_type_ids_tensor = all_token_type_ids - .map(|ids| Tensor::from_array(([batch_size, batch_max_len], ids))) + .map(|ids| Tensor::from_array(([tensor_batch_size, batch_max_len], ids))) .transpose()?; let (shape_slice, output_owned): (Vec, Vec) = @@ -2318,7 +4579,27 @@ fn encode_prepared_batch_with_session( (output_shape.to_vec(), output_data.to_vec()) }; - let embedding_dim = shape_slice[2] as usize; + if shape_slice.len() != 3 { + anyhow::bail!( + "ONNX output tensor has rank {} but expected rank 3 for input shape [{},{}]", + shape_slice.len(), + tensor_batch_size, + batch_max_len + ); + } + let output_batch_size = + usize::try_from(shape_slice[0]).context("Negative output batch size")?; + let output_sequence_len = + usize::try_from(shape_slice[1]).context("Negative output sequence length")?; + let embedding_dim = usize::try_from(shape_slice[2]).context("Negative embedding dimension")?; + if output_batch_size != tensor_batch_size || output_sequence_len != batch_max_len { + anyhow::bail!( + "ONNX output shape {:?} does not match input shape [{},{}]. Clear any stale execution-provider model cache and retry.", + shape_slice, + tensor_batch_size, + batch_max_len + ); + } let output_data = &output_owned; let mut all_embeddings = Vec::with_capacity(batch_size); @@ -2620,6 +4901,7 @@ mod tests { assert_ne!(ExecutionProvider::Cuda, ExecutionProvider::TensorRT); assert_ne!(ExecutionProvider::TensorRT, ExecutionProvider::CoreML); assert_ne!(ExecutionProvider::CoreML, ExecutionProvider::DirectML); + assert_ne!(ExecutionProvider::DirectML, ExecutionProvider::MIGraphX); } #[test] @@ -2636,6 +4918,451 @@ mod tests { assert_eq!(debug_str, "Cuda"); } + #[test] + fn test_execution_provider_display_names() { + assert_eq!(ExecutionProvider::Auto.display_name(), "auto"); + assert_eq!(ExecutionProvider::Cpu.display_name(), "CPU"); + assert_eq!(ExecutionProvider::Cuda.display_name(), "CUDA"); + assert_eq!(ExecutionProvider::TensorRT.display_name(), "TensorRT"); + assert_eq!(ExecutionProvider::CoreML.display_name(), "CoreML"); + assert_eq!(ExecutionProvider::DirectML.display_name(), "DirectML"); + assert_eq!(ExecutionProvider::MIGraphX.display_name(), "MIGraphX"); + } + + #[test] + fn test_execution_provider_gpu_classification() { + assert!(!ExecutionProvider::Auto.is_gpu()); + assert!(!ExecutionProvider::Cpu.is_gpu()); + assert!(ExecutionProvider::Cuda.is_gpu()); + assert!(ExecutionProvider::TensorRT.is_gpu()); + assert!(ExecutionProvider::CoreML.is_gpu()); + assert!(ExecutionProvider::DirectML.is_gpu()); + assert!(ExecutionProvider::MIGraphX.is_gpu()); + } + + #[test] + fn test_execution_provider_compiled_flags() { + assert!(is_execution_provider_compiled(ExecutionProvider::Auto)); + assert!(is_execution_provider_compiled(ExecutionProvider::Cpu)); + assert_eq!( + is_execution_provider_compiled(ExecutionProvider::Cuda), + cfg!(feature = "cuda") + ); + assert_eq!( + is_execution_provider_compiled(ExecutionProvider::TensorRT), + cfg!(feature = "tensorrt") + ); + assert_eq!( + is_execution_provider_compiled(ExecutionProvider::CoreML), + cfg!(feature = "coreml") + ); + assert_eq!( + is_execution_provider_compiled(ExecutionProvider::DirectML), + cfg!(feature = "directml") + ); + assert_eq!( + is_execution_provider_compiled(ExecutionProvider::MIGraphX), + cfg!(feature = "migraphx") + ); + } + + #[test] + fn test_compiled_gpu_execution_provider_order() { + let expected = GPU_PROVIDER_ORDER + .iter() + .copied() + .filter(|provider| is_execution_provider_compiled(*provider)) + .collect::>(); + + assert_eq!(compiled_gpu_execution_providers(), expected); + assert_eq!(compiled_gpu_execution_provider(), expected.first().copied()); + } + + #[test] + #[cfg(not(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "migraphx" + )))] + fn test_require_gpu_execution_provider_without_gpu_features() { + let error = require_gpu_execution_provider().unwrap_err().to_string(); + assert!(error.contains("GPU execution requested")); + assert!(error.contains("no GPU execution provider was compiled")); + } + + // ========================================================================= + // MIGraphX CPU fallback tests + // ========================================================================= + + #[test] + fn test_trim_prepared_batch_for_cpu_fallback_removes_padding() { + let prepared = PreparedDocumentBatch { + batch_size: 2, + tensor_batch_size: 4, + batch_max_len: 8, + all_input_ids: (0..32).collect(), + all_attention_mask: (100..132).collect(), + all_token_type_ids: Some((200..232).collect()), + all_token_ids: vec![vec![1, 2, 3], vec![4, 5, 6, 7, 8]], + original_lengths: vec![3, 5], + is_query: false, + filter_skiplist: true, + original_input_indices: vec![1, 0], + }; + + let trimmed = trim_prepared_batch_for_cpu_fallback(prepared).unwrap(); + + assert_eq!(trimmed.batch_size, 2); + assert_eq!(trimmed.tensor_batch_size, 2); + assert_eq!(trimmed.batch_max_len, 5); + assert_eq!(trimmed.all_input_ids, vec![0, 1, 2, 3, 4, 8, 9, 10, 11, 12]); + assert_eq!( + trimmed.all_attention_mask, + vec![100, 101, 102, 103, 104, 108, 109, 110, 111, 112] + ); + assert_eq!( + trimmed.all_token_type_ids, + Some(vec![200, 201, 202, 203, 204, 208, 209, 210, 211, 212]) + ); + assert_eq!( + trimmed.all_token_ids, + vec![vec![1, 2, 3], vec![4, 5, 6, 7, 8]] + ); + assert_eq!(trimmed.original_lengths, vec![3, 5]); + assert_eq!(trimmed.original_input_indices, vec![1, 0]); + } + + #[test] + fn test_trim_prepared_batch_for_cpu_fallback_preserves_query_length() { + let prepared = PreparedDocumentBatch { + batch_size: 1, + tensor_batch_size: 4, + batch_max_len: 6, + all_input_ids: (0..24).collect(), + all_attention_mask: vec![1; 24], + all_token_type_ids: None, + all_token_ids: vec![vec![1, 2, 3]], + original_lengths: vec![3], + is_query: true, + filter_skiplist: false, + original_input_indices: Vec::new(), + }; + + let trimmed = trim_prepared_batch_for_cpu_fallback(prepared).unwrap(); + + assert_eq!(trimmed.batch_size, 1); + assert_eq!(trimmed.tensor_batch_size, 1); + assert_eq!(trimmed.batch_max_len, 6); + assert_eq!(trimmed.all_input_ids, vec![0, 1, 2, 3, 4, 5]); + assert_eq!(trimmed.all_attention_mask, vec![1; 6]); + assert_eq!(trimmed.all_token_type_ids, None); + } + + #[test] + fn test_split_prepared_batch_rows_preserves_row_data() { + let prepared = PreparedDocumentBatch { + batch_size: 5, + tensor_batch_size: 5, + batch_max_len: 3, + all_input_ids: (0..15).collect(), + all_attention_mask: (100..115).collect(), + all_token_type_ids: Some((200..215).collect()), + all_token_ids: vec![vec![0], vec![1], vec![2], vec![3], vec![4]], + original_lengths: vec![1, 2, 3, 1, 2], + is_query: false, + filter_skiplist: true, + original_input_indices: vec![4, 3, 2, 1, 0], + }; + + let chunks = split_prepared_batch_rows(prepared, 2).unwrap(); + + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0].batch_size, 2); + assert_eq!(chunks[0].tensor_batch_size, 2); + assert_eq!(chunks[0].all_input_ids, vec![0, 1, 2, 3, 4, 5]); + assert_eq!( + chunks[0].all_attention_mask, + vec![100, 101, 102, 103, 104, 105] + ); + assert_eq!( + chunks[0].all_token_type_ids, + Some(vec![200, 201, 202, 203, 204, 205]) + ); + assert_eq!(chunks[0].all_token_ids, vec![vec![0], vec![1]]); + assert_eq!(chunks[0].original_lengths, vec![1, 2]); + assert_eq!(chunks[0].original_input_indices, vec![4, 3]); + + assert_eq!(chunks[1].batch_size, 2); + assert_eq!(chunks[1].all_input_ids, vec![6, 7, 8, 9, 10, 11]); + assert_eq!(chunks[1].all_token_ids, vec![vec![2], vec![3]]); + assert_eq!(chunks[1].original_input_indices, vec![2, 1]); + + assert_eq!(chunks[2].batch_size, 1); + assert_eq!(chunks[2].all_input_ids, vec![12, 13, 14]); + assert_eq!(chunks[2].all_token_ids, vec![vec![4]]); + assert_eq!(chunks[2].original_input_indices, vec![0]); + } + + #[test] + fn test_pad_prepared_batch_rows_for_migraphx_tail_adds_dummy_rows() { + let config = ColbertConfig { + pad_token_id: 99, + ..Default::default() + }; + let prepared = PreparedDocumentBatch { + batch_size: 2, + tensor_batch_size: 2, + batch_max_len: 3, + all_input_ids: vec![1, 2, 3, 4, 5, 6], + all_attention_mask: vec![1, 1, 1, 1, 1, 0], + all_token_type_ids: Some(vec![0, 0, 0, 0, 0, 0]), + all_token_ids: vec![vec![1, 2, 3], vec![4, 5]], + original_lengths: vec![3, 2], + is_query: false, + filter_skiplist: true, + original_input_indices: vec![0, 1], + }; + + let padded = pad_prepared_batch_rows_for_migraphx_tail(prepared, 4, &config).unwrap(); + + assert_eq!(padded.batch_size, 2); + assert_eq!(padded.tensor_batch_size, 4); + assert_eq!(padded.batch_max_len, 3); + assert_eq!( + padded.all_input_ids, + vec![1, 2, 3, 4, 5, 6, 99, 99, 99, 99, 99, 99] + ); + assert_eq!( + padded.all_attention_mask, + vec![1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!(padded.all_token_type_ids, Some(vec![0; 12])); + assert_eq!(padded.all_token_ids, vec![vec![1, 2, 3], vec![4, 5]]); + assert_eq!(padded.original_lengths, vec![3, 2]); + assert_eq!(padded.original_input_indices, vec![0, 1]); + } + + #[test] + fn test_can_pad_migraphx_warm_tail_rows_uses_bounded_factor() { + assert!(can_pad_migraphx_warm_tail_rows(8, 16)); + assert!(can_pad_migraphx_warm_tail_rows(3, 4)); + assert!(!can_pad_migraphx_warm_tail_rows(7, 16)); + assert!(!can_pad_migraphx_warm_tail_rows(2, 64)); + assert!(!can_pad_migraphx_warm_tail_rows(10, 512)); + assert!(!can_pad_migraphx_warm_tail_rows(7, 7)); + assert!(!can_pad_migraphx_warm_tail_rows(3, 256)); + assert!(!can_pad_migraphx_warm_tail_rows(0, 16)); + } + + #[test] + fn test_migraphx_document_shapes_keep_only_expensive_planned_shapes() { + let shapes = build_migraphx_document_static_shapes_with_min_tokens(64, 2048, 65_536); + assert_eq!( + shapes, + vec![ + MigraphxStaticShape { + batch_size: 1024, + sequence_length: 128, + }, + MigraphxStaticShape { + batch_size: 512, + sequence_length: 256, + }, + MigraphxStaticShape { + batch_size: 256, + sequence_length: 512, + }, + MigraphxStaticShape { + batch_size: 128, + sequence_length: 1024, + }, + MigraphxStaticShape { + batch_size: 64, + sequence_length: 2048, + }, + ] + ); + assert!(!shapes.contains(&MigraphxStaticShape { + batch_size: 1, + sequence_length: 128, + })); + assert!(!shapes.contains(&MigraphxStaticShape { + batch_size: 16, + sequence_length: 128, + })); + assert!(shapes.contains(&MigraphxStaticShape { + batch_size: 512, + sequence_length: 256, + })); + assert!(!shapes.contains(&MigraphxStaticShape { + batch_size: 16, + sequence_length: 256, + })); + assert!(!shapes.contains(&MigraphxStaticShape { + batch_size: 2, + sequence_length: 2048, + })); + } + + #[test] + fn test_migraphx_cache_key_includes_precision_options() { + let path = std::path::Path::new("/tmp/next-plaid-cache-key-test/model.onnx"); + + let fp32_key = cache_key_for_onnx_with_options( + path, + false, + MigraphxCacheOptions { + entries: Vec::new(), + }, + ); + let fp16_key = cache_key_for_onnx_with_options( + path, + false, + MigraphxCacheOptions { + entries: vec![("migraphx_fp16_enable".to_string(), "1".to_string())], + }, + ); + let ort_fp16_key = cache_key_for_onnx_with_options( + path, + false, + MigraphxCacheOptions { + entries: vec![("ORT_MIGRAPHX_FP16_ENABLE".to_string(), "1".to_string())], + }, + ); + let int8_key = cache_key_for_onnx_with_options( + path, + true, + MigraphxCacheOptions { + entries: Vec::new(), + }, + ); + + assert_ne!(fp32_key, fp16_key); + assert_ne!(fp32_key, ort_fp16_key); + assert_ne!(fp32_key, int8_key); + assert_ne!(fp16_key, int8_key); + assert_eq!( + fp32_key, + cache_key_for_onnx_with_options( + path, + false, + MigraphxCacheOptions { + entries: Vec::new(), + } + ) + ); + } + + #[test] + fn test_migraphx_cache_key_hashes_model_contents() { + let unique = format!( + "next-plaid-cache-key-content-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + ); + let root = std::env::temp_dir().join(unique); + fs::create_dir_all(&root).unwrap(); + let path = root.join("model.onnx"); + + fs::write(&path, b"same-length-a").unwrap(); + let first_key = cache_key_for_onnx_with_options( + &path, + false, + MigraphxCacheOptions { + entries: Vec::new(), + }, + ); + fs::write(&path, b"same-length-b").unwrap(); + let second_key = cache_key_for_onnx_with_options( + &path, + false, + MigraphxCacheOptions { + entries: Vec::new(), + }, + ); + + assert_ne!(first_key, second_key); + let _ = fs::remove_dir_all(root); + } + + #[test] + fn test_migraphx_static_shape_cache_status_tracks_warm_document_shapes() { + let unique = format!( + "next-plaid-cache-status-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + ); + let root = std::env::temp_dir().join(unique); + let model_dir = root.join("model"); + let cache_root = root.join("cache"); + fs::create_dir_all(&model_dir).unwrap(); + fs::write(model_dir.join("model.onnx"), b"fake model").unwrap(); + fs::write( + model_dir.join("onnx_config.json"), + r#"{"query_length":256,"document_length":512,"embedding_dim":48}"#, + ) + .unwrap(); + + let old_cache_root = std::env::var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT").ok(); + let old_min_tokens = std::env::var("NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS").ok(); + std::env::set_var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT", &cache_root); + std::env::set_var("NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS", "65536"); + + let status = migraphx_static_shape_cache_status(&model_dir, false, 128).unwrap(); + assert_eq!( + status.document_shapes, + vec![ + MigraphxStaticShape { + batch_size: 512, + sequence_length: 128 + }, + MigraphxStaticShape { + batch_size: 256, + sequence_length: 256 + }, + MigraphxStaticShape { + batch_size: 128, + sequence_length: 512 + } + ] + ); + assert!(!status.all_document_shapes_warm()); + assert_eq!(status.warm_document_shapes.len(), 0); + assert_eq!(status.cold_document_shapes.len(), 3); + + let warm_shape = status.document_shapes[0]; + let warm_dir = cache_root + .join(&status.model_cache_key) + .join(warm_shape.cache_dir_name()); + fs::create_dir_all(&warm_dir).unwrap(); + fs::write(warm_dir.join("validated-v1"), b"validated").unwrap(); + fs::write(warm_dir.join("fake.mxr"), b"mxr").unwrap(); + + let status = migraphx_static_shape_cache_status(&model_dir, false, 128).unwrap(); + assert_eq!(status.warm_document_shapes, vec![warm_shape]); + assert_eq!(status.cold_document_shapes.len(), 2); + + if let Some(value) = old_cache_root { + std::env::set_var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT", value); + } else { + std::env::remove_var("NEXT_PLAID_MIGRAPHX_STATIC_CACHE_ROOT"); + } + if let Some(value) = old_min_tokens { + std::env::set_var("NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS", value); + } else { + std::env::remove_var("NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS"); + } + let _ = fs::remove_dir_all(root); + } + // ========================================================================= // Pool embeddings tests // ========================================================================= From b9171fe25bdea4c7013ed34cfe45eb19696d99a6 Mon Sep 17 00:00:00 2001 From: Guillaume Ausset Date: Tue, 26 May 2026 21:55:38 +0200 Subject: [PATCH 4/4] Add ColGREP cache warming command Expose colgrep warm-cache as an explicit, advanced path for preparing provider-specific runtime caches. For MIGraphX it warms only eligible expensive static document shapes and reports when there is nothing worth compiling. --- colgrep/README.md | 17 ++++++ colgrep/src/cli.rs | 42 +++++++++++++- colgrep/src/commands/mod.rs | 2 + colgrep/src/commands/warm_cache.rs | 91 ++++++++++++++++++++++++++++++ colgrep/src/main.rs | 8 ++- next-plaid-onnx/README.md | 33 ++++++++++- 6 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 colgrep/src/commands/warm_cache.rs diff --git a/colgrep/README.md b/colgrep/README.md index eabd672e..62178058 100644 --- a/colgrep/README.md +++ b/colgrep/README.md @@ -219,6 +219,7 @@ colgrep --json "auth" | jq '.[] | .unit.file' | `colgrep clear` | Clear index for current project | | `colgrep clear --all` | Clear all indexes | | `colgrep set-model ` | Change the default ColBERT model | +| `colgrep warm-cache --provider migraphx` | Warm MIGraphX runtime caches | | `colgrep settings` | View or modify configuration | | `colgrep settings --ignore` | Add extra ignore patterns (persistent) | | `colgrep settings --force-include` | Force-include normally ignored paths | @@ -493,6 +494,22 @@ This is useful for: - **CI/dev setup** scripts where you want indexing to happen ahead of time - **Updating** the index after pulling new code +### `colgrep warm-cache` + +Warm provider-specific runtime caches without indexing. For MIGraphX this +pre-compiles only eligible expensive static shapes that colgrep can later reuse +for experimental GPU indexing. Cold or ineligible shapes continue to use CPU; +warming can take minutes and usually only pays off for repeated large batches. + +```bash +colgrep warm-cache --provider migraphx +colgrep warm-cache --provider migraphx --batch-size 64 --max-sequence-len 1024 +``` + +Advanced MIGraphX thresholds can be tuned with +`NEXT_PLAID_MIGRAPHX_MIN_STATIC_SHAPE_TOKENS` and +`NEXT_PLAID_MIGRAPHX_MIN_RUN_TOKENS`. + ```bash # Check index status colgrep status diff --git a/colgrep/src/cli.rs b/colgrep/src/cli.rs index b08a70e8..2d04b5d3 100644 --- a/colgrep/src/cli.rs +++ b/colgrep/src/cli.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use crate::color::ColorChoice; @@ -182,6 +182,21 @@ NOTES: • Useful for pre-warming the index before searching • Subsequent searches will be fast since the index is already built"; +pub const WARM_CACHE_HELP: &str = "\ +MIGraphX cache warming is experimental. It can take minutes and usually only +pays off for repeated large indexing batches. ColGREP never compiles cold +MIGraphX shapes during normal indexing; cold or ineligible shapes use CPU. + +EXAMPLES: + # Warm eligible expensive MIGraphX static-shape caches for the configured model + colgrep warm-cache --provider migraphx + + # Warm caches for a specific model and batch size + colgrep warm-cache --provider migraphx --model lightonai/LateOn-Code-edge --batch-size 64 + + # Limit warming to shorter document shapes only + colgrep warm-cache --provider migraphx --max-sequence-len 1024"; + pub const CONFIG_HELP: &str = "\ EXAMPLES: # Show current configuration @@ -449,6 +464,11 @@ pub struct Cli { pub force_gpu: bool, } +#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] +pub enum CacheProvider { + Migraphx, +} + #[derive(Subcommand)] pub enum Commands { /// Search for code semantically (auto-indexes if needed) @@ -621,6 +641,26 @@ pub enum Commands { static_batch: bool, }, + /// Warm provider-specific runtime caches without indexing + #[command(name = "warm-cache", after_help = WARM_CACHE_HELP)] + WarmCache { + /// Cache provider to warm + #[arg(long, value_enum, default_value_t = CacheProvider::Migraphx)] + provider: CacheProvider, + + /// ColBERT model HuggingFace ID or local path (uses saved preference if not specified) + #[arg(long)] + model: Option, + + /// Model/session batch size whose MIGraphX static shapes should be warmed + #[arg(long = "batch-size", value_name = "SIZE")] + batch_size: Option, + + /// Maximum sequence length of MIGraphX static shapes to warm (default: all eligible expensive shapes) + #[arg(long = "max-sequence-len")] + max_sequence_len: Option, + }, + /// View or set configuration options (default k, n values) #[command(name = "settings", after_help = CONFIG_HELP)] Settings { diff --git a/colgrep/src/commands/mod.rs b/colgrep/src/commands/mod.rs index fe6c6259..9aa483a3 100644 --- a/colgrep/src/commands/mod.rs +++ b/colgrep/src/commands/mod.rs @@ -6,6 +6,7 @@ pub mod search; mod stats; mod status; mod update; +mod warm_cache; pub use clear::cmd_clear; pub use config::{cmd_config, cmd_set_model}; @@ -15,3 +16,4 @@ pub use search::cmd_search; pub use stats::{cmd_reset_stats, cmd_stats}; pub use status::cmd_status; pub use update::cmd_update; +pub use warm_cache::cmd_warm_cache; diff --git a/colgrep/src/commands/warm_cache.rs b/colgrep/src/commands/warm_cache.rs new file mode 100644 index 00000000..0b63a0d6 --- /dev/null +++ b/colgrep/src/commands/warm_cache.rs @@ -0,0 +1,91 @@ +use anyhow::Result; + +use crate::cli::CacheProvider; + +pub fn cmd_warm_cache( + provider: CacheProvider, + cli_model: Option<&str>, + batch_size: Option, + max_sequence_len: Option, +) -> Result<()> { + match provider { + CacheProvider::Migraphx => warm_migraphx_cache(cli_model, batch_size, max_sequence_len), + } +} + +#[cfg(not(feature = "migraphx"))] +fn warm_migraphx_cache( + _cli_model: Option<&str>, + _batch_size: Option, + _max_sequence_len: Option, +) -> Result<()> { + anyhow::bail!("MIGraphX support is not compiled. Rebuild colgrep with --features migraphx."); +} + +#[cfg(feature = "migraphx")] +fn warm_migraphx_cache( + cli_model: Option<&str>, + batch_size: Option, + max_sequence_len: Option, +) -> Result<()> { + use anyhow::Context; + use colgrep::acceleration::{ + apply_acceleration_mode, env_acceleration_mode_lossy, AccelerationMode, + }; + use colgrep::{config, ensure_model, onnx_runtime, Config}; + use next_plaid_onnx::{Colbert, ExecutionProvider}; + + if env_acceleration_mode_lossy() == AccelerationMode::ForceCpu { + anyhow::bail!("warm-cache --provider migraphx requires GPU execution; remove --force-cpu."); + } + + // Force GPU runtime discovery so a missing MIGraphX-capable ONNX Runtime + // fails early with the MIGraphX installation guidance instead of falling + // back to a CPU-only runtime. + apply_acceleration_mode(AccelerationMode::ForceGpu); + onnx_runtime::ensure_onnx_runtime().context("Failed to initialize ONNX Runtime")?; + + let model_id = crate::commands::search::resolve_model(cli_model); + let config = Config::load().unwrap_or_default(); + let quantized = !config.use_fp32(); + let batch_size = batch_size + .map(|batch_size| batch_size.max(1)) + .or_else(|| config.configured_batch_size()) + .unwrap_or_else(|| { + config::default_batch_size_for_execution_provider(ExecutionProvider::MIGraphX) + }); + let model_path = ensure_model(Some(&model_id), false)?; + + eprintln!("šŸ¤– Model: {model_id}"); + let max_sequence_len_label = max_sequence_len + .map(|len| len.max(1).to_string()) + .unwrap_or_else(|| "all".to_string()); + eprintln!( + "šŸ”„ Warming eligible expensive MIGraphX static-shape cache(s) (batch_size={batch_size}, max_sequence_len={max_sequence_len_label})..." + ); + + let model = Colbert::builder(&model_path) + .with_quantized(quantized) + .with_parallel(1) + .with_batch_size(batch_size) + .with_execution_provider(ExecutionProvider::MIGraphX) + .build() + .context("Failed to load ColBERT model for MIGraphX cache warming")?; + + let mut shapes = model.migraphx_static_shapes(); + if let Some(max_sequence_len) = max_sequence_len { + shapes.retain(|shape| shape.sequence_length <= max_sequence_len.max(1)); + } + eprintln!("Eligible planned shapes: {shapes:?}"); + if shapes.is_empty() { + eprintln!("No eligible MIGraphX shapes for this model/batch-size; nothing to warm."); + } + + let warmed = if let Some(max_sequence_len) = max_sequence_len { + model.warm_migraphx_static_shape_cache_up_to(max_sequence_len.max(1))? + } else { + model.warm_migraphx_static_shape_cache()? + }; + eprintln!("āœ… Warmed {warmed} MIGraphX shape cache(s)."); + Ok(()) +} diff --git a/colgrep/src/main.rs b/colgrep/src/main.rs index 368566dd..9100b19f 100644 --- a/colgrep/src/main.rs +++ b/colgrep/src/main.rs @@ -20,7 +20,7 @@ use cli::{Cli, Commands}; use commands::search::{resolve_pool_factor, resolve_top_k}; use commands::{ cmd_clear, cmd_config, cmd_init, cmd_reset_stats, cmd_search, cmd_session_hook, cmd_set_model, - cmd_stats, cmd_status, cmd_task_hook, cmd_update, InitOptions, + cmd_stats, cmd_status, cmd_task_hook, cmd_update, cmd_warm_cache, InitOptions, }; fn main() -> Result<()> { @@ -255,6 +255,12 @@ fn main() -> Result<()> { }, ), Some(Commands::Update) => cmd_update(), + Some(Commands::WarmCache { + provider, + model, + batch_size, + max_sequence_len, + }) => cmd_warm_cache(provider, model.as_deref(), batch_size, max_sequence_len), Some(Commands::Status { path }) => cmd_status(&path), Some(Commands::Clear { path, all }) => cmd_clear(&path, all), Some(Commands::SetModel { model }) => cmd_set_model(&model), diff --git a/next-plaid-onnx/README.md b/next-plaid-onnx/README.md index 2e9f0f21..ca9894dc 100644 --- a/next-plaid-onnx/README.md +++ b/next-plaid-onnx/README.md @@ -75,9 +75,12 @@ next-plaid-onnx = { version = "0.2", features = ["coreml"] } # Windows DirectML (DirectX 12) next-plaid-onnx = { version = "0.2", features = ["directml"] } + +# AMD ROCm / MIGraphX +next-plaid-onnx = { version = "0.2", features = ["migraphx"] } ``` -`ExecutionProvider::Auto` tries providers in order: CUDA → TensorRT → CoreML → DirectML → CPU. Set `NEXT_PLAID_FORCE_CPU=1` to bypass all GPU providers. +`ExecutionProvider::Auto` tries providers in order: CUDA → TensorRT → CoreML → DirectML → MIGraphX → CPU. Set `NEXT_PLAID_FORCE_CPU=1` to bypass all GPU providers. For ROCm, install AMD's ONNX Runtime wheel for your ROCm release (for example `pip install onnxruntime-migraphx -f https://repo.radeon.com/rocm/manylinux/rocm-rel-/`) @@ -144,6 +147,33 @@ impl ColbertBuilder { } ``` +#### MIGraphX cache helpers + +```rust +pub fn migraphx_static_shape_cache_status( + model_dir: impl AsRef, + quantized: bool, + batch_size: usize, +) -> Result; + +pub fn migraphx_document_static_shape_caches_warm( + model_dir: impl AsRef, + quantized: bool, + batch_size: usize, +) -> Result; + +impl Colbert { + pub fn warm_migraphx_static_shape_cache_up_to( + &self, + max_sequence_len: usize, + ) -> Result; +} +``` + +These inspect per-shape validation markers and MXR files without creating ONNX +sessions, so callers can route fully warmed MIGraphX workloads to MIGraphX and +keep cold/incomplete workloads on CPU. + #### `ExecutionProvider` ```rust @@ -154,6 +184,7 @@ pub enum ExecutionProvider { TensorRT, // NVIDIA TensorRT (requires `tensorrt` feature) CoreML, // Apple Silicon (requires `coreml` feature) DirectML, // Windows DirectX 12 (requires `directml` feature) + MIGraphX, // AMD ROCm/MIGraphX (requires `migraphx` feature) } ```