Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions colgrep/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -999,11 +999,15 @@ impl IndexBuilder {
}
}
};
#[cfg(not(feature = "cuda"))]
#[cfg(not(any(
feature = "cuda",
feature = "directml",
feature = "migraphx",
feature = "coreml"
)))]
let (num_sessions, execution_provider) = {
let _ = num_units; // Silence unused warning when cuda feature is disabled
let _ = num_units;

// Initialize ONNX Runtime (CPU-only build)
crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;

Expand All @@ -1014,8 +1018,31 @@ impl IndexBuilder {
)
};

#[cfg(any(feature = "directml", feature = "migraphx", feature = "coreml"))]
#[cfg(not(feature = "cuda"))]
let (num_sessions, execution_provider) = {
let _ = num_units;

crate::onnx_runtime::ensure_onnx_runtime()
.context("Failed to initialize ONNX Runtime")?;

let provider = if cfg!(feature = "directml") {
ExecutionProvider::DirectML
} else if cfg!(feature = "migraphx") {
ExecutionProvider::MIGraphX
} else {
ExecutionProvider::CoreML
};

(
self.parallel_sessions
.unwrap_or_else(crate::config::get_default_cpu_parallel_sessions),
provider,
)
};

// Print model info after ONNX runtime is initialized (and any potential re-exec)
eprintln!("🤖 Model: {}", self.model_id);
eprintln!("🤖 Model: {} ({})", self.model_id, execution_provider);
eprintln!("📂 Building index...");

// Use runtime default for batch size (respects cuDNN availability)
Expand Down Expand Up @@ -3286,6 +3313,10 @@ impl Searcher {
AccelerationMode::Auto => {
if cfg!(feature = "coreml") {
ExecutionProvider::CoreML
} else if cfg!(feature = "directml") {
ExecutionProvider::DirectML
} else if cfg!(feature = "migraphx") {
ExecutionProvider::MIGraphX
} else {
ExecutionProvider::Cpu
}
Expand Down Expand Up @@ -3362,6 +3393,10 @@ impl Searcher {
AccelerationMode::Auto => {
if cfg!(feature = "coreml") {
ExecutionProvider::CoreML
} else if cfg!(feature = "directml") {
ExecutionProvider::DirectML
} else if cfg!(feature = "migraphx") {
ExecutionProvider::MIGraphX
} else {
ExecutionProvider::Cpu
}
Expand Down
279 changes: 154 additions & 125 deletions colgrep/src/onnx_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,10 @@ const ORT_LIB_NAME: &str = "libonnxruntime.so";
#[cfg(target_os = "windows")]
const ORT_LIB_NAME: &str = "onnxruntime.dll";

/// Whether to use GPU (CUDA) version of ONNX Runtime
#[cfg(feature = "cuda")]
const USE_GPU: bool = true;
#[cfg(not(feature = "cuda"))]
const USE_GPU: bool = false;

/// Subdirectory name for caching (gpu vs cpu)
#[cfg(feature = "cuda")]
#[cfg(any(feature = "cuda", feature = "directml"))]
const ORT_CACHE_SUBDIR: &str = "gpu";
#[cfg(not(feature = "cuda"))]
#[cfg(not(any(feature = "cuda", feature = "directml")))]
const ORT_CACHE_SUBDIR: &str = "cpu";

/// Ensure ONNX Runtime is available.
Expand Down Expand Up @@ -557,7 +551,10 @@ fn download_onnx_runtime() -> Result<PathBuf> {
&& cache_dir.join("onnxruntime_providers_shared.dll").exists()
&& cache_dir.join("onnxruntime_providers_cuda.dll").exists();

#[cfg(not(feature = "cuda"))]
#[cfg(all(feature = "directml", not(feature = "cuda")))]
let already_cached = lib_path.exists();

#[cfg(not(any(feature = "cuda", feature = "directml")))]
let already_cached = lib_path.exists();

if already_cached {
Expand All @@ -568,11 +565,12 @@ fn download_onnx_runtime() -> Result<PathBuf> {

let (url, files_to_extract) = get_download_info()?;

if USE_GPU {
eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION);
} else {
eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION);
}
#[cfg(feature = "cuda")]
eprintln!("⚙️ Runtime: ONNX {} (GPU/CUDA)", ORT_VERSION);
#[cfg(all(feature = "directml", not(feature = "cuda")))]
eprintln!("⚙️ Runtime: ONNX {} (GPU/DirectML)", ORT_VERSION);
#[cfg(not(any(feature = "cuda", feature = "directml")))]
eprintln!("⚙️ Runtime: ONNX {} (CPU)", ORT_VERSION);

// Download archive
let response = ureq::get(&url)
Expand All @@ -593,128 +591,159 @@ type FileToExtract = (String, String);

/// Get download URL and files to extract for current platform
fn get_download_info() -> Result<(String, Vec<FileToExtract>)> {
let base = format!(
"https://github.com/microsoft/onnxruntime/releases/download/v{}",
ORT_VERSION
);

// macOS - no GPU support via GitHub releases (use CoreML instead)
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
// DirectML: download from NuGet (Microsoft GPU package does not include DirectML)
#[cfg(all(
target_os = "windows",
target_arch = "x86_64",
feature = "directml",
not(feature = "cuda")
))]
return Ok((
format!(
"https://www.nuget.org/api/v2/package/Microsoft.ML.OnnxRuntime.DirectML/{}",
ORT_VERSION
),
vec![(
format!(
"onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.dylib".to_string(),
"runtimes/win-x64/native/onnxruntime.dll".to_string(),
"onnxruntime.dll".to_string(),
)],
);
));

#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
let (archive, files) = (
format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.dylib".to_string(),
)],
);

// Linux x86_64 - supports both CPU and GPU
#[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION);
(
format!("{}.tgz", archive_name),
vec![
(
format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION),
"libonnxruntime.so".to_string(),
),
(
format!("{}/lib/libonnxruntime_providers_shared.so", archive_name),
"libonnxruntime_providers_shared.so".to_string(),
// All other configurations: download from GitHub releases
#[cfg(not(all(
target_os = "windows",
target_arch = "x86_64",
feature = "directml",
not(feature = "cuda")
)))]
{
let base = format!(
"https://github.com/microsoft/onnxruntime/releases/download/v{}",
ORT_VERSION
);

// macOS - no GPU support via GitHub releases (use CoreML instead)
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-osx-arm64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-osx-arm64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
(
format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name),
"libonnxruntime_providers_cuda.so".to_string(),
"libonnxruntime.dylib".to_string(),
)],
);

