diff --git a/colgrep/src/index/mod.rs b/colgrep/src/index/mod.rs index 3342df30..ae74afa1 100644 --- a/colgrep/src/index/mod.rs +++ b/colgrep/src/index/mod.rs @@ -999,11 +999,15 @@ impl IndexBuilder { } } }; - #[cfg(not(feature = "cuda"))] + #[cfg(not(any( + feature = "cuda", + feature = "directml", + feature = "migraphx", + feature = "coreml" + )))] let (num_sessions, execution_provider) = { - let _ = num_units; // Silence unused warning when cuda feature is disabled + let _ = num_units; - // Initialize ONNX Runtime (CPU-only build) crate::onnx_runtime::ensure_onnx_runtime() .context("Failed to initialize ONNX Runtime")?; @@ -1014,8 +1018,31 @@ impl IndexBuilder { ) }; + #[cfg(any(feature = "directml", feature = "migraphx", feature = "coreml"))] + #[cfg(not(feature = "cuda"))] + let (num_sessions, execution_provider) = { + let _ = num_units; + + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + + let provider = if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX + } else { + ExecutionProvider::CoreML + }; + + ( + self.parallel_sessions + .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), + provider, + ) + }; + // Print model info after ONNX runtime is initialized (and any potential re-exec) - eprintln!("🤖 Model: {}", self.model_id); + eprintln!("🤖 Model: {} ({})", self.model_id, execution_provider); eprintln!("📂 Building index..."); // Use runtime default for batch size (respects cuDNN availability) @@ -3286,6 +3313,10 @@ impl Searcher { AccelerationMode::Auto => { if cfg!(feature = "coreml") { ExecutionProvider::CoreML + } else if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX } else { ExecutionProvider::Cpu } @@ -3362,6 +3393,10 @@ impl Searcher { AccelerationMode::Auto => { if cfg!(feature = "coreml") { ExecutionProvider::CoreML + } else if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX } else { ExecutionProvider::Cpu } diff --git a/colgrep/src/onnx_runtime.rs b/colgrep/src/onnx_runtime.rs index 6bd9a4f3..daa04472 100644 --- a/colgrep/src/onnx_runtime.rs +++ b/colgrep/src/onnx_runtime.rs @@ -41,16 +41,10 @@ const ORT_LIB_NAME: &str = "libonnxruntime.so"; #[cfg(target_os = "windows")] const ORT_LIB_NAME: &str = "onnxruntime.dll"; -/// Whether to use GPU (CUDA) version of ONNX Runtime -#[cfg(feature = "cuda")] -const USE_GPU: bool = true; -#[cfg(not(feature = "cuda"))] -const USE_GPU: bool = false; - /// Subdirectory name for caching (gpu vs cpu) -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "directml"))] const ORT_CACHE_SUBDIR: &str = "gpu"; -#[cfg(not(feature = "cuda"))] +#[cfg(not(any(feature = "cuda", feature = "directml")))] const ORT_CACHE_SUBDIR: &str = "cpu"; /// Ensure ONNX Runtime is available. @@ -557,7 +551,10 @@ fn download_onnx_runtime() -> Result { && cache_dir.join("onnxruntime_providers_shared.dll").exists() && cache_dir.join("onnxruntime_providers_cuda.dll").exists(); - #[cfg(not(feature = "cuda"))] + #[cfg(all(feature = "directml", not(feature = "cuda")))] + let already_cached = lib_path.exists(); + + #[cfg(not(any(feature = "cuda", feature = "directml")))] let already_cached = lib_path.exists(); if already_cached { @@ -568,11 +565,12 @@ fn download_onnx_runtime() -> Result { let (url, files_to_extract) = get_download_info()?; - if USE_GPU { - eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION); - } else { - eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION); - } + #[cfg(feature = "cuda")] + eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION); + #[cfg(all(feature = "directml", not(feature = "cuda")))] + eprintln!("⚙️ Runtime: ONNX {} (GPU/DirectML)", ORT_VERSION); + #[cfg(not(any(feature = "cuda", feature = "directml")))] + eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION); // Download archive let response = ureq::get(&url) @@ -593,128 +591,159 @@ type FileToExtract = (String, String); /// Get download URL and files to extract for current platform fn get_download_info() -> Result<(String, Vec)> { - let base = format!( - "https://github.com/microsoft/onnxruntime/releases/download/v{}", - ORT_VERSION - ); - - // macOS - no GPU support via GitHub releases (use CoreML instead) - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - let (archive, files) = ( - format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION), + // DirectML: download from NuGet (Microsoft GPU package does not include DirectML) + #[cfg(all( + target_os = "windows", + target_arch = "x86_64", + feature = "directml", + not(feature = "cuda") + ))] + return Ok(( + format!( + "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/{}", + ORT_VERSION + ), vec![( - format!( - "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.dylib".to_string(), + "runtimes/win-x64/native/onnxruntime.dll".to_string(), + "onnxruntime.dll".to_string(), )], - ); + )); - #[cfg(all(target_os = "macos", target_arch = "x86_64"))] - let (archive, files) = ( - format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.dylib".to_string(), - )], - ); - - // Linux x86_64 - supports both CPU and GPU - #[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))] - let (archive, files) = { - let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION); - ( - format!("{}.tgz", archive_name), - vec![ - ( - format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION), - "libonnxruntime.so".to_string(), - ), - ( - format!("{}/lib/libonnxruntime_providers_shared.so", archive_name), - "libonnxruntime_providers_shared.so".to_string(), + // All other configurations: download from GitHub releases + #[cfg(not(all( + target_os = "windows", + target_arch = "x86_64", + feature = "directml", + not(feature = "cuda") + )))] + { + let base = format!( + "https://github.com/microsoft/onnxruntime/releases/download/v{}", + ORT_VERSION + ); + + // macOS - no GPU support via GitHub releases (use CoreML instead) + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + let (archive, files) = ( + format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib", + ORT_VERSION, ORT_VERSION ), - ( - format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name), - "libonnxruntime_providers_cuda.so".to_string(), + "libonnxruntime.dylib".to_string(), + )], + ); + + #[cfg(all(target_os = "macos", target_arch = "x86_64"))] + let (archive, files) = ( + format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib", + ORT_VERSION, ORT_VERSION ), - ], - ) - }; + "libonnxruntime.dylib".to_string(), + )], + ); - #[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))] - let (archive, files) = ( - format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.so".to_string(), - )], - ); + // Linux x86_64 - supports both CPU and GPU + #[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))] + let (archive, files) = { + let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION); + ( + format!("{}.tgz", archive_name), + vec![ + ( + format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION), + "libonnxruntime.so".to_string(), + ), + ( + format!("{}/lib/libonnxruntime_providers_shared.so", archive_name), + "libonnxruntime_providers_shared.so".to_string(), + ), + ( + format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name), + "libonnxruntime_providers_cuda.so".to_string(), + ), + ], + ) + }; - // Linux aarch64 - CPU only (no GPU releases available) - #[cfg(all(target_os = "linux", target_arch = "aarch64"))] - let (archive, files) = ( - format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.so".to_string(), - )], - ); - - // Windows - supports both CPU and GPU - #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))] - let (archive, files) = { - let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION); - ( - format!("{}.zip", archive_name), - vec![ - ( - format!("{}/lib/onnxruntime.dll", archive_name), - "onnxruntime.dll".to_string(), + #[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))] + let (archive, files) = ( + format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}", + ORT_VERSION, ORT_VERSION ), - ( - format!("{}/lib/onnxruntime_providers_shared.dll", archive_name), - "onnxruntime_providers_shared.dll".to_string(), - ), - ( - format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name), - "onnxruntime_providers_cuda.dll".to_string(), + "libonnxruntime.so".to_string(), + )], + ); + + // Linux aarch64 - CPU only (no GPU releases available) + #[cfg(all(target_os = "linux", target_arch = "aarch64"))] + let (archive, files) = ( + format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}", + ORT_VERSION, ORT_VERSION ), - ], - ) - }; + "libonnxruntime.so".to_string(), + )], + ); - #[cfg(all(target_os = "windows", target_arch = "x86_64", not(feature = "cuda")))] - let (archive, files) = ( - format!("onnxruntime-win-x64-{}.zip", ORT_VERSION), - vec![( - format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION), - "onnxruntime.dll".to_string(), - )], - ); - - #[cfg(not(any( - all(target_os = "macos", target_arch = "aarch64"), - all(target_os = "macos", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - all(target_os = "windows", target_arch = "x86_64"), - )))] - return Err(anyhow::anyhow!( - "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH." - )); + // Windows - supports both CPU and GPU + #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))] + let (archive, files) = { + let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION); + ( + format!("{}.zip", archive_name), + vec![ + ( + format!("{}/lib/onnxruntime.dll", archive_name), + "onnxruntime.dll".to_string(), + ), + ( + format!("{}/lib/onnxruntime_providers_shared.dll", archive_name), + "onnxruntime_providers_shared.dll".to_string(), + ), + ( + format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name), + "onnxruntime_providers_cuda.dll".to_string(), + ), + ], + ) + }; + + #[cfg(all( + target_os = "windows", + target_arch = "x86_64", + not(any(feature = "cuda", feature = "directml")) + ))] + let (archive, files) = ( + format!("onnxruntime-win-x64-{}.zip", ORT_VERSION), + vec![( + format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION), + "onnxruntime.dll".to_string(), + )], + ); - Ok((format!("{}/{}", base, archive), files)) + #[cfg(not(any( + all(target_os = "macos", target_arch = "aarch64"), + all(target_os = "macos", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + all(target_os = "windows", target_arch = "x86_64"), + )))] + return Err(anyhow::anyhow!( + "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH." + )); + + Ok((format!("{}/{}", base, archive), files)) + } } /// Extract libraries from tgz archive diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 359c5545..617ab003 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -198,6 +198,23 @@ pub enum ExecutionProvider { MIGraphX, } +impl std::fmt::Display for ExecutionProvider { + /// Short user-facing label matching the tokens used in onnx_runtime.rs + /// download messages (e.g. "CPU", "CUDA", "DirectML"). Stable across + /// releases; do not change without updating callers that log the label. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Auto => f.write_str("auto"), + Self::Cpu => f.write_str("CPU"), + Self::Cuda => f.write_str("CUDA"), + Self::TensorRT => f.write_str("TensorRT"), + Self::CoreML => f.write_str("CoreML"), + Self::DirectML => f.write_str("DirectML"), + Self::MIGraphX => f.write_str("MIGraphX"), + } + } +} + fn configure_execution_provider( builder: SessionBuilder, provider: ExecutionProvider, @@ -347,7 +364,7 @@ fn configure_auto_provider(builder: SessionBuilder) -> Result { } #[cfg(feature = "coreml")] - { + if !force_cpu { if let Ok(b) = builder .clone() .with_execution_providers([CoreMLExecutionProvider::default().build()]) @@ -2403,6 +2420,19 @@ mod tests { assert_eq!(debug_str, "Cuda"); } + #[test] + fn test_execution_provider_display() { + // Labels are part of the user-facing surface (e.g. colgrep's + // `Model: ()` line); lock them in. + assert_eq!(format!("{}", ExecutionProvider::Auto), "auto"); + assert_eq!(format!("{}", ExecutionProvider::Cpu), "CPU"); + assert_eq!(format!("{}", ExecutionProvider::Cuda), "CUDA"); + assert_eq!(format!("{}", ExecutionProvider::TensorRT), "TensorRT"); + assert_eq!(format!("{}", ExecutionProvider::CoreML), "CoreML"); + assert_eq!(format!("{}", ExecutionProvider::DirectML), "DirectML"); + assert_eq!(format!("{}", ExecutionProvider::MIGraphX), "MIGraphX"); + } + // ========================================================================= // Pool embeddings tests // =========================================================================