From 499df8ae0fcee26de2ca1bfcf5a80e78f93abc09 Mon Sep 17 00:00:00 2001 From: noctrex Date: Thu, 4 Jun 2026 21:57:54 +0300 Subject: [PATCH 1/3] feat(onnx): add DirectML/MIGraphX/CoreML execution provider support Extend colgrep's ONNX Runtime integration beyond CUDA so the same code path can drive non-NVIDIA GPUs via DirectML (Windows), AMD GPUs via MIGraphX, and Apple Silicon via CoreML, selected by cargo feature. The active backend is now also surfaced in the user-facing Model line. next-plaid-onnx/src/lib.rs - Add impl Display for ExecutionProvider with short canonical labels (CPU, CUDA, DirectML, CoreML, MIGraphX, TensorRT, auto) matching the tokens already used in onnx_runtime.rs download messages. - Add test_execution_provider_display alongside the existing EP tests. colgrep/src/index/mod.rs - New cfg-gated IndexBuilder branch for directml/migraphx/coreml features that initializes the runtime and picks the matching ExecutionProvider. Falls through to the existing CPU branch when none of those features are enabled. - Searcher::Auto (two sites) extends its fallback chain past CoreML to also consider DirectML and MIGraphX before Cpu. - The Model: eprintln! now appends the selected backend in parentheses, e.g. "Model: lightonai/LateOn-Code (DirectML)". Provider is resolved at runtime so a CUDA build that falls back to CPU honestly prints CPU. colgrep/src/onnx_runtime.rs - Drop the USE_GPU bool constant in favor of per-feature eprintln!s so each provider gets an accurate label at download time. - Route DirectML into the existing "gpu" cache subdir (shared with CUDA) and keep CPU in "cpu". - Add a NuGet download path for Microsoft.ML.OnnxRuntime.DirectML on win-x64, since the GitHub release artifacts do not ship DirectML. Other platforms/configs keep using GitHub releases unchanged. Note: MIGraphX and CoreML are wired into provider selection but their runtime download still falls through to the CPU package; users on those platforms are expected to supply the matching ORT build out-of-band for now. DirectML is the only new provider with full auto-download support. Verified: - cargo test -p next-plaid-onnx execution_provider (6 passed) - cargo build --release -p colgrep --features directml on Windows x86_64 - cargo build --release -p colgrep (default features) - Runtime: colgrep prints "Model: lightonai/LateOn-Code (DirectML)" with --features directml Disclaimer: This commit was authored with assistance from an LLM. --- colgrep/src/index/mod.rs | 38 ++++++++++++++++++++++++++---- colgrep/src/onnx_runtime.rs | 46 ++++++++++++++++++++++++++----------- next-plaid-onnx/src/lib.rs | 30 ++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 18 deletions(-) diff --git a/colgrep/src/index/mod.rs b/colgrep/src/index/mod.rs index 3342df30..0ef66ae5 100644 --- a/colgrep/src/index/mod.rs +++ b/colgrep/src/index/mod.rs @@ -999,11 +999,10 @@ impl IndexBuilder { } } }; - #[cfg(not(feature = "cuda"))] + #[cfg(not(any(feature = "cuda", feature = "directml", feature = "migraphx", feature = "coreml")))] let (num_sessions, execution_provider) = { - let _ = num_units; // Silence unused warning when cuda feature is disabled + let _ = num_units; - // Initialize ONNX Runtime (CPU-only build) crate::onnx_runtime::ensure_onnx_runtime() .context("Failed to initialize ONNX Runtime")?; @@ -1014,8 +1013,31 @@ impl IndexBuilder { ) }; + #[cfg(any(feature = "directml", feature = "migraphx", feature = "coreml"))] + #[cfg(not(feature = "cuda"))] + let (num_sessions, execution_provider) = { + let _ = num_units; + + crate::onnx_runtime::ensure_onnx_runtime() + .context("Failed to initialize ONNX Runtime")?; + + let provider = if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX + } else { + ExecutionProvider::CoreML + }; + + ( + self.parallel_sessions + .unwrap_or_else(crate::config::get_default_cpu_parallel_sessions), + provider, + ) + }; + // Print model info after ONNX runtime is initialized (and any potential re-exec) - eprintln!("🤖 Model: {}", self.model_id); + eprintln!("🤖 Model: {} ({})", self.model_id, execution_provider); eprintln!("📂 Building index..."); // Use runtime default for batch size (respects cuDNN availability) @@ -3286,6 +3308,10 @@ impl Searcher { AccelerationMode::Auto => { if cfg!(feature = "coreml") { ExecutionProvider::CoreML + } else if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX } else { ExecutionProvider::Cpu } @@ -3362,6 +3388,10 @@ impl Searcher { AccelerationMode::Auto => { if cfg!(feature = "coreml") { ExecutionProvider::CoreML + } else if cfg!(feature = "directml") { + ExecutionProvider::DirectML + } else if cfg!(feature = "migraphx") { + ExecutionProvider::MIGraphX } else { ExecutionProvider::Cpu } diff --git a/colgrep/src/onnx_runtime.rs b/colgrep/src/onnx_runtime.rs index 6bd9a4f3..6db842c6 100644 --- a/colgrep/src/onnx_runtime.rs +++ b/colgrep/src/onnx_runtime.rs @@ -41,16 +41,11 @@ const ORT_LIB_NAME: &str = "libonnxruntime.so"; #[cfg(target_os = "windows")] const ORT_LIB_NAME: &str = "onnxruntime.dll"; -/// Whether to use GPU (CUDA) version of ONNX Runtime -#[cfg(feature = "cuda")] -const USE_GPU: bool = true; -#[cfg(not(feature = "cuda"))] -const USE_GPU: bool = false; /// Subdirectory name for caching (gpu vs cpu) -#[cfg(feature = "cuda")] +#[cfg(any(feature = "cuda", feature = "directml"))] const ORT_CACHE_SUBDIR: &str = "gpu"; -#[cfg(not(feature = "cuda"))] +#[cfg(not(any(feature = "cuda", feature = "directml")))] const ORT_CACHE_SUBDIR: &str = "cpu"; /// Ensure ONNX Runtime is available. @@ -557,7 +552,10 @@ fn download_onnx_runtime() -> Result { && cache_dir.join("onnxruntime_providers_shared.dll").exists() && cache_dir.join("onnxruntime_providers_cuda.dll").exists(); - #[cfg(not(feature = "cuda"))] + #[cfg(all(feature = "directml", not(feature = "cuda")))] + let already_cached = lib_path.exists(); + + #[cfg(not(any(feature = "cuda", feature = "directml")))] let already_cached = lib_path.exists(); if already_cached { @@ -568,11 +566,12 @@ fn download_onnx_runtime() -> Result { let (url, files_to_extract) = get_download_info()?; - if USE_GPU { - eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION); - } else { - eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION); - } + #[cfg(feature = "cuda")] + eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION); + #[cfg(all(feature = "directml", not(feature = "cuda")))] + eprintln!("⚙️ Runtime: ONNX {} (GPU/DirectML)", ORT_VERSION); + #[cfg(not(any(feature = "cuda", feature = "directml")))] + eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION); // Download archive let response = ureq::get(&url) @@ -593,6 +592,24 @@ type FileToExtract = (String, String); /// Get download URL and files to extract for current platform fn get_download_info() -> Result<(String, Vec)> { + // DirectML: download from NuGet (Microsoft GPU package does not include DirectML) + #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "directml", not(feature = "cuda")))] + return Ok(( + format!( + "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/{}", + ORT_VERSION + ), + vec![ + ( + "runtimes/win-x64/native/onnxruntime.dll".to_string(), + "onnxruntime.dll".to_string(), + ), + ], + )); + + // All other configurations: download from GitHub releases + #[cfg(not(all(target_os = "windows", target_arch = "x86_64", feature = "directml", not(feature = "cuda"))))] + { let base = format!( "https://github.com/microsoft/onnxruntime/releases/download/v{}", ORT_VERSION @@ -694,7 +711,7 @@ fn get_download_info() -> Result<(String, Vec)> { ) }; - #[cfg(all(target_os = "windows", target_arch = "x86_64", not(feature = "cuda")))] + #[cfg(all(target_os = "windows", target_arch = "x86_64", not(any(feature = "cuda", feature = "directml"))))] let (archive, files) = ( format!("onnxruntime-win-x64-{}.zip", ORT_VERSION), vec![( @@ -715,6 +732,7 @@ fn get_download_info() -> Result<(String, Vec)> { )); Ok((format!("{}/{}", base, archive), files)) + } } /// Extract libraries from tgz archive diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 359c5545..7a9c7bbb 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -198,6 +198,23 @@ pub enum ExecutionProvider { MIGraphX, } +impl std::fmt::Display for ExecutionProvider { + /// Short user-facing label matching the tokens used in onnx_runtime.rs + /// download messages (e.g. "CPU", "CUDA", "DirectML"). Stable across + /// releases; do not change without updating callers that log the label. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Auto => f.write_str("auto"), + Self::Cpu => f.write_str("CPU"), + Self::Cuda => f.write_str("CUDA"), + Self::TensorRT => f.write_str("TensorRT"), + Self::CoreML => f.write_str("CoreML"), + Self::DirectML => f.write_str("DirectML"), + Self::MIGraphX => f.write_str("MIGraphX"), + } + } +} + fn configure_execution_provider( builder: SessionBuilder, provider: ExecutionProvider, @@ -2403,6 +2420,19 @@ mod tests { assert_eq!(debug_str, "Cuda"); } + #[test] + fn test_execution_provider_display() { + // Labels are part of the user-facing surface (e.g. colgrep's + // `Model: ()` line); lock them in. + assert_eq!(format!("{}", ExecutionProvider::Auto), "auto"); + assert_eq!(format!("{}", ExecutionProvider::Cpu), "CPU"); + assert_eq!(format!("{}", ExecutionProvider::Cuda), "CUDA"); + assert_eq!(format!("{}", ExecutionProvider::TensorRT), "TensorRT"); + assert_eq!(format!("{}", ExecutionProvider::CoreML), "CoreML"); + assert_eq!(format!("{}", ExecutionProvider::DirectML), "DirectML"); + assert_eq!(format!("{}", ExecutionProvider::MIGraphX), "MIGraphX"); + } + // ========================================================================= // Pool embeddings tests // ========================================================================= From 317b4bdd55e77531dc75332cebe469cf2165cf14 Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Mon, 8 Jun 2026 18:11:58 +0200 Subject: [PATCH 2/3] style(onnx): apply rustfmt to execution provider changes Fixes the Format / Colgrep Crate Format CI jobs. Pure rustfmt output: wrap long #[cfg] attributes, collapse single-element vec!, and re-indent the get_download_info block. No logic change. --- colgrep/src/index/mod.rs | 7 +- colgrep/src/onnx_runtime.rs | 255 +++++++++++++++++++----------------- 2 files changed, 139 insertions(+), 123 deletions(-) diff --git a/colgrep/src/index/mod.rs b/colgrep/src/index/mod.rs index 0ef66ae5..ae74afa1 100644 --- a/colgrep/src/index/mod.rs +++ b/colgrep/src/index/mod.rs @@ -999,7 +999,12 @@ impl IndexBuilder { } } }; - #[cfg(not(any(feature = "cuda", feature = "directml", feature = "migraphx", feature = "coreml")))] + #[cfg(not(any( + feature = "cuda", + feature = "directml", + feature = "migraphx", + feature = "coreml" + )))] let (num_sessions, execution_provider) = { let _ = num_units; diff --git a/colgrep/src/onnx_runtime.rs b/colgrep/src/onnx_runtime.rs index 6db842c6..daa04472 100644 --- a/colgrep/src/onnx_runtime.rs +++ b/colgrep/src/onnx_runtime.rs @@ -41,7 +41,6 @@ const ORT_LIB_NAME: &str = "libonnxruntime.so"; #[cfg(target_os = "windows")] const ORT_LIB_NAME: &str = "onnxruntime.dll"; - /// Subdirectory name for caching (gpu vs cpu) #[cfg(any(feature = "cuda", feature = "directml"))] const ORT_CACHE_SUBDIR: &str = "gpu"; @@ -593,145 +592,157 @@ type FileToExtract = (String, String); /// Get download URL and files to extract for current platform fn get_download_info() -> Result<(String, Vec)> { // DirectML: download from NuGet (Microsoft GPU package does not include DirectML) - #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "directml", not(feature = "cuda")))] + #[cfg(all( + target_os = "windows", + target_arch = "x86_64", + feature = "directml", + not(feature = "cuda") + ))] return Ok(( format!( "https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/{}", ORT_VERSION ), - vec![ - ( - "runtimes/win-x64/native/onnxruntime.dll".to_string(), - "onnxruntime.dll".to_string(), - ), - ], + vec![( + "runtimes/win-x64/native/onnxruntime.dll".to_string(), + "onnxruntime.dll".to_string(), + )], )); // All other configurations: download from GitHub releases - #[cfg(not(all(target_os = "windows", target_arch = "x86_64", feature = "directml", not(feature = "cuda"))))] + #[cfg(not(all( + target_os = "windows", + target_arch = "x86_64", + feature = "directml", + not(feature = "cuda") + )))] { - let base = format!( - "https://github.com/microsoft/onnxruntime/releases/download/v{}", - ORT_VERSION - ); - - // macOS - no GPU support via GitHub releases (use CoreML instead) - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - let (archive, files) = ( - format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.dylib".to_string(), - )], - ); + let base = format!( + "https://github.com/microsoft/onnxruntime/releases/download/v{}", + ORT_VERSION + ); - #[cfg(all(target_os = "macos", target_arch = "x86_64"))] - let (archive, files) = ( - format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.dylib".to_string(), - )], - ); - - // Linux x86_64 - supports both CPU and GPU - #[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))] - let (archive, files) = { - let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION); - ( - format!("{}.tgz", archive_name), - vec![ - ( - format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION), - "libonnxruntime.so".to_string(), + // macOS - no GPU support via GitHub releases (use CoreML instead) + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + let (archive, files) = ( + format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib", + ORT_VERSION, ORT_VERSION ), - ( - format!("{}/lib/libonnxruntime_providers_shared.so", archive_name), - "libonnxruntime_providers_shared.so".to_string(), - ), - ( - format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name), - "libonnxruntime_providers_cuda.so".to_string(), + "libonnxruntime.dylib".to_string(), + )], + ); + + #[cfg(all(target_os = "macos", target_arch = "x86_64"))] + let (archive, files) = ( + format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib", + ORT_VERSION, ORT_VERSION ), - ], - ) - }; + "libonnxruntime.dylib".to_string(), + )], + ); - #[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))] - let (archive, files) = ( - format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.so".to_string(), - )], - ); + // Linux x86_64 - supports both CPU and GPU + #[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))] + let (archive, files) = { + let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION); + ( + format!("{}.tgz", archive_name), + vec![ + ( + format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION), + "libonnxruntime.so".to_string(), + ), + ( + format!("{}/lib/libonnxruntime_providers_shared.so", archive_name), + "libonnxruntime_providers_shared.so".to_string(), + ), + ( + format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name), + "libonnxruntime_providers_cuda.so".to_string(), + ), + ], + ) + }; - // Linux aarch64 - CPU only (no GPU releases available) - #[cfg(all(target_os = "linux", target_arch = "aarch64"))] - let (archive, files) = ( - format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION), - vec![( - format!( - "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}", - ORT_VERSION, ORT_VERSION - ), - "libonnxruntime.so".to_string(), - )], - ); - - // Windows - supports both CPU and GPU - #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))] - let (archive, files) = { - let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION); - ( - format!("{}.zip", archive_name), - vec![ - ( - format!("{}/lib/onnxruntime.dll", archive_name), - "onnxruntime.dll".to_string(), - ), - ( - format!("{}/lib/onnxruntime_providers_shared.dll", archive_name), - "onnxruntime_providers_shared.dll".to_string(), + #[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))] + let (archive, files) = ( + format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}", + ORT_VERSION, ORT_VERSION ), - ( - format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name), - "onnxruntime_providers_cuda.dll".to_string(), + "libonnxruntime.so".to_string(), + )], + ); + + // Linux aarch64 - CPU only (no GPU releases available) + #[cfg(all(target_os = "linux", target_arch = "aarch64"))] + let (archive, files) = ( + format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION), + vec![( + format!( + "onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}", + ORT_VERSION, ORT_VERSION ), - ], - ) - }; + "libonnxruntime.so".to_string(), + )], + ); - #[cfg(all(target_os = "windows", target_arch = "x86_64", not(any(feature = "cuda", feature = "directml"))))] - let (archive, files) = ( - format!("onnxruntime-win-x64-{}.zip", ORT_VERSION), - vec![( - format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION), - "onnxruntime.dll".to_string(), - )], - ); - - #[cfg(not(any( - all(target_os = "macos", target_arch = "aarch64"), - all(target_os = "macos", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - all(target_os = "windows", target_arch = "x86_64"), - )))] - return Err(anyhow::anyhow!( - "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH." - )); + // Windows - supports both CPU and GPU + #[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))] + let (archive, files) = { + let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION); + ( + format!("{}.zip", archive_name), + vec![ + ( + format!("{}/lib/onnxruntime.dll", archive_name), + "onnxruntime.dll".to_string(), + ), + ( + format!("{}/lib/onnxruntime_providers_shared.dll", archive_name), + "onnxruntime_providers_shared.dll".to_string(), + ), + ( + format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name), + "onnxruntime_providers_cuda.dll".to_string(), + ), + ], + ) + }; + + #[cfg(all( + target_os = "windows", + target_arch = "x86_64", + not(any(feature = "cuda", feature = "directml")) + ))] + let (archive, files) = ( + format!("onnxruntime-win-x64-{}.zip", ORT_VERSION), + vec![( + format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION), + "onnxruntime.dll".to_string(), + )], + ); - Ok((format!("{}/{}", base, archive), files)) + #[cfg(not(any( + all(target_os = "macos", target_arch = "aarch64"), + all(target_os = "macos", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + all(target_os = "windows", target_arch = "x86_64"), + )))] + return Err(anyhow::anyhow!( + "Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH." + )); + + Ok((format!("{}/{}", base, archive), files)) } } From 2bd056f80cf20cd85d25d24dacdd79ccfb6f1e78 Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Mon, 8 Jun 2026 18:16:22 +0200 Subject: [PATCH 3/3] fix(onnx): honor force-CPU override for CoreML provider The CoreML branch in configure_auto_provider was the only execution provider missing the `if !force_cpu` guard, so is_force_cpu() was silently ignored on macOS and force_cpu went unread under a coreml-only build (failing clippy -D warnings). Add the guard so CoreML respects the override like CUDA/TensorRT/DirectML/MIGraphX. --- next-plaid-onnx/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 7a9c7bbb..617ab003 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -364,7 +364,7 @@ fn configure_auto_provider(builder: SessionBuilder) -> Result { } #[cfg(feature = "coreml")] - { + if !force_cpu { if let Ok(b) = builder .clone() .with_execution_providers([CoreMLExecutionProvider::default().build()])