#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
let (archive, files) = (
format!("onnxruntime-osx-x86_64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-osx-x86_64-{}/lib/libonnxruntime.{}.dylib",
ORT_VERSION, ORT_VERSION
),
],
)
};
"libonnxruntime.dylib".to_string(),
)],
);

#[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))]
let (archive, files) = (
format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.so".to_string(),
)],
);
// Linux x86_64 - supports both CPU and GPU
#[cfg(all(target_os = "linux", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-linux-x64-gpu-{}", ORT_VERSION);
(
format!("{}.tgz", archive_name),
vec![
(
format!("{}/lib/libonnxruntime.so.{}", archive_name, ORT_VERSION),
"libonnxruntime.so".to_string(),
),
(
format!("{}/lib/libonnxruntime_providers_shared.so", archive_name),
"libonnxruntime_providers_shared.so".to_string(),
),
(
format!("{}/lib/libonnxruntime_providers_cuda.so", archive_name),
"libonnxruntime_providers_cuda.so".to_string(),
),
],
)
};

// Linux aarch64 - CPU only (no GPU releases available)
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
"libonnxruntime.so".to_string(),
)],
);

// Windows - supports both CPU and GPU
#[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION);
(
format!("{}.zip", archive_name),
vec![
(
format!("{}/lib/onnxruntime.dll", archive_name),
"onnxruntime.dll".to_string(),
#[cfg(all(target_os = "linux", target_arch = "x86_64", not(feature = "cuda")))]
let (archive, files) = (
format!("onnxruntime-linux-x64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-x64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
(
format!("{}/lib/onnxruntime_providers_shared.dll", archive_name),
"onnxruntime_providers_shared.dll".to_string(),
),
(
format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name),
"onnxruntime_providers_cuda.dll".to_string(),
"libonnxruntime.so".to_string(),
)],
);

// Linux aarch64 - CPU only (no GPU releases available)
#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
let (archive, files) = (
format!("onnxruntime-linux-aarch64-{}.tgz", ORT_VERSION),
vec![(
format!(
"onnxruntime-linux-aarch64-{}/lib/libonnxruntime.so.{}",
ORT_VERSION, ORT_VERSION
),
],
)
};
"libonnxruntime.so".to_string(),
)],
);

#[cfg(all(target_os = "windows", target_arch = "x86_64", not(feature = "cuda")))]
let (archive, files) = (
format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
vec![(
format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
"onnxruntime.dll".to_string(),
)],
);

#[cfg(not(any(
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
)))]
return Err(anyhow::anyhow!(
"Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
));
// Windows - supports both CPU and GPU
#[cfg(all(target_os = "windows", target_arch = "x86_64", feature = "cuda"))]
let (archive, files) = {
let archive_name = format!("onnxruntime-win-x64-gpu-{}", ORT_VERSION);
(
format!("{}.zip", archive_name),
vec![
(
format!("{}/lib/onnxruntime.dll", archive_name),
"onnxruntime.dll".to_string(),
),
(
format!("{}/lib/onnxruntime_providers_shared.dll", archive_name),
"onnxruntime_providers_shared.dll".to_string(),
),
(
format!("{}/lib/onnxruntime_providers_cuda.dll", archive_name),
"onnxruntime_providers_cuda.dll".to_string(),
),
],
)
};

#[cfg(all(
target_os = "windows",
target_arch = "x86_64",
not(any(feature = "cuda", feature = "directml"))
))]
let (archive, files) = (
format!("onnxruntime-win-x64-{}.zip", ORT_VERSION),
vec![(
format!("onnxruntime-win-x64-{}/lib/onnxruntime.dll", ORT_VERSION),
"onnxruntime.dll".to_string(),
)],
);

Ok((format!("{}/{}", base, archive), files))
#[cfg(not(any(
all(target_os = "macos", target_arch = "aarch64"),
all(target_os = "macos", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "x86_64"),
all(target_os = "linux", target_arch = "aarch64"),
all(target_os = "windows", target_arch = "x86_64"),
)))]
return Err(anyhow::anyhow!(
"Unsupported platform. Please install ONNX Runtime manually and set ORT_DYLIB_PATH."
));

Ok((format!("{}/{}", base, archive), files))
}
}

/// Extract libraries from tgz archive
Expand Down
Loading
Loading