diff --git a/AGENTS.md b/AGENTS.md index eafb5f6..511aba1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,8 +17,7 @@ By default, trtx builds with real TensorRT-RTX and cudarc for CUDA operations: ```bash # Set environment -export TENSORRT_RTX_DIR=/path/to/tensorrt-rtx -export CUDA_ROOT=/usr/local/cuda +export LD_LIBRARY_PATH=/path/to/tensorrt-rtx/lib # Build (uses real TensorRT-RTX) make build # Debug build @@ -69,29 +68,6 @@ make publish # Publish to crates.io ## Architecture -### Three-Layer FFI Design - -``` -┌─────────────────────────┐ -│ Rust Safe API (trtx) │ <- RAII, Result, lifetimes -├─────────────────────────┤ -│ Raw FFI (trtx-sys) │ <- Bindgen-generated from wrapper.hpp -├─────────────────────────┤ -│ C Wrapper Layer │ <- wrapper.hpp/cpp (exception handling) -├─────────────────────────┤ -│ TensorRT-RTX C++ API │ <- NVIDIA library -└─────────────────────────┘ -``` - -**Why three layers?** -- TensorRT-RTX is C++ with exceptions and classes -- C wrapper provides `extern "C"` interface with opaque pointers -- C wrapper catches exceptions and converts to error codes -- Bindgen generates Rust FFI from C wrapper -- trtx crate provides safe Rust abstractions - -See `docs/FFI_GUIDE.md` for detailed FFI development workflow. - ### Two-Phase Workflow **Build Phase (AOT):** @@ -114,57 +90,6 @@ Runtime → Deserialize Engine → ExecutionContext → Bind Tensors → Execute - `trtx/src/executor.rs`: High-level executor API (rustnn-compatible) - `trtx/src/error.rs`: Error types -## FFI Bindings - -### When Modifying FFI - -1. **Update C wrapper** in `trtx-sys/wrapper.hpp` and `trtx-sys/wrapper.cpp` -2. **Rebuild** to regenerate bindings: `cargo clean -p trtx-sys && cargo build` -3. **Update mock** in `trtx-sys/build.rs` (`generate_mock_bindings`) and `trtx-sys/mock.c` -4. **Add safe wrapper** in appropriate `trtx/src/*.rs` file - -### Naming Conventions - -- C functions: `trtx__` (e.g., `trtx_cuda_engine_get_tensor_name`) -- Types: `Trtx` (e.g., `TrtxCudaEngine`) -- Constants: `TRTX_` (e.g., `TRTX_SUCCESS`) - -### Error Handling Pattern - -All FFI functions follow this signature: -```c -int32_t trtx_function_name( - // ... input parameters ... - char* error_msg, // Always second-to-last - size_t error_msg_len // Always last -); -``` - -Return `TRTX_SUCCESS` (0) on success, error code on failure. - -## Mock Mode - -Mock mode provides stub implementations for development without TensorRT-RTX: -- Development on machines without TensorRT-RTX (e.g., macOS) -- CI/CD on any platform -- API validation without GPU - -**Use `--features mock`** when TensorRT-RTX is not available. When adding new FFI functions, update both real bindings AND mock implementations in `trtx-sys/build.rs` and `trtx-sys/mock.c`. - -### Mock Implementation Files - -- `trtx-sys/build.rs`: `generate_mock_bindings()` function defines Rust FFI stubs -- `trtx-sys/mock.c`: C implementations that return `TRTX_SUCCESS` with dummy data - -## Important Notes - -### Build System - -- `trtx-sys/build.rs` uses bindgen to auto-generate `bindings.rs` from `wrapper.hpp` -- Generated file: `target/debug/build/trtx-sys-*/out/bindings.rs` -- Manual edits to generated files are **overwritten** on rebuild -- Changes must go in source files (`wrapper.hpp`, `wrapper.cpp`, `build.rs`) - ### Memory Management - **RAII everywhere**: Use `Drop` trait for automatic cleanup diff --git a/README.md b/README.md index a8518de..d55e8cf 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,8 @@ Mock mode provides stub implementations that allow you to: The `trtx` crate has the following Cargo features: -- `default`: "real", "dlopen_tensorrt_onnxparser", "dlopen_tensorrt_rtx", "onnxparser", "v_1_3" +- `default`: "dlopen_tensorrt_onnxparser", "dlopen_tensorrt_rtx", "onnxparser", "v_1_3" - `mock`: use this library in mock mode. TensorRT libraries and a Nvidia are no longer necessary for execution -- `real`: opposite of `mock` mode. TensorRT and Nvidia GPU are required for execution - `dlopen_tensorrt_rtx`: enables dynamic loading of the TensorRT library via `trtx::dynamically_load_tensorrt` - `dlopen_tensorrt_onnxparser`: enables dynamic loading of the TensorRT ONNX parser library via `trtx::dynamically_load_tensorrt_onnxparser` - `links_tensorrt_rtx`: links the TensorRT library, `trtx::dynamically_load_tensorrt` is now optional diff --git a/docs/DESIGN.md b/docs/DESIGN.md index c6fece5..2fbb602 100644 --- a/docs/DESIGN.md +++ b/docs/DESIGN.md @@ -72,7 +72,7 @@ trtx-rs/ ### 1. FFI Layer (trtx-sys) **Approach:** -- Use bindgen to generate raw bindings +- Use autocxx to generate raw bindings - Minimal manual intervention - Direct mapping to C++ API - No safety guarantees diff --git a/docs/FFI_GUIDE.md b/docs/FFI_GUIDE.md deleted file mode 100644 index b9a8e25..0000000 --- a/docs/FFI_GUIDE.md +++ /dev/null @@ -1,456 +0,0 @@ -# FFI Bindings Guide - -This document explains how the FFI bindings for TensorRT-RTX are built and how to update them when the C++ API changes. - -## Architecture Overview - -The FFI layer uses a **three-layer architecture**: - -``` -┌─────────────────────────┐ -│ Rust Safe API (trtx) │ <- High-level safe Rust -├─────────────────────────┤ -│ Raw FFI (trtx-sys) │ <- Auto-generated bindings -├─────────────────────────┤ -│ C Wrapper Layer │ <- Hand-written C interface -│ (wrapper.hpp/cpp) │ -├─────────────────────────┤ -│ TensorRT-RTX C++ API │ <- NVIDIA's C++ library -└─────────────────────────┘ -``` - -### Why Three Layers? - -1. **C Wrapper Layer** (`wrapper.hpp` + `wrapper.cpp`) - - TensorRT-RTX is a C++ API that cannot be directly bound by Rust - - Provides a C-compatible interface with `extern "C"` linkage - - Handles C++ exceptions and converts them to error codes - - Uses opaque pointer types to hide C++ objects - -2. **Bindgen Layer** (automatic) - - Reads `wrapper.hpp` header file - - Automatically generates Rust FFI declarations - - Outputs to `target/debug/build/trtx-sys-*/out/bindings.rs` - -3. **Safe Rust API** (in `trtx` crate) - - Wraps unsafe FFI calls in safe Rust abstractions - - Provides RAII types, Result-based error handling - - Adds Rust idioms and type safety - -## Build Process - -The build happens in `build.rs`: - -### 1. Compile C++ Wrapper - -```rust -// Build C++ wrapper -let mut build = cc::Build::new(); -build.cpp(true) - .file("wrapper.cpp") - .include(&include_dir) - .flag("-std=c++17"); - -build.compile("trtx_wrapper"); -``` - -This compiles `wrapper.cpp` into a static library that links against TensorRT-RTX. - -### 2. Generate Rust Bindings - -```rust -let bindings = bindgen::Builder::default() - .header("wrapper.hpp") // Input C header - .clang_arg(format!("-I{}", include_dir)) // Include TensorRT headers - .allowlist_function("trtx_.*") // Only expose trtx_* functions - .allowlist_type("TrtxLogger.*") // Only expose Trtx* types - .allowlist_var("TRTX_.*") // Only expose TRTX_* constants - .derive_debug(true) // Add Debug trait - .derive_default(true) // Add Default trait - .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - .generate() - .expect("Unable to generate bindings"); - -bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); -``` - -### 3. Link Libraries - -```rust -println!("cargo:rustc-link-search=native={}", lib_dir); -println!("cargo:rustc-link-lib=dylib=nvinfer_10"); -println!("cargo:rustc-link-lib=dylib=nvonnxparser_10"); -println!("cargo:rustc-link-lib=dylib=cudart"); -``` - -## How to Update Bindings When C++ API Changes - -### Scenario 1: Adding New TensorRT-RTX Functions - -**Example**: You want to expose `ICudaEngine::getTensorShape()` - -#### Step 1: Add C wrapper function in `wrapper.hpp` - -```c -int32_t trtx_cuda_engine_get_tensor_shape( - TrtxCudaEngine* engine, - const char* tensor_name, - int32_t* out_dims, - int32_t* out_nb_dims, - char* error_msg, - size_t error_msg_len -); -``` - -#### Step 2: Implement in `wrapper.cpp` - -```cpp -int32_t trtx_cuda_engine_get_tensor_shape( - TrtxCudaEngine* engine_ptr, - const char* tensor_name, - int32_t* out_dims, - int32_t* out_nb_dims, - char* error_msg, - size_t error_msg_len -) { - try { - auto* engine = reinterpret_cast(engine_ptr); - - nvinfer1::Dims dims = engine->getTensorShape(tensor_name); - *out_nb_dims = dims.nbDims; - - for (int i = 0; i < dims.nbDims; i++) { - out_dims[i] = dims.d[i]; - } - - return TRTX_SUCCESS; - } catch (const std::exception& e) { - copy_error_msg(error_msg, error_msg_len, e.what()); - return TRTX_ERROR_RUNTIME_ERROR; - } -} -``` - -#### Step 3: Rebuild to regenerate bindings - -```bash -cargo clean -p trtx-sys -cargo build -``` - -Bindgen automatically picks up the new function from `wrapper.hpp` and generates the Rust FFI declaration. - -#### Step 4: Add safe wrapper in `trtx/src/engine.rs` - -```rust -impl CudaEngine { - pub fn tensor_shape(&self, name: &str) -> Result> { - let c_name = CString::new(name)?; - let mut dims = [0i32; 8]; // Max 8 dimensions - let mut nb_dims = 0i32; - - unsafe { - check_error(trtx_cuda_engine_get_tensor_shape( - self.ptr, - c_name.as_ptr(), - dims.as_mut_ptr(), - &mut nb_dims, - error_msg.as_mut_ptr(), - ERROR_MSG_LEN, - ))?; - } - - Ok(dims[..nb_dims as usize].to_vec()) - } -} -``` - -### Scenario 2: Updating Existing Function Signatures - -**Example**: TensorRT changes `setMemoryPoolLimit()` signature - -#### Step 1: Update `wrapper.hpp` signature - -```c -// Old -int32_t trtx_builder_config_set_memory_pool_limit( - TrtxBuilderConfig* config, - int32_t pool_type, - size_t pool_size, - char* error_msg, - size_t error_msg_len -); - -// New - added pool_flags parameter -int32_t trtx_builder_config_set_memory_pool_limit( - TrtxBuilderConfig* config, - int32_t pool_type, - size_t pool_size, - uint32_t pool_flags, - char* error_msg, - size_t error_msg_len -); -``` - -#### Step 2: Update implementation in `wrapper.cpp` - -```cpp -int32_t trtx_builder_config_set_memory_pool_limit( - TrtxBuilderConfig* config_ptr, - int32_t pool_type, - size_t pool_size, - uint32_t pool_flags, - char* error_msg, - size_t error_msg_len -) { - try { - auto* config = reinterpret_cast(config_ptr); - config->setMemoryPoolLimit( - static_cast(pool_type), - pool_size, - pool_flags // New parameter - ); - return TRTX_SUCCESS; - } catch (const std::exception& e) { - copy_error_msg(error_msg, error_msg_len, e.what()); - return TRTX_ERROR_RUNTIME_ERROR; - } -} -``` - -#### Step 3: Rebuild - -```bash -cargo clean -p trtx-sys -cargo build -``` - -This will fail if the high-level `trtx` crate still uses the old signature. - -#### Step 4: Update all call sites in `trtx` crate - -Find and update all uses: - -```bash -# Find all uses -rg "trtx_builder_config_set_memory_pool_limit" trtx/src/ - -# Update each call site to pass the new parameter -``` - -### Scenario 3: Adding Mock Support for New Functions - -When adding new functions, also update the mock bindings for development without TensorRT. - -#### Update `build.rs` mock bindings - -In the `generate_mock_bindings()` function, add: - -```rust -fn generate_mock_bindings(out_path: &Path) { - let mock_bindings = r#" - // ... existing mock code ... - - // Add new function declaration - extern "C" { - pub fn trtx_cuda_engine_get_tensor_shape( - engine: *mut TrtxCudaEngine, - tensor_name: *const ::std::os::raw::c_char, - out_dims: *mut i32, - out_nb_dims: *mut i32, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - } - "#; - - std::fs::write(out_path.join("bindings.rs"), mock_bindings) - .expect("Couldn't write mock bindings!"); -} -``` - -#### Update `mock.c` implementation - -```c -int32_t trtx_cuda_engine_get_tensor_shape( - TrtxCudaEngine* engine, - const char* tensor_name, - int32_t* out_dims, - int32_t* out_nb_dims, - char* error_msg, - size_t error_msg_len -) { - // Mock implementation: return dummy shape - *out_nb_dims = 3; - out_dims[0] = 1; - out_dims[1] = 224; - out_dims[2] = 224; - return TRTX_SUCCESS; -} -``` - -## Testing Changes - -### 1. Test with Mock Mode - -```bash -cargo test --features mock -``` - -### 2. Test with Real TensorRT-RTX - -```bash -TENSORRT_RTX_DIR=/usr/local/tensorrt-rtx cargo test -``` - -### 3. Check Generated Bindings - -```bash -# View generated bindings -cat target/debug/build/trtx-sys-*/out/bindings.rs | less - -# Search for your new function -rg "trtx_cuda_engine_get_tensor_shape" target/debug/build/trtx-sys-*/out/bindings.rs -``` - -## Common Issues - -### Issue 1: Bindgen Can't Find Headers - -**Error**: `fatal error: 'NvInfer.h' file not found` - -**Solution**: Set `TENSORRT_RTX_DIR` or update include path in `build.rs`: - -```bash -export TENSORRT_RTX_DIR=/path/to/tensorrt-rtx -cargo build -``` - -### Issue 2: Linker Can't Find Libraries - -**Error**: `ld: library not found for -lnvinfer_10` - -**Solution**: Check library search path and library names: - -```bash -ls $TENSORRT_RTX_DIR/lib/ -# Update build.rs if library names have changed -``` - -### Issue 3: C++ Exceptions Crash Rust - -**Problem**: C++ exception crosses FFI boundary - -**Solution**: Always wrap C++ code in try-catch in `wrapper.cpp`: - -```cpp -try { - // C++ code that might throw -} catch (const std::exception& e) { - copy_error_msg(error_msg, error_msg_len, e.what()); - return TRTX_ERROR_RUNTIME_ERROR; -} -``` - -### Issue 4: Function Not Exposed to Rust - -**Problem**: New function in wrapper.hpp but not appearing in Rust - -**Solution**: Check bindgen allowlist patterns in `build.rs`: - -```rust -.allowlist_function("trtx_.*") // Must match your function name -``` - -## Best Practices - -### Naming Conventions - -- **C functions**: `trtx__` (e.g., `trtx_cuda_engine_get_tensor_shape`) -- **Types**: `Trtx` (e.g., `TrtxCudaEngine`) -- **Constants**: `TRTX_` (e.g., `TRTX_SUCCESS`) - -### Error Handling Pattern - -All FFI functions should follow this pattern: - -```c -int32_t trtx_function_name( - // ... input parameters ... - char* error_msg, // Always second-to-last - size_t error_msg_len // Always last -); -``` - -Return values: -- `TRTX_SUCCESS` (0) on success -- Error code on failure (error message written to `error_msg`) - -### Memory Management - -- **Rust owns**: Rust-allocated memory passed to C (must free in Rust) -- **C owns**: C-allocated objects (use destroy functions to free) -- **Return buffers**: C allocates, Rust calls `trtx_free_buffer()` to free - -Example: -```rust -// C allocates buffer -let mut data: *mut c_void = std::ptr::null_mut(); -let mut size: usize = 0; -trtx_builder_build_serialized_network(..., &mut data, &mut size, ...); - -// Rust must free when done -trtx_free_buffer(data); -``` - -### Type Safety - -Use opaque pointer types in C: -```c -typedef struct TrtxCudaEngine TrtxCudaEngine; // Opaque -``` - -Cast to real type only in C++: -```cpp -auto* engine = reinterpret_cast(engine_ptr); -``` - -## Development Workflow - -### Quick Reference - -```bash -# 1. Modify wrapper.hpp and wrapper.cpp -vim wrapper.hpp wrapper.cpp - -# 2. Rebuild bindings -cargo clean -p trtx-sys && cargo build - -# 3. Check generated bindings -cat target/debug/build/trtx-sys-*/out/bindings.rs | rg "your_function" - -# 4. Update safe Rust wrapper in trtx crate -vim ../trtx/src/your_module.rs - -# 5. Test -cargo test - -# 6. Test with mock -cargo test --features mock -``` - -## References - -- [Bindgen User Guide](https://rust-lang.github.io/rust-bindgen/) -- [The Rustonomicon - FFI](https://doc.rust-lang.org/nomicon/ffi.html) -- [TensorRT Documentation](https://docs.nvidia.com/deeplearning/tensorrt/) - -## Questions? - -If you encounter issues not covered here, check: - -1. Bindgen output for errors: `cargo build -vv` -2. Generated bindings: `cat target/debug/build/trtx-sys-*/out/bindings.rs` -3. Linker output: `cargo build -vv 2>&1 | grep "ld:"` diff --git a/trtx-sys/Cargo.toml b/trtx-sys/Cargo.toml index 062f5f9..da0a6fb 100644 --- a/trtx-sys/Cargo.toml +++ b/trtx-sys/Cargo.toml @@ -13,14 +13,12 @@ categories = ["external-ffi-bindings"] [build-dependencies] autocxx-build = "0.30" -bindgen = "0.69" +bindgen = "0.72" cc = "1.0" regex = "1.0" [features] default = ["v_1_3"] -# Mock mode for development without TensorRT-RTX installed -mock = [] onnxparser = [] link_tensorrt_rtx = [] link_tensorrt_onnxparser = ["onnxparser"] diff --git a/trtx-sys/build.rs b/trtx-sys/build.rs index 3de6092..fbf7979 100644 --- a/trtx-sys/build.rs +++ b/trtx-sys/build.rs @@ -28,6 +28,20 @@ fn prepare_transformed_headers(trt_dir: &Path, out_dir: &Path) -> PathBuf { let replaced = param_regex.replace_all(&replaced, "- `$1`"); let replaced = replaced .replace("std::size_t", "size_t") + // workaround autocxx limitation where there can't be the same type in different + // namespaces + .replace("namespace v_1_0", "inline namespace v_1_0") + .replace("namespace impl", "inline namespace impl") + .replace("ErrorCode getErrorCode", "int32_t getErrorCode") + .replace( + "bool reportError(ErrorCode val", + "bool reportError(int32_t val", + ) + .replace("noexcept", "") + .replace( + "void log(Severity severity, AsciiChar const* msg)", + "void log(int32_t severity, char const* msg)", + ) .replace("//!", "///") .replace(r"\returns", " - Returns "); let replaced = doxy_regex.replace_all(&replaced, ""); @@ -83,9 +97,15 @@ fn generate_enum_bindings(crate_root: &str, out_path: &Path) { ".*Platform", ".*Level", ".*Capability", + ".*ErrorCode", ".*Flag", ".*Selector", ".*Transformation", + ".*Location", + ".*Role", + ".*Limit", + ".*AttentionNormalizationOp", + ".*SeekPosition", ] { builder = builder.allowlist_type(pattern); } @@ -101,6 +121,8 @@ fn generate_enum_bindings(crate_root: &str, out_path: &Path) { let mut output = bindings.to_string(); output = output.replace("extern \"C\"", "extern \"system\""); output = output.replace("nvinfer1_", ""); + output = output.replace("ILogger_", ""); + output = output.replace("impl__EnumMaxImpl", "impl_EnumMaxImpl"); let out_file = out_path.join("enums.rs"); let mut f = File::create(&out_file).expect("Failed to create enums.rs"); @@ -121,15 +143,7 @@ fn main() { generate_enum_bindings(&crate_root, &out_path); // Check if we're in mock mode - if env::var("CARGO_FEATURE_MOCK").is_ok() { - println!("cargo:warning=Building in MOCK mode - no TensorRT-RTX required"); - - // Build mock C implementation - cc::Build::new().file("mock.c").compile("trtx_mock"); - - generate_mock_bindings(&out_path); - return; - } + let is_mock = env::var("CARGO_FEATURE_MOCK").is_ok(); println!("cargo:rerun-if-changed=src/lib.rs"); println!("cargo:rerun-if-changed=logger_bridge.hpp"); @@ -188,6 +202,9 @@ fn main() { .include(&transformed_include_dir) .include(&cuda_shim_include_dir); + if is_mock { + cc_build.define("TRTX_MOCK_MODE", "1"); + } if link_trt { cc_build.define("TRTX_LINK_TENSORRT_RTX", "1"); } @@ -242,706 +259,3 @@ fn main() { println!("cargo:rerun-if-changed=src/lib.rs"); } - -fn generate_mock_bindings(out_path: &std::path::Path) { - let mock_bindings = r#" -// Mock bindings for development without TensorRT-RTX - -// Error codes -pub const TRTX_SUCCESS: i32 = 0; -pub const TRTX_ERROR_INVALID_ARGUMENT: i32 = 1; -pub const TRTX_ERROR_OUT_OF_MEMORY: i32 = 2; -pub const TRTX_ERROR_RUNTIME_ERROR: i32 = 3; -pub const TRTX_ERROR_CUDA_ERROR: i32 = 4; -pub const TRTX_ERROR_UNKNOWN: i32 = 99; - -// Logger severity levels -#[repr(u32)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum TrtxLoggerSeverity { - TRTX_SEVERITY_INTERNAL_ERROR = 0, - TRTX_SEVERITY_ERROR = 1, - TRTX_SEVERITY_WARNING = 2, - TRTX_SEVERITY_INFO = 3, - TRTX_SEVERITY_VERBOSE = 4, -} - -// Opaque types (just markers in mock mode) -#[repr(C)] -pub struct TrtxLogger { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxBuilder { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxBuilderConfig { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxNetworkDefinition { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxRuntime { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxCudaEngine { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxExecutionContext { - _unused: [u8; 0], -} - -#[repr(C)] -pub struct TrtxOnnxParser { - _unused: [u8; 0], -} - -// Logger callback type -pub type TrtxLoggerCallback = ::std::option::Option< - unsafe extern "C" fn( - user_data: *mut ::std::os::raw::c_void, - severity: TrtxLoggerSeverity, - msg: *const ::std::os::raw::c_char, - ), ->; - -// Stub implementations that return success -extern "C" { - pub fn trtx_logger_create( - callback: TrtxLoggerCallback, - user_data: *mut ::std::os::raw::c_void, - out_logger: *mut *mut TrtxLogger, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_logger_destroy(logger: *mut TrtxLogger); - - pub fn trtx_builder_create( - logger: *mut TrtxLogger, - out_builder: *mut *mut TrtxBuilder, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_builder_destroy(builder: *mut TrtxBuilder); - - pub fn trtx_builder_create_network( - builder: *mut TrtxBuilder, - flags: u32, - out_network: *mut *mut TrtxNetworkDefinition, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_builder_create_builder_config( - builder: *mut TrtxBuilder, - out_config: *mut *mut TrtxBuilderConfig, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_builder_build_serialized_network( - builder: *mut TrtxBuilder, - network: *mut TrtxNetworkDefinition, - config: *mut TrtxBuilderConfig, - out_data: *mut *mut ::std::os::raw::c_void, - out_size: *mut usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_builder_config_destroy(config: *mut TrtxBuilderConfig); - - pub fn trtx_builder_config_set_memory_pool_limit( - config: *mut TrtxBuilderConfig, - pool_type: i32, - pool_size: usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_network_destroy(network: *mut TrtxNetworkDefinition); - - pub fn trtx_runtime_create( - logger: *mut TrtxLogger, - out_runtime: *mut *mut TrtxRuntime, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_runtime_destroy(runtime: *mut TrtxRuntime); - - pub fn trtx_runtime_deserialize_cuda_engine( - runtime: *mut TrtxRuntime, - data: *const ::std::os::raw::c_void, - size: usize, - out_engine: *mut *mut TrtxCudaEngine, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_engine_destroy(engine: *mut TrtxCudaEngine); - - pub fn trtx_cuda_engine_create_execution_context( - engine: *mut TrtxCudaEngine, - out_context: *mut *mut TrtxExecutionContext, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_engine_get_tensor_name( - engine: *mut TrtxCudaEngine, - index: i32, - out_name: *mut *const ::std::os::raw::c_char, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_engine_get_nb_io_tensors( - engine: *mut TrtxCudaEngine, - out_count: *mut i32, - ) -> i32; - - pub fn trtx_execution_context_destroy(context: *mut TrtxExecutionContext); - - pub fn trtx_execution_context_set_tensor_address( - context: *mut TrtxExecutionContext, - tensor_name: *const ::std::os::raw::c_char, - data: *mut ::std::os::raw::c_void, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_execution_context_enqueue_v3( - context: *mut TrtxExecutionContext, - cuda_stream: *mut ::std::os::raw::c_void, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_free_buffer(buffer: *mut ::std::os::raw::c_void); - - // ONNX Parser functions - pub fn trtx_onnx_parser_create( - network: *mut TrtxNetworkDefinition, - logger: *mut TrtxLogger, - out_parser: *mut *mut TrtxOnnxParser, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_onnx_parser_destroy(parser: *mut TrtxOnnxParser); - - pub fn trtx_onnx_parser_parse( - parser: *mut TrtxOnnxParser, - model_data: *const ::std::os::raw::c_void, - model_size: usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - // CUDA Memory Management functions - pub fn trtx_cuda_malloc( - ptr: *mut *mut ::std::os::raw::c_void, - size: usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_free( - ptr: *mut ::std::os::raw::c_void, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_memcpy_host_to_device( - dst: *mut ::std::os::raw::c_void, - src: *const ::std::os::raw::c_void, - size: usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_memcpy_device_to_host( - dst: *mut ::std::os::raw::c_void, - src: *const ::std::os::raw::c_void, - size: usize, - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_synchronize( - error_msg: *mut ::std::os::raw::c_char, - error_msg_len: usize, - ) -> i32; - - pub fn trtx_cuda_get_default_stream() -> *mut ::std::os::raw::c_void; -} - -// Mock nvinfer1 module - stub types for trtx crate compatibility in mock mode -pub mod nvinfer1 { - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum DataType { - kFLOAT = 0, - kHALF = 1, - kINT8 = 2, - kINT32 = 3, - kBOOL = 4, - kUINT8 = 5, - kFP8 = 6, - kBF16 = 7, - kINT64 = 8, - kINT4 = 9, - kFP4 = 10, - kE8M0 = 11, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum TensorFormat { - kLINEAR = 0, - kCHW2 = 1, - kHWC8 = 2, - kCHW4 = 3, - kCHW16 = 4, - kCHW32 = 5, - kDHWC8 = 6, - kCDHW32 = 7, - kHWC = 8, - kDLA_LINEAR = 9, - kDLA_HWC4 = 10, - kHWC16 = 11, - kDHWC = 12, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ActivationType { - kRELU = 0, - kSIGMOID = 1, - kTANH = 2, - kLEAKY_RELU = 3, - kELU = 4, - kSELU = 5, - kSOFTSIGN = 6, - kSOFTPLUS = 7, - kCLIP = 8, - kHARD_SIGMOID = 9, - kSCALED_TANH = 10, - kTHRESHOLDED_RELU = 11, - kGELU_ERF = 12, - kGELU_TANH = 13, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum PoolingType { - kMAX = 0, - kAVERAGE = 1, - kMAX_AVERAGE_BLEND = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ElementWiseOperation { - kSUM = 0, - kPROD = 1, - kMAX = 2, - kMIN = 3, - kSUB = 4, - kDIV = 5, - kPOW = 6, - kFLOOR_DIV = 7, - kAND = 8, - kOR = 9, - kXOR = 10, - kEQUAL = 11, - kGREATER = 12, - kLESS = 13, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum MatrixOperation { - kNONE = 0, - kTRANSPOSE = 1, - kVECTOR = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum UnaryOperation { - kEXP = 0, - kLOG = 1, - kSQRT = 2, - kRECIP = 3, - kABS = 4, - kNEG = 5, - kSIN = 6, - kCOS = 7, - kTAN = 8, - kSINH = 9, - kCOSH = 10, - kASIN = 11, - kACOS = 12, - kATAN = 13, - kASINH = 14, - kACOSH = 15, - kATANH = 16, - kCEIL = 17, - kFLOOR = 18, - kERF = 19, - kNOT = 20, - kROUND = 21, - kSIGN = 22, - kISINF = 23, - kISNAN = 24, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ReduceOperation { - kSUM = 0, - kPROD = 1, - kMAX = 2, - kMIN = 3, - kAVG = 4, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum CumulativeOperation { - kSUM = 0, - kPROD = 1, - kMIN = 2, - kMAX = 3, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum GatherMode { - kDEFAULT = 0, - kELEMENT = 1, - kND = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ScatterMode { - kELEMENT = 0, - kND = 1, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum InterpolationMode { - kNEAREST = 0, - kLINEAR = 1, - kCUBIC = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ResizeCoordinateTransformation { - kASYMMETRIC = 0, - kALIGN_CORNERS = 1, - kHALF_PIXEL = 2, - kHALF_PIXEL_SYMMETRIC = 3, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ResizeRoundMode { - kFLOOR = 0, - kCEIL = 1, - kROUND = 2, - kHALF_UP = 3, - kHALF_DOWN = 4, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ResizeSelector { - kFORMULA = 0, - kSIZES = 1, - kUPPER = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum TopKOperation { - kMAX = 0, - kMIN = 1, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ScaleMode { - kUNIFORM = 0, - kCHANNEL = 1, - kELEMENTWISE = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ExecutionContextAllocationStrategy { - kSTATIC = 0, - kUSER_MANAGED = 1, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum MemoryPoolType { - kWORKSPACE = 0, - kDLA_MANAGED_SRAM = 1, - kDLA_LOCAL_DRAM = 2, - kDLA_GLOBAL_DRAM = 3, - kTACTIC_DRAM = 4, - kTACTIC_SHARED_MEMORY = 5, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ProfilingVerbosity { - kLAYER_NAMES_ONLY = 0, - kNONE = 1, - kDETAILED = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum EngineCapability { - kSTANDARD = 0, - kSAFETY = 1, - kDLA_STANDALONE = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum BuilderFlag { - kFP16 = 0, - kINT8 = 1, - kDEBUG = 2, - kGPU_FALLBACK = 3, - kREFIT = 4, - kDISABLE_TIMING_CACHE = 5, - kTF32 = 6, - kSPARSE_WEIGHTS = 7, - kSAFETY_SCOPE = 8, - kOBEY_PRECISION_CONSTRAINTS = 9, - kPREFER_PRECISION_CONSTRAINTS = 10, - kDIRECT_IO = 11, - kREJECT_EMPTY_ALGORITHMS = 12, - kVERSION_COMPATIBLE = 13, - kEXCLUDE_LEAN_RUNTIME = 14, - kFP8 = 15, - kERROR_ON_TIMING_CACHE_MISS = 16, - kBF16 = 17, - kDISABLE_COMPILATION_CACHE = 18, - kSTRIP_PLAN = 19, - kREFIT_IDENTICAL = 20, - kWEIGHT_STREAMING = 21, - kINT4 = 22, - kREFIT_INDIVIDUAL = 23, - kSTRICT_NANS = 24, - kMONITOR_MEMORY = 25, - kFP4 = 26, - kEDITABLE_TIMING_CACHE = 27, - kDISTRIBUTIVE_INDEPENDENCE = 28, - } - - pub type BuilderFlags = u32; - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum DeviceType { - kGPU = 0, - kDLA = 1, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum TacticSource { - kCUBLAS = 0, - kCUBLAS_LT = 1, - kCUDNN = 2, - kEDGE_MASK_CONVOLUTIONS = 3, - kJIT_CONVOLUTIONS = 4, - } - - pub type TacticSources = u32; - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum PreviewFeature { - kPROFILE_SHARING_0806 = 0, - kALIASED_PLUGIN_IO_10_03 = 1, - kRUNTIME_ACTIVATION_RESIZE_10_10 = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum HardwareCompatibilityLevel { - kNONE = 0, - kAMPERE_PLUS = 1, - kSAME_COMPUTE_CAPABILITY = 2, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum RuntimePlatform { - kSAME_AS_BUILD = 0, - kWINDOWS_AMD64 = 1, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum TilingOptimizationLevel { - kNONE = 0, - kFAST = 1, - kMODERATE = 2, - kFULL = 3, - } - - #[repr(i32)] - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - pub enum ComputeCapability { - kNONE = 0, - kCURRENT = 1, - kSM75 = 75, - kSM80 = 80, - kSM86 = 86, - kSM89 = 89, - kSM120 = 120, - } - - // Layer interface types (opaque stubs for mock - only used in type positions) - #[repr(C)] - pub struct IShuffleLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IActivationLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IResizeLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ITopKLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IGatherLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IScatterLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ISelectLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IMatrixMultiplyLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ISoftMaxLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IReduceLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ICumulativeLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IPoolingLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IConvolutionLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IDeconvolutionLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IQuantizeLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IDequantizeLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IConstantLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IConcatenationLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IScaleLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ISliceLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IUnaryLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IIdentityLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct IPaddingLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ICastLayer { _unused: [u8; 0] } - #[repr(C)] - pub struct ITensor { _unused: [u8; 0] } - #[repr(C)] - pub struct ILayer { _unused: [u8; 0] } - #[repr(C)] - pub struct INetworkDefinition { _unused: [u8; 0] } - #[repr(C)] - pub struct ICudaEngine { _unused: [u8; 0] } - #[repr(C)] - pub struct IExecutionContext { _unused: [u8; 0] } - - #[repr(C)] - #[derive(Clone, Copy)] - pub struct Weights { - pub type_: DataType, - pub values: *const ::std::ffi::c_void, - pub count: i64, - } - - impl Weights { - pub fn new_float(values_ptr: *const ::std::ffi::c_void, count_val: i64) -> Self { - Self { type_: DataType::kFLOAT, values: values_ptr, count: count_val } - } - pub fn new_with_type( - data_type: DataType, - values_ptr: *const ::std::ffi::c_void, - count_val: i64, - ) -> Self { - Self { type_: data_type, values: values_ptr, count: count_val } - } - } -} - -// Dims64/Dims - mock version -#[repr(C)] -#[derive(Clone, Copy, Debug)] -pub struct Dims64 { - pub nbDims: i32, - pub d: [i64; 8], -} - -pub type Dims = Dims64; - -impl Dims64 { - pub fn from_slice(dims: &[i64]) -> Self { - let mut d = [0i64; 8]; - let nb_dims = dims.len().min(8) as i32; - d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]); - Self { nbDims: nb_dims, d } - } - pub fn new_2d(d0: i64, d1: i64) -> Self { - Self { nbDims: 2, d: [d0, d1, 0, 0, 0, 0, 0, 0] } - } - pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self { - Self { nbDims: 3, d: [d0, d1, d2, 0, 0, 0, 0, 0] } - } - pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self { - Self { nbDims: 4, d: [d0, d1, d2, d3, 0, 0, 0, 0] } - } -} - -// ResizeMode is InterpolationMode in TensorRT -pub use nvinfer1::InterpolationMode as ResizeMode; -"#; - - std::fs::write(out_path.join("bindings.rs"), mock_bindings) - .expect("Couldn't write mock bindings!"); -} diff --git a/trtx-sys/logger_bridge.cpp b/trtx-sys/logger_bridge.cpp index 8948d6a..60a1148 100644 --- a/trtx-sys/logger_bridge.cpp +++ b/trtx-sys/logger_bridge.cpp @@ -1,49 +1,52 @@ /** * Logger Bridge for TensorRT-RTX Rust Bindings - * + * * This file provides C wrapper functions for TensorRT-RTX C++ API. - * While we use autocxx for most C++ bindings, some wrappers are still necessary. - * + * While we use autocxx for most C++ bindings, some wrappers are still + * necessary. + * * ## Architecture - * + * * ``` * Rust (trtx) → Raw FFI (trtx-sys) → logger_bridge.cpp → TensorRT C++ + autocxx * ``` - * + * * ## Why These Wrappers Exist - * + * * ### NECESSARY WRAPPERS (Cannot be removed): - * - * 1. **Logger Bridge (lines 29-52)**: + * + * 1. **Logger Bridge (lines 29-52)**: * - Rust cannot implement C++ virtual classes * - RustLoggerImpl forwards virtual method calls to Rust callbacks * - REQUIRED: No alternative - * + * * 2. **Factory Functions (lines 55-91)**: * - createInferBuilder/Runtime take `ILogger&` references * - autocxx struggles with C++ reference parameters * - REQUIRED: Simplest solution for reference params - * + * * 3. **CUDA Wrappers (lines 658-677)**: * - Bridge between std::ffi::c_void and autocxx::c_void * - Type compatibility issue between Rust and autocxx types * - KEEP FOR NOW: Could be removed with codebase-wide type migration - * + * * ### POTENTIALLY REDUNDANT WRAPPERS: - * + * * 4. **TensorRT Method Wrappers (lines 94-657)**: * - Builder, Network, Tensor, Engine, Context methods - * - autocxx CAN generate these with `generate!("nvinfer1::INetworkDefinition")` + * - autocxx CAN generate these with + * `generate!("nvinfer1::INetworkDefinition")` * - POTENTIALLY REMOVABLE: ~75% code reduction if refactored - * - STATUS: Kept for now due to stability, could be migrated to direct autocxx calls - * + * - STATUS: Kept for now due to stability, could be migrated to direct + * autocxx calls + * * ## Why Not Full autocxx? - * + * * We TRIED to use autocxx for everything but encountered: * - Type mismatches (autocxx::c_void vs std::ffi::c_void) * - Reference parameter handling issues * - Virtual method/callback complications - * + * * ## See Also * - docs/LOGGER_BRIDGE_ANALYSIS.md - Detailed analysis of each function * - docs/REFACTORING_SUMMARY.md - Test results and recommendations @@ -51,6 +54,7 @@ */ #include "logger_bridge.hpp" +#include #include #include #include @@ -64,50 +68,51 @@ // C++ implementation of ILogger that bridges to Rust class RustLoggerImpl : public nvinfer1::ILogger { public: - RustLoggerImpl(RustLogCallback callback, void* user_data) - : callback_(callback), user_data_(user_data) {} + RustLoggerImpl(RustLogCallback callback, void *user_data) + : callback_(callback), user_data_(user_data) {} - void log(Severity severity, const char* msg) noexcept override { - if (callback_) { - callback_(user_data_, static_cast(severity), msg); - } + void log(int32_t severity, const char *msg) noexcept override { + if (callback_) { + callback_(user_data_, static_cast(severity), msg); } + } private: - RustLogCallback callback_; - void* user_data_; + RustLogCallback callback_; + void *user_data_; }; // Opaque struct that holds the logger implementation struct RustLoggerBridge { - RustLoggerImpl* impl; + RustLoggerImpl *impl; }; extern "C" { -RustLoggerBridge* create_rust_logger_bridge(RustLogCallback callback, void* user_data) { - if (!callback) { - return nullptr; - } - - try { - auto* bridge = new RustLoggerBridge(); - bridge->impl = new RustLoggerImpl(callback, user_data); - return bridge; - } catch (...) { - return nullptr; - } +RustLoggerBridge *create_rust_logger_bridge(RustLogCallback callback, + void *user_data) { + if (!callback) { + return nullptr; + } + + try { + auto *bridge = new RustLoggerBridge(); + bridge->impl = new RustLoggerImpl(callback, user_data); + return bridge; + } catch (...) { + return nullptr; + } } -void destroy_rust_logger_bridge(RustLoggerBridge* logger) { - if (logger) { - delete logger->impl; - delete logger; - } +void destroy_rust_logger_bridge(RustLoggerBridge *logger) { + if (logger) { + delete logger->impl; + delete logger; + } } -nvinfer1::ILogger* get_logger_interface(RustLoggerBridge* logger) { - return logger ? logger->impl : nullptr; +nvinfer1::ILogger *get_logger_interface(RustLoggerBridge *logger) { + return logger ? logger->impl : nullptr; } //============================================================================== @@ -118,316 +123,389 @@ nvinfer1::ILogger* get_logger_interface(RustLoggerBridge* logger) { // Factory functions for TensorRT #ifdef TRTX_LINK_TENSORRT_RTX -void* create_infer_builder(void* logger) { - if (!logger) { - return nullptr; - } - try { - auto* ilogger = static_cast(logger); - return nvinfer1::createInferBuilder(*ilogger); - } catch (...) { - return nullptr; - } +void *create_infer_builder(void *logger) { + if (!logger) { + return nullptr; + } + try { + auto *ilogger = static_cast(logger); + return nvinfer1::createInferBuilder(*ilogger); + } catch (...) { + return nullptr; + } } -void* create_infer_runtime(void* logger) { - if (!logger) { - return nullptr; - } - try { - auto* ilogger = static_cast(logger); - return nvinfer1::createInferRuntime(*ilogger); - } catch (...) { - return nullptr; - } +void *create_infer_runtime(void *logger) { + if (!logger) { + return nullptr; + } + try { + auto *ilogger = static_cast(logger); + return nvinfer1::createInferRuntime(*ilogger); + } catch (...) { + return nullptr; + } +} +void *create_infer_refitter(void *engine, void *logger) { + if (!engine || !logger) { + return nullptr; + } + try { + auto *iengine = static_cast(engine); + auto *ilogger = static_cast(logger); + return nvinfer1::createInferRefitter(*iengine, *ilogger); + } catch (...) { + return nullptr; + } } #endif #ifdef TRTX_LINK_TENSORRT_ONNXPARSER // ONNX Parser factory function -void* create_onnx_parser(void* network, void* logger) { - if (!network || !logger) { - return nullptr; - } - try { - auto* inetwork = static_cast(network); - auto* ilogger = static_cast(logger); - return nvonnxparser::createParser(*inetwork, *ilogger); - } catch (...) { - return nullptr; - } +void *create_onnx_parser(void *network, void *logger) { + if (!network || !logger) { + return nullptr; + } + try { + auto *inetwork = static_cast(network); + auto *ilogger = static_cast(logger); + return nvonnxparser::createParser(*inetwork, *ilogger); + } catch (...) { + return nullptr; + } } #endif -//============================================================================== -// SECTION 3: BUILDER & CONFIG METHODS (POTENTIALLY REDUNDANT) -//============================================================================== -// These wrap IBuilder and IBuilderConfig methods. -// autocxx CAN generate these with generate!("nvinfer1::IBuilder"). -// FUTURE: Consider migrating to direct autocxx calls (see REFACTORING_SUMMARY.md) - -// Builder methods -void builder_config_set_memory_pool_limit(void* config, int32_t pool_type, size_t limit) { - if (!config) return; - try { - auto* iconfig = static_cast(config); - iconfig->setMemoryPoolLimit(static_cast(pool_type), limit); - } catch (...) { - // Ignore errors - } +// Refitter methods that use char const** (pointer-to-pointer); autocxx cannot +// bind these. +int32_t trtx_refitter_get_missing(void *refitter, int32_t size, + char const **layer_names, int32_t *roles) { + if (!refitter || !layer_names || !roles) + return 0; + try { + auto *ir = static_cast(refitter); + return ir->getMissing(size, layer_names, + reinterpret_cast(roles)); + } catch (...) { + return 0; + } } -//============================================================================== -// SECTION 4: NETWORK DEFINITION METHODS (POTENTIALLY REDUNDANT) -//============================================================================== -// These wrap INetworkDefinition layer building methods. -// autocxx CAN generate these with generate!("nvinfer1::INetworkDefinition"). -// FUTURE: Consider migrating to direct autocxx calls -// NOTE: This is the largest section (~350 lines) and biggest refactoring opportunity - -// Network methods -// network_add_input - REMOVED - Now using direct autocxx call in network.rs - -// network_add_convolution - REMOVED - Using direct autocxx - -// network_add_activation - REMOVED - Now using direct autocxx call in network.rs - -// network_add_pooling - REMOVED - Now using direct autocxx call in network.rs - -// network_add_matrix_multiply - REMOVED - Using direct autocxx - -// network_add_constant - REMOVED - Using direct autocxx - -// network_add_elementwise - REMOVED - Now using direct autocxx call in network.rs - -// network_add_shuffle - REMOVED - Now using direct autocxx call in network.rs - -void* network_add_concatenation(void* network, void** inputs, int32_t nb_inputs) { - if (!network || !inputs || nb_inputs <= 0) return nullptr; - try { - auto* inetwork = static_cast(network); - std::vector tensors; - tensors.reserve(nb_inputs); - for (int32_t i = 0; i < nb_inputs; ++i) { - tensors.push_back(static_cast(inputs[i])); - } - auto* layer = inetwork->addConcatenation(tensors.data(), nb_inputs); - return layer; // Return layer, not output tensor - } catch (...) { - return nullptr; - } -} - -// network_add_softmax - REMOVED - Using direct autocxx - -// network_add_scale - REMOVED - Using direct autocxx - -// network_add_reduce - REMOVED - Using direct autocxx - -// network_add_slice - REMOVED - Now using direct autocxx call in network.rs - -// network_add_resize - REMOVED - Using direct autocxx - -// network_add_topk - REMOVED - Using direct autocxx - -// network_add_gather - REMOVED - Using direct autocxx - -// network_add_select - REMOVED - Using direct autocxx - -void* network_add_assertion(void* network, void* condition, const char* message) { - if (!network || !condition) return nullptr; - try { - auto* inetwork = static_cast(network); - auto* condition_tensor = static_cast(condition); - auto* layer = inetwork->addAssertion(*condition_tensor, message ? message : ""); - // Assertion layers don't have outputs, return the layer itself - return layer; - } catch (...) { - return nullptr; - } +int32_t trtx_refitter_get_all(void *refitter, int32_t size, + char const **layer_names, int32_t *roles) { + if (!refitter || !layer_names || !roles) + return 0; + try { + auto *ir = static_cast(refitter); + return ir->getAll(size, layer_names, + reinterpret_cast(roles)); + } catch (...) { + return 0; + } } -void* network_add_loop(void* network) { - if (!network) return nullptr; - try { - auto* inetwork = static_cast(network); - return inetwork->addLoop(); - } catch (...) { - return nullptr; - } +int32_t trtx_refitter_get_missing_weights(void *refitter, int32_t size, + char const **weights_names) { + if (!refitter || !weights_names) + return 0; + try { + auto *ir = static_cast(refitter); + return ir->getMissingWeights(size, weights_names); + } catch (...) { + return 0; + } } -void* network_add_if_conditional(void* network) { - if (!network) return nullptr; - try { - auto* inetwork = static_cast(network); - return inetwork->addIfConditional(); - } catch (...) { - return nullptr; - } +int32_t trtx_refitter_get_all_weights(void *refitter, int32_t size, + char const **weights_names) { + if (!refitter || !weights_names) + return 0; + try { + auto *ir = static_cast(refitter); + return ir->getAllWeights(size, weights_names); + } catch (...) { + return 0; + } } -//============================================================================== -// SECTION 5: TENSOR METHODS (POTENTIALLY REDUNDANT) -//============================================================================== -// Wrap ITensor getter/setter methods. -// autocxx CAN generate with generate!("nvinfer1::ITensor") - -// Tensor methods -void* tensor_get_dimensions(void* tensor, int32_t* dims, int32_t* nb_dims) { - if (!tensor || !dims || !nb_dims) return nullptr; - try { - auto* itensor = static_cast(tensor); - nvinfer1::Dims dimensions = itensor->getDimensions(); - *nb_dims = dimensions.nbDims; - for (int32_t i = 0; i < dimensions.nbDims && i < nvinfer1::Dims::MAX_DIMS; ++i) { - dims[i] = dimensions.d[i]; - } - return tensor; // Return success - } catch (...) { - return nullptr; - } +bool parser_parse(void *parser, const void *data, size_t size) { + if (!parser || !data) + return false; + try { + auto *iparser = static_cast(parser); + return iparser->parse(data, size); + } catch (...) { + return false; + } } -int32_t tensor_get_type(void* tensor) { - if (!tensor) return -1; - try { - auto* itensor = static_cast(tensor); - return static_cast(itensor->getType()); - } catch (...) { - return -1; - } +int32_t parser_get_nb_errors(void *parser) { + if (!parser) + return 0; + try { + auto *iparser = static_cast(parser); + return iparser->getNbErrors(); + } catch (...) { + return 0; + } } -void* builder_build_serialized_network(void* builder, void* network, void* config, size_t* out_size) { - if (!builder || !network || !config || !out_size) return nullptr; - try { - auto* ibuilder = static_cast(builder); - auto* inetwork = static_cast(network); - auto* iconfig = static_cast(config); - - auto* serialized = ibuilder->buildSerializedNetwork(*inetwork, *iconfig); - if (!serialized) return nullptr; - - *out_size = serialized->size(); - // Allocate and copy data - void* data = malloc(*out_size); - if (data) { - memcpy(data, serialized->data(), *out_size); - } - delete serialized; - return data; - } catch (...) { - return nullptr; - } +void *parser_get_error(void *parser, int32_t index) { + if (!parser) + return nullptr; + try { + auto *iparser = static_cast(parser); + return const_cast(iparser->getError(index)); + } catch (...) { + return nullptr; + } } -// Runtime methods -void* runtime_deserialize_cuda_engine(void* runtime, const void* data, size_t size) { - if (!runtime || !data) return nullptr; - try { - auto* iruntime = static_cast(runtime); - return iruntime->deserializeCudaEngine(data, size); - } catch (...) { - return nullptr; - } +const char *parser_error_desc(void *error) { + if (!error) + return nullptr; + try { + auto *ierror = static_cast(error); + return ierror->desc(); + } catch (...) { + return nullptr; + } } -// Engine methods -// ExecutionContext methods -// Parser methods -bool parser_parse(void* parser, const void* data, size_t size) { - if (!parser || !data) return false; - try { - auto* iparser = static_cast(parser); - return iparser->parse(data, size); - } catch (...) { - return false; +void *network_add_concatenation(void *network, void **inputs, + int32_t nb_inputs) { + if (!network || !inputs || nb_inputs <= 0) + return nullptr; + try { + auto *inetwork = static_cast(network); + std::vector tensors; + tensors.reserve(nb_inputs); + for (int32_t i = 0; i < nb_inputs; ++i) { + tensors.push_back(static_cast(inputs[i])); } + auto *layer = inetwork->addConcatenation(tensors.data(), nb_inputs); + return layer; // Return layer, not output tensor + } catch (...) { + return nullptr; + } } -int32_t parser_get_nb_errors(void* parser) { - if (!parser) return 0; - try { - auto* iparser = static_cast(parser); - return iparser->getNbErrors(); - } catch (...) { - return 0; - } -} +uint32_t get_tensorrt_version() { return NV_TENSORRT_VERSION; } -void* parser_get_error(void* parser, int32_t index) { - if (!parser) return nullptr; - try { - auto* iparser = static_cast(parser); - return const_cast(iparser->getError(index)); - } catch (...) { - return nullptr; - } +namespace nvinfer1 { +class ProgressMonitor : public IProgressMonitor { +public: + ProgressMonitor(void *self, void *phaseStart, void *stepComplete, + void *phaseFinish) + : self(self), m_phaseStart((decltype(m_phaseStart))phaseStart), + m_stepComplete((decltype(m_stepComplete))stepComplete), + m_phaseFinish((decltype(m_phaseFinish))phaseFinish) {} + ~ProgressMonitor() = default; + void *self; + void (*m_phaseStart)(void *self, char const *phaseName, + char const *parentPhase, int32_t nbSteps); + bool (*m_stepComplete)(void *self, char const *phaseName, int32_t step); + void (*m_phaseFinish)(void *self, char const *phaseName); + + void phaseStart(char const *phaseName, char const *parentPhase, + int32_t nbSteps) override { + m_phaseStart(self, phaseName, parentPhase, nbSteps); + }; + bool stepComplete(char const *phaseName, int32_t step) override { + return m_stepComplete(self, phaseName, step); + }; + void phaseFinish(char const *phaseName) override { + m_phaseFinish(self, phaseName); + }; +}; +} // namespace nvinfer1 + +void *trtx_create_progress_monitor(void *self, void *phaseStart, + void *stepComplete, void *phaseFinish) { + try { + return new nvinfer1::ProgressMonitor(self, phaseStart, stepComplete, + phaseFinish); + } catch (...) { + return nullptr; + } } - -const char* parser_error_desc(void* error) { - if (!error) return nullptr; - try { - auto* ierror = static_cast(error); - return ierror->desc(); - } catch (...) { - return nullptr; - } +void trtx_destroy_progress_monitor(void *self) { + delete (nvinfer1::ProgressMonitor *)(self); } //============================================================================== -// SECTION 6: DESTRUCTION METHODS (POTENTIALLY REDUNDANT) +// ErrorRecorder subclass (bridge to Rust RecordError) //============================================================================== -// These wrap TensorRT object deletion. -// autocxx CAN handle C++ destructors with RAII wrappers. -// FUTURE: Consider using UniquePtr or Drop trait implementations - -// Destruction methods -void delete_builder(void* builder) { - if (builder) { - delete static_cast(builder); - } +namespace nvinfer1 { +class ErrorRecorderSubclass : public IErrorRecorder { +public: + using ErrorCode = nvinfer1::ErrorCode; + ErrorRecorderSubclass(void *self, int32_t (*getNbErrors)(void *), + int32_t (*getErrorCode)(void *, int32_t), + void (*getErrorDesc)(void *, int32_t, char *, size_t), + bool (*hasOverflowed)(void *), void (*clear)(void *), + bool (*reportError)(void *, int32_t, char const *), + int32_t (*incRefCount)(void *), + int32_t (*decRefCount)(void *)) + : self(self), m_getNbErrors(getNbErrors), m_getErrorCode(getErrorCode), + m_getErrorDesc(getErrorDesc), m_hasOverflowed(hasOverflowed), + m_clear(clear), m_reportError(reportError), m_incRefCount(incRefCount), + m_decRefCount(decRefCount) {} + ~ErrorRecorderSubclass() = default; + + void *self; + int32_t (*m_getNbErrors)(void *); + int32_t (*m_getErrorCode)(void *, int32_t); + void (*m_getErrorDesc)(void *, int32_t, char *, size_t); + bool (*m_hasOverflowed)(void *); + void (*m_clear)(void *); + bool (*m_reportError)(void *, int32_t, char const *); + int32_t (*m_incRefCount)(void *); + int32_t (*m_decRefCount)(void *); + + mutable std::string m_lastDesc; + + int32_t getNbErrors() const noexcept override { return m_getNbErrors(self); } + int32_t getErrorCode(int32_t errorIdx) const noexcept override { + return m_getErrorCode(self, errorIdx); + } + ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept override { + char buf[128]; + m_getErrorDesc(self, errorIdx, buf, sizeof(buf)); + m_lastDesc = buf; + return m_lastDesc.c_str(); + } + bool hasOverflowed() const noexcept override { return m_hasOverflowed(self); } + void clear() noexcept override { m_clear(self); } + bool reportError(int32_t val, ErrorDesc desc) noexcept override { + return m_reportError(self, val, desc); + } + RefCount incRefCount() noexcept override { return m_incRefCount(self); } + RefCount decRefCount() noexcept override { return m_decRefCount(self); } +}; +} // namespace nvinfer1 + +void *trtx_create_error_recorder(void *self, void *getNbErrors, + void *getErrorCode, void *getErrorDesc, + void *hasOverflowed, void *clear, + void *reportError, void *incRefCount, + void *decRefCount) { + try { + return new nvinfer1::ErrorRecorderSubclass( + self, (int32_t (*)(void *))getNbErrors, + (int32_t (*)(void *, int32_t))getErrorCode, + (void (*)(void *, int32_t, char *, size_t))getErrorDesc, + (bool (*)(void *))hasOverflowed, (void (*)(void *))clear, + (bool (*)(void *, int32_t, char const *))reportError, + (int32_t (*)(void *))incRefCount, (int32_t (*)(void *))decRefCount); + } catch (...) { + return nullptr; + } } - -void delete_network(void* network) { - if (network) { - delete static_cast(network); - } +void trtx_destroy_error_recorder(void *obj) { + delete static_cast(obj); } -void delete_config(void* config) { - if (config) { - delete static_cast(config); - } +//============================================================================== +// GpuAllocator subclass (bridge to Rust AllocateGpu) +//============================================================================== +namespace nvinfer1 { +class GpuAllocatorSubclass : public IGpuAllocator { +public: + GpuAllocatorSubclass(void *self, + void *(*allocateAsync)(void *, uint64_t, uint64_t, + uint32_t, void *), + void *(*reallocate)(void *, void *, uint64_t, uint64_t), + bool (*deallocateAsync)(void *, void *, void *)) + : self(self), m_allocateAsync((decltype(m_allocateAsync))allocateAsync), + m_reallocate((decltype(m_reallocate))reallocate), + m_deallocateAsync((decltype(m_deallocateAsync))deallocateAsync) {} + ~GpuAllocatorSubclass() = default; + + void *self; + void *(*m_allocateAsync)(void *, uint64_t, uint64_t, uint32_t, void *); + void *(*m_reallocate)(void *, void *, uint64_t, uint64_t); + bool (*m_deallocateAsync)(void *, void *, void *); + + void *allocate(uint64_t size, uint64_t alignment, + AllocatorFlags flags) noexcept override { + return m_allocateAsync(self, size, alignment, static_cast(flags), + nullptr); + } + void *reallocate(void *baseAddr, uint64_t alignment, + uint64_t newSize) noexcept override { + return m_reallocate(self, baseAddr, alignment, newSize); + } + bool deallocate(void *memory) noexcept override { + return m_deallocateAsync(self, memory, nullptr); + } + void *allocateAsync(uint64_t size, uint64_t alignment, AllocatorFlags flags, + cudaStream_t stream) noexcept override { + return m_allocateAsync(self, size, alignment, static_cast(flags), + stream); + } + bool deallocateAsync(void *memory, cudaStream_t stream) noexcept override { + return m_deallocateAsync(self, memory, stream); + } +}; +} // namespace nvinfer1 + +void *trtx_create_gpu_allocator(void *self, void *allocateAsync, + void *reallocate, void *deallocateAsync) { + + try { + return new nvinfer1::GpuAllocatorSubclass( + self, + (void *(*)(void *, uint64_t, uint64_t, uint32_t, void *))allocateAsync, + (void *(*)(void *, void *, uint64_t, uint64_t))reallocate, + (bool (*)(void *, void *, void *))deallocateAsync); + } catch (...) { + return nullptr; + } } - -void delete_runtime(void* runtime) { - if (runtime) { - delete static_cast(runtime); - } +void trtx_destroy_gpu_allocator(void *obj) { + delete static_cast(obj); } - -void delete_engine(void* engine) { - if (engine) { - delete static_cast(engine); - } } -void delete_context(void* context) { - if (context) { - delete static_cast(context); - } -} +namespace nvinfer1 { +class DebugListener : public IDebugListener { +public: + DebugListener(void *self, + bool (*processDebugTensor)(void *self, void const *addr, + TensorLocation location, + DataType type, Dims const *shape, + char const *name, + cudaStream_t stream)) + : self(self), m_processDebugTensor( + (decltype(m_processDebugTensor))processDebugTensor) {} + ~DebugListener() = default; + + void *self; + bool (*m_processDebugTensor)(void *self, void const *addr, + TensorLocation location, DataType type, + Dims const *shape, char const *name, + cudaStream_t stream); + + bool processDebugTensor(void const *addr, TensorLocation location, + DataType type, Dims const &shape, char const *name, + cudaStream_t stream) override { + return m_processDebugTensor(self, addr, location, type, &shape, name, + stream); + }; +}; -void delete_parser(void* parser) { - if (parser) { - delete static_cast(parser); - } +extern "C" { +void *trtx_create_debug_listener( + nvinfer1::IDebugListener *self, + bool (*processDebugTensor)(void *self, void const *addr, + TensorLocation location, DataType type, + Dims const *shape, char const *name, + cudaStream_t stream)) { + return new DebugListener(self, processDebugTensor); } - -uint32_t get_tensorrt_version() { - return NV_TENSORRT_VERSION; +void trtx_destroy_debug_listener(nvinfer1::IDebugListener *self) { + delete self; } - -} // extern "C" +} +} // namespace nvinfer1 diff --git a/trtx-sys/logger_bridge.hpp b/trtx-sys/logger_bridge.hpp index a21903a..e0dec72 100644 --- a/trtx-sys/logger_bridge.hpp +++ b/trtx-sys/logger_bridge.hpp @@ -1,80 +1,30 @@ #ifndef TRTX_LOGGER_BRIDGE_H #define TRTX_LOGGER_BRIDGE_H -#include -namespace std { -typedef std::size_t size_t; -} #include - -#include // This will cause a compiler error if Weights isn't a POD -static_assert(std::is_standard_layout::value, "Weights must be standard layout"); -static_assert(std::is_trivial::value, "Weights must be trivial"); +#include +#include #ifdef __cplusplus extern "C" { #endif // Rust callback function type -typedef void (*RustLogCallback)(void* user_data, int32_t severity, const char* msg); +typedef void (*RustLogCallback)(void *user_data, int32_t severity, + const char *msg); // Opaque logger type for C interface typedef struct RustLoggerBridge RustLoggerBridge; // Create a logger bridge that calls back into Rust -RustLoggerBridge* create_rust_logger_bridge(RustLogCallback callback, void* user_data); +RustLoggerBridge *create_rust_logger_bridge(RustLogCallback callback, + void *user_data); // Destroy the logger bridge -void destroy_rust_logger_bridge(RustLoggerBridge* logger); +void destroy_rust_logger_bridge(RustLoggerBridge *logger); // Get the ILogger pointer (for use with TensorRT C++ API) -nvinfer1::ILogger* get_logger_interface(RustLoggerBridge* logger); - -// Factory functions for TensorRT (return raw pointers) -void* create_infer_builder(void* logger); -void* create_infer_runtime(void* logger); - -// ONNX Parser factory function -void* create_onnx_parser(void* network, void* logger); - -// Builder methods -void* builder_build_serialized_network(void* builder, void* network, void* config, size_t* out_size); -void builder_config_set_memory_pool_limit(void* config, int32_t pool_type, size_t limit); - -// Network methods -// network_add_convolution - REMOVED - Using direct autocxx -void* network_add_concatenation(void* network, void** inputs, int32_t nb_inputs); -// network_add_constant - REMOVED - Using direct autocxx -// network_add_scale - REMOVED - Using direct autocxx -void* network_add_assertion(void* network, void* condition, const char* message); -void* network_add_loop(void* network); -void* network_add_if_conditional(void* network); - -// Tensor methods -void* tensor_get_dimensions(void* tensor, int32_t* dims, int32_t* nb_dims); -int32_t tensor_get_type(void* tensor); - -// Destruction methods -void delete_builder(void* builder); -void delete_network(void* network); -void delete_config(void* config); -void delete_runtime(void* runtime); -void delete_engine(void* engine); -void delete_context(void* context); -void delete_parser(void* parser); - -// Runtime methods -void* runtime_deserialize_cuda_engine(void* runtime, const void* data, size_t size); - -// Engine methods - -// ExecutionContext methods - -// Parser methods -bool parser_parse(void* parser, const void* data, size_t size); -int32_t parser_get_nb_errors(void* parser); -void* parser_get_error(void* parser, int32_t index); -const char* parser_error_desc(void* error); +nvinfer1::ILogger *get_logger_interface(RustLoggerBridge *logger); #ifdef __cplusplus } diff --git a/trtx-sys/mock.c b/trtx-sys/mock.c deleted file mode 100644 index e5b9c5d..0000000 --- a/trtx-sys/mock.c +++ /dev/null @@ -1,342 +0,0 @@ -// Mock implementations for development without TensorRT-RTX -// These are stubs that allow compilation and basic testing - -#include -#include -#include -#include - -// Mock handles (just use integers) -typedef struct { int dummy; } TrtxLogger; -typedef struct { int dummy; } TrtxBuilder; -typedef struct { int dummy; } TrtxBuilderConfig; -typedef struct { int dummy; } TrtxNetworkDefinition; -typedef struct { int dummy; } TrtxRuntime; -typedef struct { int dummy; } TrtxCudaEngine; -typedef struct { int dummy; } TrtxExecutionContext; - -// Mock implementations - all return success - -int32_t trtx_logger_create( - void* callback, - void* user_data, - TrtxLogger** out_logger, - char* error_msg, - size_t error_msg_len -) { - (void)callback; - (void)user_data; - (void)error_msg; - (void)error_msg_len; - *out_logger = malloc(sizeof(TrtxLogger)); - return 0; // TRTX_SUCCESS -} - -void trtx_logger_destroy(TrtxLogger* logger) { - free(logger); -} - -int32_t trtx_builder_create( - TrtxLogger* logger, - TrtxBuilder** out_builder, - char* error_msg, - size_t error_msg_len -) { - (void)logger; - (void)error_msg; - (void)error_msg_len; - *out_builder = malloc(sizeof(TrtxBuilder)); - return 0; -} - -void trtx_builder_destroy(TrtxBuilder* builder) { - free(builder); -} - -int32_t trtx_builder_create_network( - TrtxBuilder* builder, - uint32_t flags, - TrtxNetworkDefinition** out_network, - char* error_msg, - size_t error_msg_len -) { - (void)builder; - (void)flags; - (void)error_msg; - (void)error_msg_len; - *out_network = malloc(sizeof(TrtxNetworkDefinition)); - return 0; -} - -int32_t trtx_builder_create_builder_config( - TrtxBuilder* builder, - TrtxBuilderConfig** out_config, - char* error_msg, - size_t error_msg_len -) { - (void)builder; - (void)error_msg; - (void)error_msg_len; - *out_config = malloc(sizeof(TrtxBuilderConfig)); - return 0; -} - -int32_t trtx_builder_build_serialized_network( - TrtxBuilder* builder, - TrtxNetworkDefinition* network, - TrtxBuilderConfig* config, - void** out_data, - size_t* out_size, - char* error_msg, - size_t error_msg_len -) { - (void)builder; - (void)network; - (void)config; - (void)error_msg; - (void)error_msg_len; - // Return a small dummy buffer - *out_size = 16; - *out_data = malloc(16); - memset(*out_data, 0, 16); - return 0; -} - -void trtx_builder_config_destroy(TrtxBuilderConfig* config) { - free(config); -} - -int32_t trtx_builder_config_set_memory_pool_limit( - TrtxBuilderConfig* config, - int32_t pool_type, - size_t pool_size, - char* error_msg, - size_t error_msg_len -) { - (void)config; - (void)pool_type; - (void)pool_size; - (void)error_msg; - (void)error_msg_len; - return 0; -} - -void trtx_network_destroy(TrtxNetworkDefinition* network) { - free(network); -} - -int32_t trtx_runtime_create( - TrtxLogger* logger, - TrtxRuntime** out_runtime, - char* error_msg, - size_t error_msg_len -) { - (void)logger; - (void)error_msg; - (void)error_msg_len; - *out_runtime = malloc(sizeof(TrtxRuntime)); - return 0; -} - -void trtx_runtime_destroy(TrtxRuntime* runtime) { - free(runtime); -} - -int32_t trtx_runtime_deserialize_cuda_engine( - TrtxRuntime* runtime, - const void* data, - size_t size, - TrtxCudaEngine** out_engine, - char* error_msg, - size_t error_msg_len -) { - (void)runtime; - (void)data; - (void)size; - (void)error_msg; - (void)error_msg_len; - *out_engine = malloc(sizeof(TrtxCudaEngine)); - return 0; -} - -void trtx_cuda_engine_destroy(TrtxCudaEngine* engine) { - free(engine); -} - -int32_t trtx_cuda_engine_create_execution_context( - TrtxCudaEngine* engine, - TrtxExecutionContext** out_context, - char* error_msg, - size_t error_msg_len -) { - (void)engine; - (void)error_msg; - (void)error_msg_len; - *out_context = malloc(sizeof(TrtxExecutionContext)); - return 0; -} - -int32_t trtx_cuda_engine_get_tensor_name( - TrtxCudaEngine* engine, - int32_t index, - const char** out_name, - char* error_msg, - size_t error_msg_len -) { - (void)engine; - (void)error_msg; - (void)error_msg_len; - static const char* mock_names[] = {"input", "output"}; - if (index < 0 || index >= 2) { - return 1; // TRTX_ERROR_INVALID_ARGUMENT - } - *out_name = mock_names[index]; - return 0; -} - -int32_t trtx_cuda_engine_get_nb_io_tensors( - TrtxCudaEngine* engine, - int32_t* out_count -) { - (void)engine; - *out_count = 2; // Mock: 1 input, 1 output - return 0; -} - -void trtx_execution_context_destroy(TrtxExecutionContext* context) { - free(context); -} - -int32_t trtx_execution_context_set_tensor_address( - TrtxExecutionContext* context, - const char* tensor_name, - void* data, - char* error_msg, - size_t error_msg_len -) { - (void)context; - (void)tensor_name; - (void)data; - (void)error_msg; - (void)error_msg_len; - return 0; -} - -int32_t trtx_execution_context_enqueue_v3( - TrtxExecutionContext* context, - void* cuda_stream, - char* error_msg, - size_t error_msg_len -) { - (void)context; - (void)cuda_stream; - (void)error_msg; - (void)error_msg_len; - return 0; -} - -void trtx_free_buffer(void* buffer) { - free(buffer); -} - -// ONNX Parser mock implementations -typedef struct { int dummy; } TrtxOnnxParser; - -int32_t trtx_onnx_parser_create( - void* network, - void* logger, - TrtxOnnxParser** out_parser, - char* error_msg, - size_t error_msg_len -) { - (void)network; - (void)logger; - (void)error_msg; - (void)error_msg_len; - *out_parser = malloc(sizeof(TrtxOnnxParser)); - return 0; -} - -void trtx_onnx_parser_destroy(TrtxOnnxParser* parser) { - free(parser); -} - -int32_t trtx_onnx_parser_parse( - TrtxOnnxParser* parser, - const void* model_data, - size_t model_size, - char* error_msg, - size_t error_msg_len -) { - (void)parser; - (void)model_data; - (void)model_size; - (void)error_msg; - (void)error_msg_len; - // Mock: always succeeds - return 0; -} - -// CUDA Memory Management mock implementations -int32_t trtx_cuda_malloc( - void** ptr, - size_t size, - char* error_msg, - size_t error_msg_len -) { - (void)error_msg; - (void)error_msg_len; - *ptr = malloc(size); - return *ptr ? 0 : 2; // 0 = success, 2 = out of memory -} - -int32_t trtx_cuda_free( - void* ptr, - char* error_msg, - size_t error_msg_len -) { - (void)error_msg; - (void)error_msg_len; - free(ptr); - return 0; -} - -int32_t trtx_cuda_memcpy_host_to_device( - void* dst, - const void* src, - size_t size, - char* error_msg, - size_t error_msg_len -) { - (void)error_msg; - (void)error_msg_len; - memcpy(dst, src, size); - return 0; -} - -int32_t trtx_cuda_memcpy_device_to_host( - void* dst, - const void* src, - size_t size, - char* error_msg, - size_t error_msg_len -) { - (void)error_msg; - (void)error_msg_len; - memcpy(dst, src, size); - return 0; -} - -int32_t trtx_cuda_synchronize( - char* error_msg, - size_t error_msg_len -) { - (void)error_msg; - (void)error_msg_len; - // Mock: nothing to synchronize - return 0; -} - -void* trtx_cuda_get_default_stream() { - return NULL; -} diff --git a/trtx-sys/src/lib.rs b/trtx-sys/src/lib.rs index 9d11b90..05e3db8 100644 --- a/trtx-sys/src/lib.rs +++ b/trtx-sys/src/lib.rs @@ -29,10 +29,6 @@ #![allow(non_snake_case)] #![allow(clippy::all)] -// Mock mode uses old-style bindings -#[cfg(feature = "mock")] -include!(concat!(env!("OUT_DIR"), "/bindings.rs")); - #[allow(warnings)] mod enums { include!(concat!(env!("OUT_DIR"), "/enums.rs")); @@ -41,23 +37,21 @@ mod enums { macro_rules! better_enum { ($to:ident) => { pub use crate::enums::$to; - #[cfg(not(feature = "mock"))] - impl Into for $to { - fn into(self) -> crate::real_bindings::nvinfer1::$to { + impl Into for $to { + fn into(self) -> crate::nvinfer1::$to { unsafe { transmute(self) } } } - #[cfg(not(feature = "mock"))] - impl From for $to { - fn from(value: crate::real_bindings::nvinfer1::$to) -> Self { + impl From for $to { + fn from(value: crate::nvinfer1::$to) -> Self { unsafe { transmute(value) } } } }; } -#[cfg(not(feature = "mock"))] use std::mem::transmute; +use std::pin::Pin; better_enum!(LayerType); better_enum!(ActivationType); better_enum!(DataType); @@ -85,384 +79,675 @@ better_enum!(ScaleMode); better_enum!(ScatterMode); better_enum!(UnaryOperation); better_enum!(TopKOperation); +better_enum!(LayerInformationFormat); +better_enum!(TensorLocation); +better_enum!(SerializationFlag); +better_enum!(OptProfileSelector); +better_enum!(AttentionNormalizationOp); +better_enum!(SeekPosition); +better_enum!(WeightsRole); +better_enum!(TripLimit); +pub use enums::ErrorCode; + +use autocxx::prelude::*; + +include_cpp! { + #include "NvInfer.h" + #include "NvInferRuntime.h" + #include "NvOnnxParser.h" + + safety!(unsafe_ffi) + + // Core TensorRT types + generate!("nvinfer1::IBuilder") + generate!("nvinfer1::IBuilderConfig") + generate!("nvinfer1::INetworkDefinition") + generate!("nvinfer1::ITensor") + generate!("nvinfer1::ILayer") + generate!("nvinfer1::IVersionedInterface") + generate!("nvinfer1::IProgressMonitor") + generate!("nvinfer1::IStreamWriter") + generate!("nvinfer1::IStreamReaderV2") + generate!("nvinfer1::IErrorRecorder") + generate!("nvinfer1::IProfiler") + generate!("nvinfer1::IGpuAllocator") + generate!("nvinfer1::IDebugListener") + generate!("nvinfer1::ISerializationConfig") + generate!("nvinfer1::IOptimizationProfile") + generate!("nvinfer1::IRefitter") + + // Derived layer types - for inheritance support + generate!("nvinfer1::IActivationLayer") + generate!("nvinfer1::IConvolutionLayer") + generate!("nvinfer1::IPoolingLayer") + generate!("nvinfer1::IElementWiseLayer") + generate!("nvinfer1::IShuffleLayer") + generate!("nvinfer1::IConcatenationLayer") + generate!("nvinfer1::IMatrixMultiplyLayer") + generate!("nvinfer1::IConstantLayer") + generate!("nvinfer1::ISoftMaxLayer") + generate!("nvinfer1::IScaleLayer") + generate!("nvinfer1::IReduceLayer") + generate!("nvinfer1::ISliceLayer") + generate!("nvinfer1::IResizeLayer") + generate!("nvinfer1::ITopKLayer") + generate!("nvinfer1::IGatherLayer") + generate!("nvinfer1::IScatterLayer") + generate!("nvinfer1::ISelectLayer") + generate!("nvinfer1::IUnaryLayer") + generate!("nvinfer1::IIdentityLayer") + generate!("nvinfer1::IPaddingLayer") + generate!("nvinfer1::ICastLayer") + generate!("nvinfer1::IDeconvolutionLayer") + generate!("nvinfer1::IQuantizeLayer") + generate!("nvinfer1::IDequantizeLayer") + generate!("nvinfer1::IAssertionLayer") + generate!("nvinfer1::ICumulativeLayer") + generate!("nvinfer1::ILoop") + generate!("nvinfer1::IIfConditional") + generate!("nvinfer1::INormalizationLayer") + generate!("nvinfer1::ISqueezeLayer") + generate!("nvinfer1::IUnsqueezeLayer") + generate!("nvinfer1::ILRNLayer") + generate!("nvinfer1::IShapeLayer") + generate!("nvinfer1::IParametricReLULayer") + generate!("nvinfer1::IFillLayer") + generate!("nvinfer1::IEinsumLayer") + generate!("nvinfer1::IOneHotLayer") + generate!("nvinfer1::INonZeroLayer") + generate!("nvinfer1::IGridSampleLayer") + generate!("nvinfer1::INMSLayer") + generate!("nvinfer1::IReverseSequenceLayer") + generate!("nvinfer1::IDynamicQuantizeLayer") + generate!("nvinfer1::IRotaryEmbeddingLayer") + generate!("nvinfer1::IKVCacheUpdateLayer") + generate!("nvinfer1::IRaggedSoftMaxLayer") + generate!("nvinfer1::ILoopBoundaryLayer") + generate!("nvinfer1::IRecurrenceLayer") + generate!("nvinfer1::ILoopOutputLayer") + generate!("nvinfer1::ITripLimitLayer") + generate!("nvinfer1::IIteratorLayer") + generate!("nvinfer1::IConditionLayer") + generate!("nvinfer1::IIfConditionalOutputLayer") + generate!("nvinfer1::IIfConditionalInputLayer") + generate!("nvinfer1::IAttentionBoundaryLayer") + generate!("nvinfer1::IAttentionInputLayer") + generate!("nvinfer1::IAttentionOutputLayer") + generate!("nvinfer1::IAttention") + // NOTE: IRNNv2Layer is deprecated (TRT_DEPRECATED) and autocxx cannot generate bindings for it + // RNN operations (lstm, lstmCell, gru, gruCell) remain deferred until we can work around this + // generate!("nvinfer1::IRNNv2Layer") + + generate!("nvinfer1::IRuntime") + generate!("nvinfer1::ICudaEngine") + generate!("nvinfer1::IExecutionContext") + generate!("nvinfer1::IEngineInspector") + generate!("nvinfer1::IHostMemory") + generate!("nvinfer1::LayerInformationFormat") + + // Try generating Dims64 directly (base class, not the typedef alias) + generate_pod!("nvinfer1::Dims64") + + generate_pod!("nvinfer1::DataType") + generate_pod!("nvinfer1::TensorIOMode") + generate_pod!("nvinfer1::MemoryPoolType") + generate_pod!("nvinfer1::NetworkDefinitionCreationFlag") + generate_pod!("nvinfer1::ActivationType") + generate_pod!("nvinfer1::PoolingType") + generate_pod!("nvinfer1::ElementWiseOperation") + generate_pod!("nvinfer1::MatrixOperation") + generate_pod!("nvinfer1::UnaryOperation") + generate_pod!("nvinfer1::ReduceOperation") + generate_pod!("nvinfer1::CumulativeOperation") + generate_pod!("nvinfer1::GatherMode") + generate_pod!("nvinfer1::ScatterMode") + generate_pod!("nvinfer1::InterpolationMode") + generate_pod!("nvinfer1::ResizeCoordinateTransformation") + generate_pod!("nvinfer1::ResizeSelector") + generate_pod!("nvinfer1::ResizeRoundMode") + generate_pod!("nvinfer1::ProfilingVerbosity") + generate_pod!("nvinfer1::EngineCapability") + generate_pod!("nvinfer1::BuilderFlag") + generate_pod!("nvinfer1::BuilderFlags") + generate_pod!("nvinfer1::DeviceType") + generate_pod!("nvinfer1::TacticSource") + generate_pod!("nvinfer1::TacticSources") + generate_pod!("nvinfer1::PreviewFeature") + generate_pod!("nvinfer1::HardwareCompatibilityLevel") + generate_pod!("nvinfer1::RuntimePlatform") + generate_pod!("nvinfer1::TilingOptimizationLevel") + generate_pod!("nvinfer1::ComputeCapability") + generate_pod!("nvinfer1::APILanguage") + // NOTE: RNN enums commented out because IRNNv2Layer (deprecated) cannot be generated + // generate!("nvinfer1::RNNOperation") + // generate!("nvinfer1::RNNDirection") + // generate!("nvinfer1::RNNInputMode") + // generate!("nvinfer1::RNNGateType") + generate_pod!("nvinfer1::Weights") + generate_pod!("nvinfer1::Permutation") + generate_pod!("nvinfer1::TripLimit") + generate_pod!("nvinfer1::LoopOutput") + generate_pod!("nvinfer1::AttentionNormalizationOp") + generate_pod!("nvinfer1::WeightsRole") + + generate!("nvinfer1::ErrorCode") + generate!("nvinfer1::LayerType") + generate!("nvinfer1::SerializationFlags") + generate!("nvinfer1::SerializationFlag") + generate!("nvinfer1::OptProfileSelector") + + // NOTE: createInferBuilder/Runtime moved to logger_bridge.cpp (autocxx struggles with these) + + // ONNX Parser + generate!("nvonnxparser::IParser") + // NOTE: createParser also moved to logger_bridge.cpp -// Real mode uses autocxx -#[cfg(not(feature = "mock"))] -pub mod real_bindings { - use autocxx::prelude::*; - - include_cpp! { - #include "NvInfer.h" - #include "NvOnnxParser.h" - - safety!(unsafe_ffi) - - // Core TensorRT types - generate!("nvinfer1::IBuilder") - generate!("nvinfer1::IBuilderConfig") - generate!("nvinfer1::INetworkDefinition") - generate!("nvinfer1::ITensor") - generate!("nvinfer1::ILayer") - - // Derived layer types - for inheritance support - generate!("nvinfer1::IActivationLayer") - generate!("nvinfer1::IConvolutionLayer") - generate!("nvinfer1::IPoolingLayer") - generate!("nvinfer1::IElementWiseLayer") - generate!("nvinfer1::IShuffleLayer") - generate!("nvinfer1::IConcatenationLayer") - generate!("nvinfer1::IMatrixMultiplyLayer") - generate!("nvinfer1::IConstantLayer") - generate!("nvinfer1::ISoftMaxLayer") - generate!("nvinfer1::IScaleLayer") - generate!("nvinfer1::IReduceLayer") - generate!("nvinfer1::ISliceLayer") - generate!("nvinfer1::IResizeLayer") - generate!("nvinfer1::ITopKLayer") - generate!("nvinfer1::IGatherLayer") - generate!("nvinfer1::IScatterLayer") - generate!("nvinfer1::ISelectLayer") - generate!("nvinfer1::IUnaryLayer") - generate!("nvinfer1::IIdentityLayer") - generate!("nvinfer1::IPaddingLayer") - generate!("nvinfer1::ICastLayer") - generate!("nvinfer1::IDeconvolutionLayer") - generate!("nvinfer1::IQuantizeLayer") - generate!("nvinfer1::IDequantizeLayer") - generate!("nvinfer1::IAssertionLayer") - generate!("nvinfer1::ICumulativeLayer") - generate!("nvinfer1::ILoop") - generate!("nvinfer1::IIfConditional") - // NOTE: IRNNv2Layer is deprecated (TRT_DEPRECATED) and autocxx cannot generate bindings for it - // RNN operations (lstm, lstmCell, gru, gruCell) remain deferred until we can work around this - // generate!("nvinfer1::IRNNv2Layer") - - generate!("nvinfer1::IRuntime") - generate!("nvinfer1::ICudaEngine") - generate!("nvinfer1::IExecutionContext") - generate!("nvinfer1::IHostMemory") - - // Try generating Dims64 directly (base class, not the typedef alias) - generate_pod!("nvinfer1::Dims64") - - generate_pod!("nvinfer1::DataType") - generate_pod!("nvinfer1::TensorIOMode") - generate_pod!("nvinfer1::MemoryPoolType") - generate_pod!("nvinfer1::NetworkDefinitionCreationFlag") - generate_pod!("nvinfer1::ActivationType") - generate_pod!("nvinfer1::PoolingType") - generate_pod!("nvinfer1::ElementWiseOperation") - generate_pod!("nvinfer1::MatrixOperation") - generate_pod!("nvinfer1::UnaryOperation") - generate_pod!("nvinfer1::ReduceOperation") - generate_pod!("nvinfer1::CumulativeOperation") - generate_pod!("nvinfer1::GatherMode") - generate_pod!("nvinfer1::ScatterMode") - generate_pod!("nvinfer1::InterpolationMode") - generate_pod!("nvinfer1::ResizeCoordinateTransformation") - generate_pod!("nvinfer1::ResizeSelector") - generate_pod!("nvinfer1::ResizeRoundMode") - generate_pod!("nvinfer1::ProfilingVerbosity") - generate_pod!("nvinfer1::EngineCapability") - generate_pod!("nvinfer1::BuilderFlag") - generate_pod!("nvinfer1::BuilderFlags") - generate_pod!("nvinfer1::DeviceType") - generate_pod!("nvinfer1::TacticSource") - generate_pod!("nvinfer1::TacticSources") - generate_pod!("nvinfer1::PreviewFeature") - generate_pod!("nvinfer1::HardwareCompatibilityLevel") - generate_pod!("nvinfer1::RuntimePlatform") - generate_pod!("nvinfer1::TilingOptimizationLevel") - generate_pod!("nvinfer1::ComputeCapability") - // NOTE: RNN enums commented out because IRNNv2Layer (deprecated) cannot be generated - // generate!("nvinfer1::RNNOperation") - // generate!("nvinfer1::RNNDirection") - // generate!("nvinfer1::RNNInputMode") - // generate!("nvinfer1::RNNGateType") - generate_pod!("nvinfer1::Weights") - generate_pod!("nvinfer1::Permutation") - generate!("nvinfer1::TensorFormat") - - // NOTE: createInferBuilder/Runtime moved to logger_bridge.cpp (autocxx struggles with these) - - // ONNX Parser - generate!("nvonnxparser::IParser") - // NOTE: createParser also moved to logger_bridge.cpp +} +pub unsafe trait TrtLayer { + fn as_layer(&self) -> &nvinfer1::ILayer { + // can't use safe `as_ref() -> &nvinfer1::ILayer` because only implemented for direct + // subclasses of ILayer + unsafe { + (self as *const Self as *const nvinfer1::ILayer) + .as_ref() + .unwrap() + } } + fn as_layer_pin_mut(&mut self) -> Pin<&mut nvinfer1::ILayer> { + unsafe { + Pin::new_unchecked( + (self as *mut Self as *mut nvinfer1::ILayer) + .as_mut() + .unwrap(), + ) + } + } +} +pub unsafe trait ConcreteTrtLayer: TrtLayer { + const TYPE: LayerType; +} - // Logger bridge C functions - extern "C" { - pub fn get_tensorrt_version() -> u32; - pub fn create_rust_logger_bridge( - callback: RustLogCallback, - user_data: *mut std::ffi::c_void, - ) -> *mut RustLoggerBridge; - - pub fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge); +unsafe impl TrtLayer for nvinfer1::IActivationLayer {} +unsafe impl TrtLayer for nvinfer1::IConvolutionLayer {} +unsafe impl TrtLayer for nvinfer1::ICastLayer {} +unsafe impl TrtLayer for nvinfer1::IPoolingLayer {} +unsafe impl TrtLayer for nvinfer1::ILRNLayer {} +unsafe impl TrtLayer for nvinfer1::IScaleLayer {} +unsafe impl TrtLayer for nvinfer1::ISoftMaxLayer {} +unsafe impl TrtLayer for nvinfer1::IDeconvolutionLayer {} +unsafe impl TrtLayer for nvinfer1::IConcatenationLayer {} +unsafe impl TrtLayer for nvinfer1::IElementWiseLayer {} +unsafe impl TrtLayer for nvinfer1::IUnaryLayer {} +unsafe impl TrtLayer for nvinfer1::IPaddingLayer {} +unsafe impl TrtLayer for nvinfer1::IShuffleLayer {} +unsafe impl TrtLayer for nvinfer1::IReduceLayer {} +unsafe impl TrtLayer for nvinfer1::ITopKLayer {} +unsafe impl TrtLayer for nvinfer1::IGatherLayer {} +unsafe impl TrtLayer for nvinfer1::IMatrixMultiplyLayer {} +unsafe impl TrtLayer for nvinfer1::IRaggedSoftMaxLayer {} +unsafe impl TrtLayer for nvinfer1::IConstantLayer {} +unsafe impl TrtLayer for nvinfer1::IIdentityLayer {} +unsafe impl TrtLayer for nvinfer1::ISliceLayer {} +unsafe impl TrtLayer for nvinfer1::IShapeLayer {} +unsafe impl TrtLayer for nvinfer1::IParametricReLULayer {} +unsafe impl TrtLayer for nvinfer1::IResizeLayer {} +unsafe impl TrtLayer for nvinfer1::ISelectLayer {} +unsafe impl TrtLayer for nvinfer1::IFillLayer {} +unsafe impl TrtLayer for nvinfer1::IQuantizeLayer {} +unsafe impl TrtLayer for nvinfer1::IDequantizeLayer {} +unsafe impl TrtLayer for nvinfer1::IScatterLayer {} +unsafe impl TrtLayer for nvinfer1::IEinsumLayer {} +unsafe impl TrtLayer for nvinfer1::IAssertionLayer {} +unsafe impl TrtLayer for nvinfer1::IOneHotLayer {} +unsafe impl TrtLayer for nvinfer1::INonZeroLayer {} +unsafe impl TrtLayer for nvinfer1::IGridSampleLayer {} +unsafe impl TrtLayer for nvinfer1::INMSLayer {} +unsafe impl TrtLayer for nvinfer1::IReverseSequenceLayer {} +unsafe impl TrtLayer for nvinfer1::INormalizationLayer {} +unsafe impl TrtLayer for nvinfer1::ISqueezeLayer {} +unsafe impl TrtLayer for nvinfer1::IUnsqueezeLayer {} +unsafe impl TrtLayer for nvinfer1::ICumulativeLayer {} +unsafe impl TrtLayer for nvinfer1::IDynamicQuantizeLayer {} +unsafe impl TrtLayer for nvinfer1::IRotaryEmbeddingLayer {} +unsafe impl TrtLayer for nvinfer1::IKVCacheUpdateLayer {} + +unsafe impl TrtLayer for nvinfer1::IAttentionInputLayer {} +unsafe impl TrtLayer for nvinfer1::IAttentionOutputLayer {} +unsafe impl TrtLayer for nvinfer1::ILoopBoundaryLayer {} +unsafe impl TrtLayer for nvinfer1::ILoopOutputLayer {} +unsafe impl TrtLayer for nvinfer1::IRecurrenceLayer {} +unsafe impl TrtLayer for nvinfer1::ITripLimitLayer {} +unsafe impl TrtLayer for nvinfer1::IIteratorLayer {} +unsafe impl TrtLayer for nvinfer1::IConditionLayer {} +unsafe impl TrtLayer for nvinfer1::IIfConditionalOutputLayer {} +unsafe impl TrtLayer for nvinfer1::IIfConditionalInputLayer {} +unsafe impl TrtLayer for nvinfer1::IAttentionBoundaryLayer {} +// this one is not concrete +unsafe impl TrtLayer for nvinfer1::ILayer {} + +// indirect subclasses of ILayer e.g. via ILoopBoundaryLayer, IAttentionBoundaryLayer, IIfConditionalBoundaryLayer + +unsafe impl ConcreteTrtLayer for nvinfer1::IActivationLayer { + const TYPE: LayerType = LayerType::kACTIVATION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IConvolutionLayer { + const TYPE: LayerType = LayerType::kCONVOLUTION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ICastLayer { + const TYPE: LayerType = LayerType::kCAST; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IPoolingLayer { + const TYPE: LayerType = LayerType::kPOOLING; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ILRNLayer { + const TYPE: LayerType = LayerType::kLRN; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IScaleLayer { + const TYPE: LayerType = LayerType::kSCALE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ISoftMaxLayer { + const TYPE: LayerType = LayerType::kSOFTMAX; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IDeconvolutionLayer { + const TYPE: LayerType = LayerType::kDECONVOLUTION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IConcatenationLayer { + const TYPE: LayerType = LayerType::kCONCATENATION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IElementWiseLayer { + const TYPE: LayerType = LayerType::kELEMENTWISE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IUnaryLayer { + const TYPE: LayerType = LayerType::kUNARY; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IPaddingLayer { + const TYPE: LayerType = LayerType::kPADDING; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IShuffleLayer { + const TYPE: LayerType = LayerType::kSHUFFLE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IReduceLayer { + const TYPE: LayerType = LayerType::kREDUCE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ITopKLayer { + const TYPE: LayerType = LayerType::kTOPK; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IGatherLayer { + const TYPE: LayerType = LayerType::kGATHER; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IMatrixMultiplyLayer { + const TYPE: LayerType = LayerType::kMATRIX_MULTIPLY; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IRaggedSoftMaxLayer { + const TYPE: LayerType = LayerType::kRAGGED_SOFTMAX; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IConstantLayer { + const TYPE: LayerType = LayerType::kCONSTANT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IIdentityLayer { + const TYPE: LayerType = LayerType::kIDENTITY; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ISliceLayer { + const TYPE: LayerType = LayerType::kSLICE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IShapeLayer { + const TYPE: LayerType = LayerType::kSHAPE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IParametricReLULayer { + const TYPE: LayerType = LayerType::kPARAMETRIC_RELU; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IResizeLayer { + const TYPE: LayerType = LayerType::kRESIZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ISelectLayer { + const TYPE: LayerType = LayerType::kSELECT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IFillLayer { + const TYPE: LayerType = LayerType::kFILL; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IQuantizeLayer { + const TYPE: LayerType = LayerType::kQUANTIZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IDequantizeLayer { + const TYPE: LayerType = LayerType::kDEQUANTIZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IScatterLayer { + const TYPE: LayerType = LayerType::kSCATTER; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IEinsumLayer { + const TYPE: LayerType = LayerType::kEINSUM; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IAssertionLayer { + const TYPE: LayerType = LayerType::kASSERTION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IOneHotLayer { + const TYPE: LayerType = LayerType::kONE_HOT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::INonZeroLayer { + const TYPE: LayerType = LayerType::kNON_ZERO; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IGridSampleLayer { + const TYPE: LayerType = LayerType::kGRID_SAMPLE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::INMSLayer { + const TYPE: LayerType = LayerType::kNMS; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IReverseSequenceLayer { + const TYPE: LayerType = LayerType::kREVERSE_SEQUENCE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::INormalizationLayer { + const TYPE: LayerType = LayerType::kNORMALIZATION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ISqueezeLayer { + const TYPE: LayerType = LayerType::kSQUEEZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IUnsqueezeLayer { + const TYPE: LayerType = LayerType::kUNSQUEEZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ICumulativeLayer { + const TYPE: LayerType = LayerType::kCUMULATIVE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IDynamicQuantizeLayer { + const TYPE: LayerType = LayerType::kDYNAMIC_QUANTIZE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IRotaryEmbeddingLayer { + const TYPE: LayerType = LayerType::kROTARY_EMBEDDING; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IKVCacheUpdateLayer { + const TYPE: LayerType = LayerType::kKVCACHE_UPDATE; +} - pub fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; // Returns ILogger* +// indirect subclasses of ILayer e.g. via ILoopBoundaryLayer, IAttentionBoundaryLayer, IIfConditionalBoundaryLayer - // TensorRT factory functions (wrapped as simple C functions) - #[cfg(feature = "link_tensorrt_rtx")] - pub fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IBuilder* +unsafe impl ConcreteTrtLayer for nvinfer1::IAttentionInputLayer { + const TYPE: LayerType = LayerType::kATTENTION_INPUT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IAttentionOutputLayer { + const TYPE: LayerType = LayerType::kATTENTION_OUTPUT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ILoopBoundaryLayer { + const TYPE: LayerType = LayerType::kTRIP_LIMIT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ILoopOutputLayer { + const TYPE: LayerType = LayerType::kLOOP_OUTPUT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IRecurrenceLayer { + const TYPE: LayerType = LayerType::kRECURRENCE; +} +unsafe impl ConcreteTrtLayer for nvinfer1::ITripLimitLayer { + const TYPE: LayerType = LayerType::kTRIP_LIMIT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IIteratorLayer { + const TYPE: LayerType = LayerType::kITERATOR; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IConditionLayer { + const TYPE: LayerType = LayerType::kCONDITION; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IIfConditionalOutputLayer { + const TYPE: LayerType = LayerType::kCONDITIONAL_OUTPUT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IIfConditionalInputLayer { + const TYPE: LayerType = LayerType::kCONDITIONAL_INPUT; +} +unsafe impl ConcreteTrtLayer for nvinfer1::IAttentionBoundaryLayer { + const TYPE: LayerType = LayerType::kATTENTION_INPUT; +} - #[cfg(feature = "link_tensorrt_rtx")] - pub fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IRuntime* +// Logger bridge C functions +unsafe extern "C" { + pub unsafe fn get_tensorrt_version() -> u32; + pub unsafe fn create_rust_logger_bridge( + callback: RustLogCallback, + user_data: *mut std::ffi::c_void, + ) -> *mut RustLoggerBridge; - // ONNX Parser factory function - #[cfg(feature = "link_tensorrt_onnxparser")] - pub fn create_onnx_parser( - network: *mut std::ffi::c_void, - logger: *mut std::ffi::c_void, - ) -> *mut std::ffi::c_void; // Returns IParser* + pub unsafe fn destroy_rust_logger_bridge(logger: *mut RustLoggerBridge); - // Builder methods - pub fn builder_create_network_v2( - builder: *mut std::ffi::c_void, + pub unsafe fn get_logger_interface(logger: *mut RustLoggerBridge) -> *mut std::ffi::c_void; // Returns ILogger* + // + pub unsafe fn trtx_create_progress_monitor( + user_data: *mut std::ffi::c_void, + phaseStart: unsafe extern "system" fn( + user_data: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, + parentPhase: *const ::std::os::raw::c_char, + nbSteps: i32, + ), + stepComplete: unsafe extern "system" fn( + user_data: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, + step: i32, + ) -> bool, + phaseFinish: unsafe extern "system" fn( + user_data: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, + ), + ) -> *mut nvinfer1::IProgressMonitor; + pub unsafe fn trtx_destroy_progress_monitor(cpp_obj: *mut nvinfer1::IProgressMonitor); + pub unsafe fn trtx_create_gpu_allocator( + rust_impl: *mut std::ffi::c_void, + allocateAsync: unsafe extern "system" fn( + this: *const std::ffi::c_void, + size: u64, + alignment: u64, flags: u32, - ) -> *mut std::ffi::c_void; - - pub fn builder_create_config(builder: *mut std::ffi::c_void) -> *mut std::ffi::c_void; - - pub fn builder_build_serialized_network( - builder: *mut std::ffi::c_void, - network: *mut std::ffi::c_void, - config: *mut std::ffi::c_void, - out_size: *mut usize, - ) -> *mut std::ffi::c_void; - - pub fn builder_config_set_memory_pool_limit( - config: *mut std::ffi::c_void, - pool_type: i32, - limit: usize, - ); - - // Network methods - // network_add_input - REMOVED - Using direct autocxx - // network_add_convolution - REMOVED - Using direct autocxx - // network_add_constant - REMOVED - Using direct autocxx - // network_add_scale - REMOVED - Using direct autocxx - - pub fn network_mark_output( - network: *mut std::ffi::c_void, - tensor: *mut std::ffi::c_void, - ) -> bool; - - pub fn network_get_nb_inputs(network: *mut std::ffi::c_void) -> i32; - pub fn network_get_nb_outputs(network: *mut std::ffi::c_void) -> i32; - pub fn network_get_input( - network: *mut std::ffi::c_void, - index: i32, - ) -> *mut std::ffi::c_void; - pub fn network_get_output( - network: *mut std::ffi::c_void, - index: i32, - ) -> *mut std::ffi::c_void; - - // network_add_activation - REMOVED - Using direct autocxx - - // network_add_pooling - REMOVED - Using direct autocxx - - // network_add_elementwise - REMOVED - Using direct autocxx - - // network_add_shuffle - REMOVED - Using direct autocxx - - pub fn network_add_concatenation( - network: *mut std::ffi::c_void, - inputs: *mut *mut std::ffi::c_void, - nb_inputs: i32, - ) -> *mut std::ffi::c_void; - - // network_add_reduce - REMOVED - Using direct autocxx - - // network_add_slice - REMOVED - Using direct autocxx - - // network_add_resize - REMOVED - Using direct autocxx - - // network_add_topk - REMOVED - Using direct autocxx - - // network_add_gather - REMOVED - Using direct autocxx - - // network_add_select - REMOVED - Using direct autocxx - - pub fn network_add_assertion( - network: *mut std::ffi::c_void, - condition: *mut std::ffi::c_void, - message: *const std::os::raw::c_char, - ) -> *mut std::ffi::c_void; - - pub fn network_add_loop(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void; - - pub fn network_add_if_conditional(network: *mut std::ffi::c_void) -> *mut std::ffi::c_void; - - // Tensor methods - pub fn tensor_get_name(tensor: *mut std::ffi::c_void) -> *const std::os::raw::c_char; - pub fn tensor_set_name(tensor: *mut std::ffi::c_void, name: *const std::os::raw::c_char); - pub fn tensor_get_dimensions( - tensor: *mut std::ffi::c_void, - dims: *mut i32, - nb_dims: *mut i32, - ) -> *mut std::ffi::c_void; - pub fn tensor_get_type(tensor: *mut std::ffi::c_void) -> i32; - - // Runtime methods - pub fn runtime_deserialize_cuda_engine( - runtime: *mut std::ffi::c_void, - data: *const std::ffi::c_void, - size: usize, - ) -> *mut std::ffi::c_void; - - // Engine methods - pub fn engine_get_nb_io_tensors(engine: *mut std::ffi::c_void) -> i32; - pub fn engine_get_tensor_name( - engine: *mut std::ffi::c_void, - index: i32, - ) -> *const std::os::raw::c_char; - pub fn engine_create_execution_context( - engine: *mut std::ffi::c_void, - ) -> *mut std::ffi::c_void; - - // ExecutionContext methods - pub fn context_set_tensor_address( - context: *mut std::ffi::c_void, - name: *const std::os::raw::c_char, - data: *mut std::ffi::c_void, - ) -> bool; - pub fn context_enqueue_v3( - context: *mut std::ffi::c_void, + cuda_stream: *mut std::ffi::c_void, + ) -> *mut std::ffi::c_void, + reallocate: unsafe extern "system" fn( + this: *const std::ffi::c_void, + memory: *mut std::ffi::c_void, + alignment: u64, + new_size: u64, + ) -> *mut std::ffi::c_void, + deallocateAsync: unsafe extern "system" fn( + this: *const std::ffi::c_void, + memory: *mut std::ffi::c_void, + cuda_stream: *mut std::ffi::c_void, + ) -> bool, + ) -> *mut nvinfer1::IGpuAllocator; + pub unsafe fn trtx_destroy_gpu_allocator(cpp_obj: *mut nvinfer1::IGpuAllocator); + pub unsafe fn trtx_create_error_recorder( + rust_impl: *mut std::ffi::c_void, + getNbErrors: *mut std::ffi::c_void, + getErrorCode: *mut std::ffi::c_void, + getErrorDesc: *mut std::ffi::c_void, + hasOverflowed: *mut std::ffi::c_void, + clear: *mut std::ffi::c_void, + reportError: *mut std::ffi::c_void, + incRefCount: *mut std::ffi::c_void, + decRefCount: *mut std::ffi::c_void, + ) -> *mut nvinfer1::IErrorRecorder; + pub unsafe fn trtx_destroy_error_recorder(cpp_obj: *mut nvinfer1::IErrorRecorder); + + pub unsafe fn trtx_create_debug_listener( + rust_impl: *mut std::ffi::c_void, + processDebugTensor: unsafe extern "system" fn( + this: *const std::ffi::c_void, + addr: *const std::ffi::c_void, + location: nvinfer1::TensorLocation, + type_: nvinfer1::DataType, + shape: *const Dims64, + name: *const std::ffi::c_char, stream: *mut std::ffi::c_void, - ) -> bool; - - // Parser methods - pub fn parser_parse( - parser: *mut std::ffi::c_void, - data: *const std::ffi::c_void, - size: usize, - ) -> bool; - pub fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32; - pub fn parser_get_error(parser: *mut std::ffi::c_void, index: i32) - -> *mut std::ffi::c_void; - pub fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char; - - // Destruction methods - pub fn delete_builder(builder: *mut std::ffi::c_void); - pub fn delete_network(network: *mut std::ffi::c_void); - pub fn delete_config(config: *mut std::ffi::c_void); - pub fn delete_runtime(runtime: *mut std::ffi::c_void); - pub fn delete_engine(engine: *mut std::ffi::c_void); - pub fn delete_context(context: *mut std::ffi::c_void); - pub fn delete_parser(parser: *mut std::ffi::c_void); - } + ) -> bool, + ) -> *mut nvinfer1::IDebugListener; + + // TensorRT factory functions (wrapped as simple C functions) + #[cfg(feature = "link_tensorrt_rtx")] + pub unsafe fn create_infer_builder(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IBuilder* + + #[cfg(feature = "link_tensorrt_rtx")] + pub unsafe fn create_infer_runtime(logger: *mut std::ffi::c_void) -> *mut std::ffi::c_void; // Returns IRuntime* + + #[cfg(feature = "link_tensorrt_rtx")] + pub fn create_infer_refitter( + cuda_engine: *mut std::ffi::c_void, + logger: *mut std::ffi::c_void, + ) -> *mut std::ffi::c_void; // Returns IRefitter* + + pub unsafe fn trtx_refitter_get_missing( + refitter: *mut std::ffi::c_void, + size: i32, + layer_names: *mut *const std::os::raw::c_char, + roles: *mut i32, + ) -> i32; + + pub unsafe fn trtx_refitter_get_all( + refitter: *mut std::ffi::c_void, + size: i32, + layer_names: *mut *const std::os::raw::c_char, + roles: *mut i32, + ) -> i32; + + pub unsafe fn trtx_refitter_get_missing_weights( + refitter: *mut std::ffi::c_void, + size: i32, + weights_names: *mut *const std::os::raw::c_char, + ) -> i32; + + pub unsafe fn trtx_refitter_get_all_weights( + refitter: *mut std::ffi::c_void, + size: i32, + weights_names: *mut *const std::os::raw::c_char, + ) -> i32; + + // ONNX Parser factory function + #[cfg(feature = "link_tensorrt_onnxparser")] + pub unsafe fn create_onnx_parser( + network: *mut std::ffi::c_void, + logger: *mut std::ffi::c_void, + ) -> *mut std::ffi::c_void; // Returns IParser* + // + pub unsafe fn network_add_concatenation( + network: *mut std::ffi::c_void, + inputs: *mut *mut std::ffi::c_void, + nb_inputs: i32, + ) -> *mut std::ffi::c_void; + + // Parser methods + pub unsafe fn parser_parse( + parser: *mut std::ffi::c_void, + data: *const std::ffi::c_void, + size: usize, + ) -> bool; + pub unsafe fn parser_get_nb_errors(parser: *mut std::ffi::c_void) -> i32; + pub unsafe fn parser_get_error( + parser: *mut std::ffi::c_void, + index: i32, + ) -> *mut std::ffi::c_void; + pub unsafe fn parser_error_desc(error: *mut std::ffi::c_void) -> *const std::os::raw::c_char; - // Opaque type for logger bridge - #[repr(C)] - pub struct RustLoggerBridge { - _unused: [u8; 0], - } +} - // Rust callback type for logger - pub type RustLogCallback = unsafe extern "C" fn( - user_data: *mut std::ffi::c_void, - severity: i32, - msg: *const std::os::raw::c_char, - ); +// Opaque type for logger bridge +#[repr(C)] +pub struct RustLoggerBridge { + _unused: [u8; 0], +} - // Re-export TensorRT types from the private ffi module - pub mod nvinfer1 { - pub use super::ffi::nvinfer1::*; - } +// Rust callback type for logger +pub type RustLogCallback = unsafe extern "C" fn( + user_data: *mut std::ffi::c_void, + severity: i32, + msg: *const std::os::raw::c_char, +); - #[cfg(feature = "onnxparser")] - pub mod nvonnxparser { - pub use super::ffi::nvonnxparser::*; - } +// Re-export TensorRT types from the private ffi module +pub mod nvinfer1 { + pub use super::ffi::nvinfer1::*; +} - // Re-export Dims64 as Dims to match TensorRT's typedef - pub use nvinfer1::Dims64; - pub type Dims = Dims64; - - // Re-export InterpolationMode as ResizeMode to match TensorRT's typedef - pub use nvinfer1::InterpolationMode; - pub type ResizeMode = InterpolationMode; - - /// Helper methods for Dims construction (avoiding name collision with generated constructor) - impl Dims64 { - /// Create a Dims from a slice of dimensions - pub fn from_slice(dims: &[i64]) -> Self { - let mut d = [0i64; 8]; - let nb_dims = dims.len().min(8) as i32; - d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]); - Self { nbDims: nb_dims, d } - } +#[cfg(feature = "onnxparser")] +pub mod nvonnxparser { + pub use super::ffi::nvonnxparser::*; +} - /// Create a 2D Dims - pub fn new_2d(d0: i64, d1: i64) -> Self { - Self { - nbDims: 2, - d: [d0, d1, 0, 0, 0, 0, 0, 0], - } +// Re-export Dims64 as Dims to match TensorRT's typedef +pub use nvinfer1::Dims64; +pub type Dims = Dims64; + +// Re-export InterpolationMode as ResizeMode to match TensorRT's typedef +pub type ResizeMode = InterpolationMode; + +/// Helper methods for Dims construction (avoiding name collision with generated constructor) +impl Dims64 { + /// Create a Dims from a slice of dimensions + pub fn from_slice(dims: &[i64]) -> Self { + let mut d = [0i64; 8]; + let nb_dims = dims.len().min(8) as i32; + d[..nb_dims as usize].copy_from_slice(&dims[..nb_dims as usize]); + Self { nbDims: nb_dims, d } + } + + /// Create a 2D Dims + pub fn new_2d(d0: i64, d1: i64) -> Self { + Self { + nbDims: 2, + d: [d0, d1, 0, 0, 0, 0, 0, 0], } + } - /// Create a 3D Dims - pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self { - Self { - nbDims: 3, - d: [d0, d1, d2, 0, 0, 0, 0, 0], - } + /// Create a 3D Dims + pub fn new_3d(d0: i64, d1: i64, d2: i64) -> Self { + Self { + nbDims: 3, + d: [d0, d1, d2, 0, 0, 0, 0, 0], } + } - /// Create a 4D Dims - pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self { - Self { - nbDims: 4, - d: [d0, d1, d2, d3, 0, 0, 0, 0], - } + /// Create a 4D Dims + pub fn new_4d(d0: i64, d1: i64, d2: i64, d3: i64) -> Self { + Self { + nbDims: 4, + d: [d0, d1, d2, d3, 0, 0, 0, 0], } } +} - // Re-export Weights - pub use nvinfer1::Weights; - - /// Helper methods for Weights construction - impl nvinfer1::Weights { - /// Create a Weights with FLOAT data type - pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self { - Self { - type_: nvinfer1::DataType::kFLOAT, - values: values_ptr, - count: count_val, - } +// Re-export Weights +pub use nvinfer1::Weights; + +/// Helper methods for Weights construction +impl nvinfer1::Weights { + /// Create a Weights with FLOAT data type + pub fn new_float(values_ptr: *const std::ffi::c_void, count_val: i64) -> Self { + Self { + type_: nvinfer1::DataType::kFLOAT, + values: values_ptr, + count: count_val, } + } - /// Create a Weights with specified data type - pub fn new_with_type( - data_type: nvinfer1::DataType, - values_ptr: *const std::ffi::c_void, - count_val: i64, - ) -> Self { - Self { - type_: data_type, - values: values_ptr, - count: count_val, - } + /// Create a Weights with specified data type + pub fn new_with_type( + data_type: nvinfer1::DataType, + values_ptr: *const std::ffi::c_void, + count_val: i64, + ) -> Self { + Self { + type_: data_type, + values: values_ptr, + count: count_val, } } } -#[cfg(not(feature = "mock"))] -pub use real_bindings::*; - -#[cfg(test)] -mod tests { - #[cfg(feature = "mock")] - use super::*; - - #[test] - #[cfg(feature = "mock")] - fn test_constants() { - // Verify error codes are defined - assert_eq!(TRTX_SUCCESS, 0); - assert_ne!(TRTX_ERROR_INVALID_ARGUMENT, TRTX_SUCCESS); +impl DataType { + pub const fn size_bits(self) -> usize { + match self { + DataType::kFLOAT => 32, + DataType::kHALF => 16, + DataType::kINT8 => 8, + DataType::kINT32 => 32, + DataType::kBOOL => 8, + DataType::kUINT8 => 8, + DataType::kFP8 => 8, + DataType::kBF16 => 16, + DataType::kINT64 => 64, + DataType::kINT4 => 4, + DataType::kFP4 => 4, + DataType::kE8M0 => 8, + } } } diff --git a/trtx-sys/tests/autocxx_methods_test.rs b/trtx-sys/tests/autocxx_methods_test.rs deleted file mode 100644 index 94e9568..0000000 --- a/trtx-sys/tests/autocxx_methods_test.rs +++ /dev/null @@ -1,38 +0,0 @@ -// Test to check if autocxx generates TensorRT class methods -// This helps us understand if we can remove manual C wrappers - -#[cfg(not(feature = "mock"))] -#[test] -#[ignore] // Ignore by default, run with --ignored -fn test_autocxx_builder_methods_exist() { - // This test just checks if the methods compile - // We're not actually running them - - // Check if we can access IBuilder type - let _builder_type: Option<*mut trtx_sys::nvinfer1::IBuilder> = None; - - // Try to call a method (won't execute, just checking if it compiles) - // Uncomment to test: - // if let Some(builder) = builder_ptr.as_mut() { - // let _network = builder.createNetworkV2(0); - // } - - println!("If this compiles, autocxx generated IBuilder bindings"); -} - -#[cfg(not(feature = "mock"))] -#[test] -#[ignore] -fn test_check_available_types() { - // List all types we expect autocxx to generate - - // Check if types exist - let _: Option<*mut trtx_sys::nvinfer1::IBuilder> = None; - let _: Option<*mut trtx_sys::nvinfer1::INetworkDefinition> = None; - let _: Option<*mut trtx_sys::nvinfer1::IBuilderConfig> = None; - let _: Option<*mut trtx_sys::nvinfer1::ICudaEngine> = None; - let _: Option<*mut trtx_sys::nvinfer1::IExecutionContext> = None; - let _: Option<*mut trtx_sys::nvinfer1::IRuntime> = None; - - println!("All expected types are available from autocxx"); -} diff --git a/trtx-sys/tests/method_call_test.rs b/trtx-sys/tests/method_call_test.rs deleted file mode 100644 index 9149465..0000000 --- a/trtx-sys/tests/method_call_test.rs +++ /dev/null @@ -1,81 +0,0 @@ -// Comprehensive test: Can we call TensorRT methods via autocxx? -// This is the KEY test for determining if we can remove C wrappers - -#![allow(unused)] - -#[cfg(not(feature = "mock"))] -#[test] -#[ignore] // Run with: cargo test --test method_call_test -- --ignored -#[cfg(feature = "link_tensorrt_rtx")] -fn test_builder_methods_callable() { - use std::ptr; - - // First, we need a logger to create a builder - // Using the existing logger bridge (which IS necessary) - unsafe { - let callback: trtx_sys::RustLogCallback = test_logger_callback; - let logger_bridge = trtx_sys::create_rust_logger_bridge(callback, ptr::null_mut()); - assert!(!logger_bridge.is_null(), "Failed to create logger bridge"); - - let logger = trtx_sys::get_logger_interface(logger_bridge); - assert!(!logger.is_null(), "Failed to get logger interface"); - - // Create builder using factory (also necessary - takes ILogger&) - let builder_ptr = trtx_sys::create_infer_builder(logger as *mut _); - assert!(!builder_ptr.is_null(), "Failed to create builder"); - - // NOW THE KEY TEST: Can we call methods on IBuilder? - // Cast void* to IBuilder* - let builder = builder_ptr as *mut trtx_sys::nvinfer1::IBuilder; - - // Attempt to call createNetworkV2() - THIS IS THE TEST! - // If this compiles and works, we can remove builder_create_network_v2() wrapper - - // Note: We can't actually test this without proper setup, but we can check if it compiles - println!("Builder pointer: {:?}", builder); - - // Cleanup - trtx_sys::delete_builder(builder_ptr); - trtx_sys::destroy_rust_logger_bridge(logger_bridge); - } -} - -#[cfg(not(feature = "mock"))] -unsafe extern "C" fn test_logger_callback( - _user_data: *mut std::ffi::c_void, - severity: i32, - msg: *const std::os::raw::c_char, -) { - if !msg.is_null() { - let c_str = std::ffi::CStr::from_ptr(msg); - if let Ok(s) = c_str.to_str() { - println!("[TensorRT {}] {}", severity, s); - } - } -} - -// This test checks if autocxx provides method access AT COMPILE TIME -#[cfg(not(feature = "mock"))] -#[test] -fn test_autocxx_method_availability_compile_check() { - // This function won't run, but if it COMPILES, we know the methods exist - - #[allow(unreachable_code)] - { - return; // Don't actually run this - - unsafe { - let builder: *mut trtx_sys::nvinfer1::IBuilder = std::ptr::null_mut(); - - // Try to access methods - if these compile, autocxx generated them! - // Uncomment these one at a time to test: - - // Test 1: Does IBuilder have methods? - // let network = (*builder).createNetworkV2(0); - - // Test 2: Does INetworkDefinition have methods? - // let network: *mut trtx_sys::nvinfer1::INetworkDefinition = std::ptr::null_mut(); - // let tensor = (*network).addInput(...); - } - } -} diff --git a/trtx-sys/tests/simple_method_test.rs b/trtx-sys/tests/simple_method_test.rs deleted file mode 100644 index 1cb52a9..0000000 --- a/trtx-sys/tests/simple_method_test.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Ultra-simple test: What does autocxx actually generate for IBuilder? - -/// Helper to pin a raw pointer (same as in trtx crate) -#[cfg(not(feature = "mock"))] -unsafe fn pin_mut(ptr: *mut T) -> std::pin::Pin<&'static mut T> { - std::pin::Pin::new_unchecked(&mut *ptr) -} - -#[cfg(not(feature = "mock"))] -#[test] -fn check_what_autocxx_provides() { - // Just try to reference methods and see what the compiler tells us - - unsafe { - // Get a null pointer of the right type - let builder: *mut trtx_sys::nvinfer1::IBuilder = std::ptr::null_mut(); - - // Can we dereference it to access methods? - if !builder.is_null() { - // Using our helper function (same pattern as trtx crate) - let _network = pin_mut(builder).createNetworkV2(0); - - println!("SUCCESS! Called createNetworkV2 via autocxx!"); - } - - println!("Test compiled - Pin helper works!"); - } -} diff --git a/trtx/Cargo.toml b/trtx/Cargo.toml index c6b522c..3727274 100644 --- a/trtx/Cargo.toml +++ b/trtx/Cargo.toml @@ -16,16 +16,21 @@ thiserror = "2.0" cxx = "1.0" libc = "0.2" libloading = { version = "0.9", optional = true } +autocxx = { version = "0.30" } # cudarc for safe CUDA operations (required when real mode is enabled) # Using cuda-12050 as fallback; CUDA 13.x should be compatible -cudarc = { version = "0.11", features = ["driver", "cuda-12050"], optional = true } +cudarc = { version = "0.11", features = ["driver", "cuda-12050"] } [features] # real TensorRT-RTX with cudarc by and dynamic loading by default -default = ["real", "dlopen_tensorrt_onnxparser", "dlopen_tensorrt_rtx", "onnxparser", "v_1_3"] -mock = ["trtx-sys/mock"] # mock implementation (no CUDA required) -real = ["dep:cudarc", "trtx-sys/v_1_3"] # real TensorRT-RTX +default = [ + "dlopen_tensorrt_onnxparser", + "dlopen_tensorrt_rtx", + "onnxparser", + "v_1_3", +] +mock = [] # mock implementation (no CUDA required) link_tensorrt_rtx = ["trtx-sys/link_tensorrt_rtx"] dlopen_tensorrt_rtx = ["libloading"] link_tensorrt_onnxparser = ["trtx-sys/link_tensorrt_onnxparser", "onnxparser"] diff --git a/trtx/examples/basic_workflow.rs b/trtx/examples/basic_workflow.rs index 366a746..d53c432 100644 --- a/trtx/examples/basic_workflow.rs +++ b/trtx/examples/basic_workflow.rs @@ -32,7 +32,7 @@ fn main() -> Result<(), Box> { // Step 2: Build phase println!("2. Building engine..."); - let builder = Builder::new(&logger)?; + let mut builder = Builder::new(&logger)?; println!(" ✓ Builder created"); // Create network with explicit batch dimensions @@ -74,10 +74,10 @@ fn main() -> Result<(), Box> { // Step 3: Inference phase println!("3. Loading engine for inference..."); - let runtime = Runtime::new(&logger)?; + let mut runtime = Runtime::new(&logger)?; println!(" ✓ Runtime created"); - let engine = runtime.deserialize_cuda_engine(engine_data.as_ref())?; + let mut engine = runtime.deserialize_cuda_engine(engine_data.as_ref())?; println!(" ✓ Engine deserialized"); // Query engine information diff --git a/trtx/examples/tiny_network.rs b/trtx/examples/tiny_network.rs index e6df09a..28d53c7 100644 --- a/trtx/examples/tiny_network.rs +++ b/trtx/examples/tiny_network.rs @@ -12,7 +12,6 @@ use trtx::builder::MemoryPoolType; use trtx::cuda::{synchronize, DeviceBuffer}; use trtx::error::Result; -use trtx::network::Layer; // Import Layer trait for get_output method use trtx::{ActivationType, Builder, DataType, Logger, Runtime}; fn main() -> Result<()> { @@ -34,8 +33,8 @@ fn main() -> Result<()> { // 3. Create runtime and deserialize engine println!("\n3. Creating runtime and loading engine..."); - let runtime = Runtime::new(&logger)?; - let engine = runtime.deserialize_cuda_engine(&engine_data)?; + let mut runtime = Runtime::new(&logger)?; + let mut engine = runtime.deserialize_cuda_engine(&engine_data)?; // 4. Inspect engine println!("4. Engine information:"); @@ -171,27 +170,28 @@ fn main() -> Result<()> { /// Build a tiny network: Input -> ReLU -> Output fn build_tiny_network(logger: &Logger) -> Result> { println!(" Creating builder..."); - let builder = Builder::new(logger)?; + let mut builder = Builder::new(logger)?; println!(" Creating network with explicit batch..."); let mut network = builder.create_network(0)?; println!(" Adding input tensor [1, 3, 4, 4]..."); let input = network.add_input("input", DataType::kFLOAT, &[1, 3, 4, 4])?; - println!(" Input tensor name: {:?}", input.name()?); - println!(" Input tensor dims: {:?}", input.dimensions()?); + println!(" Input tensor name: {:?}", input.name(&network)?); + println!(" Input tensor dims: {:?}", input.dimensions(&network)?); println!(" Adding ReLU activation layer..."); - let activation_layer = network.add_activation(&input, ActivationType::kRELU)?; - let output = activation_layer.get_output(0)?; + let mut input = input; + let activation_layer = network.add_activation(&mut input, ActivationType::kRELU)?; + let output = activation_layer.get_output(&network, 0)?; println!(" Setting output tensor name..."); let mut output_named = output; - output_named.set_name("output")?; - println!(" Output tensor name: {:?}", output_named.name()?); + output_named.set_name(&mut network, "output")?; + println!(" Output tensor name: {:?}", output_named.name(&network)?); println!(" Marking output tensor..."); - network.mark_output(&output_named)?; + network.mark_output(&mut output_named); println!(" Network has {} inputs", network.get_nb_inputs()); println!(" Network has {} outputs", network.get_nb_outputs()); diff --git a/trtx/src/axes.rs b/trtx/src/axes.rs new file mode 100644 index 0000000..81eb80c --- /dev/null +++ b/trtx/src/axes.rs @@ -0,0 +1,184 @@ +//! Axis mask type for operations that reduce or normalize over selected axes. +//! +//! TensorRT represents axes as a `u32` bitmask: bit `i` set means axis `i` is included. + +use std::fmt; + +/// Bitmask of axes: each bit set indicates one axis (bit 0 = axis 0, bit 1 = axis 1, etc.). +#[derive(Clone, Copy, PartialEq, Eq, Default)] +pub struct Axes(pub u32); + +impl fmt::Debug for Axes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let bits = self.0; + if bits == 0 { + return f.write_str("Axes()"); + } + let indices: Vec = (0u32..32).filter(|&i| (bits & (1u32 << i)) != 0).collect(); + f.write_str("Axes(")?; + for (i, &axis) in indices.iter().enumerate() { + if i > 0 { + f.write_str("|")?; + } + write!(f, "{}", axis)?; + } + f.write_str(")") + } +} + +impl Axes { + /// No axes selected. + pub const fn empty() -> Self { + Axes(0) + } + + /// Wrap a raw axes bitmask. + #[inline] + pub const fn from_bits(bits: u32) -> Self { + Axes(bits) + } + + /// Build from a list of axis indices. Each index sets the corresponding bit. + /// + /// # Example + /// ``` + /// # use trtx::Axes; + /// // Axes 1 and 2 (e.g. for channel normalization over NCHW) + /// let axes = Axes::new([1, 2]); + /// assert_eq!(axes.to_bits(), 0b110); + /// ``` + #[inline] + pub const fn new(indices: [u32; N]) -> Self { + let mut bits = 0u32; + let mut i = 0; + while i < N { + bits |= 1u32.wrapping_shl(indices[i]); + i += 1; + } + Axes(bits) + } + + /// Build from a list of axis indices. Each index sets the corresponding bit. + /// + /// # Example + /// ``` + /// # use trtx::Axes; + /// // Axes 1 and 2 (e.g. for channel normalization over NCHW) + /// let axes = Axes::from_slice(&[1, 2]); + /// assert_eq!(axes.to_bits(), 0b110); + /// ``` + pub fn from_slice(indices: &[u32]) -> Self { + let mut bits = 0u32; + for &index in indices { + bits |= 1u32.wrapping_shl(index); + } + Axes(bits) + } + + /// Add one axis by index. Chainable for const construction. + #[inline] + pub const fn with_axis(self, axis: u32) -> Self { + Axes(self.0 | 1u32.wrapping_shl(axis)) + } + + /// Return the raw bitmask. + #[inline] + pub const fn to_bits(self) -> u32 { + self.0 + } +} + +impl From for Axes { + #[inline] + fn from(bits: u32) -> Self { + Axes(bits) + } +} + +impl From for u32 { + #[inline] + fn from(axes: Axes) -> u32 { + axes.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_is_zero() { + assert_eq!(Axes::empty().to_bits(), 0); + assert_eq!(Axes::default().to_bits(), 0); + } + + #[test] + fn from_bits_roundtrip() { + assert_eq!(Axes::from_bits(0).to_bits(), 0); + assert_eq!(Axes::from_bits(0b111).to_bits(), 0b111); + assert_eq!(Axes::from_bits(0xFFFF_FFFF).to_bits(), 0xFFFF_FFFF); + } + + #[test] + fn from_indices_empty() { + let axes = Axes::new([]); + assert_eq!(axes.to_bits(), 0); + } + + #[test] + fn from_indices_single() { + assert_eq!(Axes::new([0]).to_bits(), 1); + assert_eq!(Axes::new([3]).to_bits(), 1 << 3); + assert_eq!(Axes::new([31]).to_bits(), 1 << 31); + } + + #[test] + fn from_indices_multiple() { + let axes = Axes::new([1, 2]); + assert_eq!(axes.to_bits(), 0b110); + let axes = Axes::new([0, 2, 4]); + assert_eq!(axes.to_bits(), 0b10101); + } + + #[test] + fn from_indices_duplicate_same_as_unique() { + let a = Axes::new([1, 1, 2]); + let b = Axes::new([1, 2]); + assert_eq!(a.to_bits(), b.to_bits()); + } + + #[test] + fn with_axis_builder() { + let axes = Axes::empty().with_axis(0).with_axis(2); + assert_eq!(axes.to_bits(), 0b101); + let axes = Axes::from_bits(1).with_axis(1); + assert_eq!(axes.to_bits(), 0b11); + } + + #[test] + fn from_u32_into_u32() { + let raw: u32 = 0b10110; + let axes: Axes = raw.into(); + assert_eq!(axes.to_bits(), raw); + let back: u32 = axes.into(); + assert_eq!(back, raw); + } + + #[test] + fn eq_and_clone() { + let a = Axes::new([1, 2]); + let b = a; + let c = a.clone(); + assert_eq!(a, b); + assert_eq!(a, c); + assert_eq!(b.to_bits(), 0b110); + } + + #[test] + fn debug_shows_indices_separated_by_pipe() { + assert_eq!(format!("{:?}", Axes::empty()), "Axes()"); + assert_eq!(format!("{:?}", Axes::new([0])), "Axes(0)"); + assert_eq!(format!("{:?}", Axes::new([1, 2])), "Axes(1|2)"); + assert_eq!(format!("{:?}", Axes::new([0, 2, 4])), "Axes(0|2|4)"); + } +} diff --git a/trtx/src/builder.rs b/trtx/src/builder.rs index d3d3ccb..6525a99 100644 --- a/trtx/src/builder.rs +++ b/trtx/src/builder.rs @@ -1,6 +1,15 @@ //! Builder for creating TensorRT engines -//! -//! Delegates to real/ or mock/ based on feature flag. + +use crate::error::{Error, Result}; +use crate::host_memory::HostMemory; +use crate::interfaces::{ErrorRecorder, RecordError}; +use crate::logger::Logger; +use crate::network::NetworkDefinition; +use crate::optimization_profile::OptimizationProfile; +use autocxx::cxx::memory::UniquePtr; +use std::marker::PhantomData; +use std::pin::Pin; +use trtx_sys::nvinfer1::IBuilder; /// Network definition builder flags pub mod network_flags { @@ -8,12 +17,150 @@ pub mod network_flags { pub const EXPLICIT_BATCH: u32 = 1 << 0; } +pub use crate::builder_config::BuilderConfig; pub use trtx_sys::{ BuilderFlag, ComputeCapability, DeviceType, EngineCapability, HardwareCompatibilityLevel, MemoryPoolType, PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel, }; -#[cfg(feature = "mock")] -pub use crate::mock::builder::{Builder, BuilderConfig}; -#[cfg(not(feature = "mock"))] -pub use crate::real::builder::{Builder, BuilderConfig}; +/// Builder for creating TensorRT engines +pub struct Builder<'a> { + inner: UniquePtr, + _logger: PhantomData<&'a Logger>, + error_recorder: Option>>, +} + +impl<'builder> Builder<'builder> { + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(not(feature = "dlopen_tensorrt_rtx"))] + pub fn new(logger: &'a Logger) -> Result { + Err(Error::TrtRtxLibraryNotLoaded) + } + + #[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))] + pub fn new(logger: &'builder Logger) -> Result { + #[cfg(not(feature = "mock"))] + { + use trtx_sys::nvinfer1::IBuilder; + + let logger_ptr = logger.as_logger_ptr(); + let builder_ptr = { + #[cfg(feature = "link_tensorrt_rtx")] + unsafe { + trtx_sys::create_infer_builder(logger_ptr) + } + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(feature = "dlopen_tensorrt_rtx")] + unsafe { + use libloading::Symbol; + use std::ffi::c_void; + + use crate::TRTLIB; + if !TRTLIB.read()?.is_some() { + crate::dynamically_load_tensorrt(None::)?; + } + + let lock = TRTLIB.read()?; + let create_infer_builder: Symbol *mut IBuilder> = lock + .as_ref() + .ok_or(Error::TrtRtxLibraryNotLoaded)? + .get(b"createInferBuilder_INTERNAL")?; + create_infer_builder(logger_ptr, trtx_sys::get_tensorrt_version()) + } + }; + if builder_ptr.is_null() { + return Err(Error::Runtime("Failed to create builder".to_string())); + } + Ok(Builder { + inner: unsafe { UniquePtr::from_raw(builder_ptr) }, + error_recorder: None, + _logger: Default::default(), + }) + } + #[cfg(feature = "mock")] + Ok(Builder { + inner: UniquePtr::null(), + _logger: Default::default(), + }) + } + + pub fn create_network(&'_ mut self, flags: u32) -> Result> { + if cfg!(feature = "mock") { + Ok(NetworkDefinition::from_ptr(std::ptr::null_mut())) + } else { + let network_ptr = self.inner.pin_mut().createNetworkV2(flags); + let network = unsafe { network_ptr.as_mut() } + .ok_or_else(|| Error::Runtime("Failed to create network".to_string()))?; + Ok(NetworkDefinition::from_ptr(network)) + } + } + + pub fn create_config(&'_ mut self) -> Result { + #[cfg(not(feature = "mock"))] + let config_ptr = self.inner.pin_mut().createBuilderConfig(); + #[cfg(feature = "mock")] + let config_ptr = std::ptr::null_mut(); + BuilderConfig::new(config_ptr) + } + + pub fn build_serialized_network<'config_borrow, 'output>( + &mut self, + network: &mut NetworkDefinition, + config: &'config_borrow mut BuilderConfig, + ) -> Result> + where + 'output: 'config_borrow + 'builder, + { + if cfg!(feature = "mock") { + Ok(unsafe { HostMemory::from_raw(std::ptr::null_mut()) }) + } else { + let serialized_engine = unsafe { + self.inner + .pin_mut() + .buildSerializedNetwork(network.inner.pin_mut(), config.inner.pin_mut()) + .as_mut() + } + .ok_or_else(|| Error::Runtime("Failed to build serialized network".to_string()))?; + + Ok(unsafe { HostMemory::from_raw(serialized_engine) }) + } + } + + pub fn creata_optimization_profile(&mut self) -> Result> { + let profile = unsafe { + self.inner + .pin_mut() + .createOptimizationProfile() + .as_mut() + .ok_or_else(|| { + Error::Runtime("Failed to create optimization profile".to_string()) + })? + }; + Ok(OptimizationProfile::from_raw(profile)) + } + + /// See [trtx_sys::nvinfer1::IBuilder::setErrorRecorder] + /// + /// The Rust bindings only allow setting the error recorder once + pub fn set_error_recorder(&mut self, error_recorder: Box) -> Result<()> { + let error_recorder = ErrorRecorder::new(error_recorder)?; + if self.error_recorder.is_some() { + // would need to make sure that we don't destroy a monitor still in use + // could offer this as an unsafe method for users who only set this when there is no + // build process active. Or we only accept a ref to progress monitor and force user + // via lifetimes to keep this alive for builder config lifetime + panic!("Setting a progress monitor more than once not supported at the moment"); + } + self.error_recorder = Some(error_recorder); + let rec = self + .error_recorder + .as_mut() + .unwrap() + .as_trt_error_recorder(); + #[cfg(not(feature = "mock"))] + unsafe { + self.inner.pin_mut().setErrorRecorder(rec) + }; + Ok(()) + } +} diff --git a/trtx/src/builder_config.rs b/trtx/src/builder_config.rs new file mode 100644 index 0000000..d5b11f9 --- /dev/null +++ b/trtx/src/builder_config.rs @@ -0,0 +1,566 @@ +//! Real TensorRT builder config implementation + +use std::pin::Pin; + +use crate::error::PropertySetAttempt; +use crate::interfaces::MonitorProgress; +use crate::interfaces::ProgressMonitor; +use crate::optimization_profile::OptimizationProfile; +use crate::Error; +use crate::Result; +use cxx::UniquePtr; +use trtx_sys::nvinfer1::{self, IBuilderConfig}; +use trtx_sys::{ + BuilderFlag, ComputeCapability, DeviceType, EngineCapability, HardwareCompatibilityLevel, + MemoryPoolType, PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel, +}; + +/// Builder configuration (real mode) +pub struct BuilderConfig { + pub(crate) inner: UniquePtr, + progress_monitor: Option>>, +} + +impl BuilderConfig { + pub(crate) fn new(builder_config: *mut nvinfer1::IBuilderConfig) -> Result { + #[cfg(not(feature = "mock"))] + if builder_config.is_null() { + return Err(Error::BuilderConfigCreationFailed); + } + Ok(Self { + inner: unsafe { UniquePtr::from_raw(builder_config) }, + progress_monitor: None, + }) + } + + /// See [IBuilderConfig::setProgressMonitor] + /// The Rust bindings only allow setting the progress monitor once per builder config object + pub fn set_progress_monitor( + &mut self, + progress_monitor: Box, + ) -> Result<()> { + let progress_monitor = ProgressMonitor::new(progress_monitor)?; + if self.progress_monitor.is_some() { + // would need to make sure that we don't destroy a monitor still in use + // could offer this as an unsafe method for users who only set this when there is no + // build process active. Or we only accept a ref to progress monitor and force user + // via lifetimes to keep this alive for builder config lifetime + panic!("Setting a progress monitor more than once not supported at the moment"); + } + self.progress_monitor = Some(progress_monitor); + #[cfg(not(feature = "mock"))] + unsafe { + self.inner.pin_mut().setProgressMonitor( + self.progress_monitor + .as_mut() + .expect("progress_monitor can't be empty. we just set it") + .as_trt_progress_monitor(), + ) + }; + Ok(()) + } + + /// See [IBuilderConfig::setMemoryPoolLimit] + pub fn set_memory_pool_limit(&mut self, pool: MemoryPoolType, size: usize) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setMemoryPoolLimit(pool.into(), size); + } + + /// See [IBuilderConfig::setProfilingVerbosity] + pub fn set_profiling_verbosity(&mut self, verbosity: ProfilingVerbosity) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setProfilingVerbosity(verbosity.into()); + } + + /// See [IBuilderConfig::getProfilingVerbosity] + pub fn get_profiling_verbosity(&self) -> ProfilingVerbosity { + if cfg!(not(feature = "mock")) { + self.inner.getProfilingVerbosity().into() + } else { + ProfilingVerbosity::kNONE + } + } + + /// See [IBuilderConfig::setAvgTimingIterations] + pub fn set_avg_timing_iterations(&mut self, avg_timing: i32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setAvgTimingIterations(avg_timing); + } + + /// See [IBuilderConfig::getAvgTimingIterations] + pub fn get_avg_timing_iterations(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getAvgTimingIterations() + } else { + 0 + } + } + + /// See [IBuilderConfig::setEngineCapability] + pub fn set_engine_capability(&mut self, capability: EngineCapability) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setEngineCapability(capability.into()); + } + + /// See [IBuilderConfig::getEngineCapability] + pub fn get_engine_capability(&self) -> EngineCapability { + if cfg!(not(feature = "mock")) { + self.inner.getEngineCapability().into() + } else { + EngineCapability::kSTANDARD + } + } + + /// See [IBuilderConfig::setFlags] + pub fn set_flags(&mut self, flags: u32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setFlags(flags); + } + + /// See [IBuilderConfig::getFlags] + pub fn get_flags(&self) -> u32 { + if cfg!(not(feature = "mock")) { + self.inner.getFlags() + } else { + 0 + } + } + + /// See [IBuilderConfig::setFlag] + pub fn set_flag(&mut self, flag: BuilderFlag) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setFlag(flag.into()); + } + + /// See [IBuilderConfig::clearFlag] + pub fn clear_flag(&mut self, flag: BuilderFlag) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().clearFlag(flag.into()); + } + + /// See [IBuilderConfig::getFlag] + pub fn get_flag(&self, flag: BuilderFlag) -> bool { + if cfg!(not(feature = "mock")) { + self.inner.getFlag(flag.into()) + } else { + false + } + } + + /// See [IBuilderConfig::setDLACore] + pub fn set_dla_core(&mut self, dla_core: i32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setDLACore(dla_core); + } + + /// See [IBuilderConfig::getDLACore] + pub fn get_dla_core(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getDLACore() + } else { + 0 + } + } + + /// See [IBuilderConfig::setDefaultDeviceType] + pub fn set_default_device_type(&mut self, device_type: DeviceType) { + #[cfg(not(feature = "mock"))] + self.inner + .pin_mut() + .setDefaultDeviceType(device_type.into()); + } + + /// See [IBuilderConfig::getDefaultDeviceType] + pub fn get_default_device_type(&self) -> DeviceType { + if cfg!(not(feature = "mock")) { + self.inner.getDefaultDeviceType().into() + } else { + DeviceType::kGPU + } + } + + /// See [IBuilderConfig::reset] + pub fn reset(&mut self) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().reset(); + } + + /// See [IBuilderConfig::getNbOptimizationProfiles] + pub fn get_nb_optimization_profiles(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getNbOptimizationProfiles() + } else { + 0 + } + } + + /// See [IBuilderConfig::addOptimizationProfile]. + /// Returns the profile index (0-based) on success. + pub fn add_optimization_profile( + &mut self, + profile: &mut OptimizationProfile<'_>, + ) -> Result { + #[cfg(not(feature = "mock"))] + { + let idx = unsafe { + self.inner + .pin_mut() + .addOptimizationProfile(profile.inner.as_mut().get_unchecked_mut()) + }; + if idx >= 0 { + Ok(idx) + } else { + Err(Error::Runtime("addOptimizationProfile failed".to_string())) + } + } + #[cfg(feature = "mock")] + Ok(0) + } + + /// See [IBuilderConfig::setTacticSources] + pub fn set_tactic_sources(&mut self, sources: u32) -> crate::Result<()> { + if cfg!(not(feature = "mock")) { + if self.inner.pin_mut().setTacticSources(sources) { + Ok(()) + } else { + Err(crate::Error::FailedToSetProperty( + PropertySetAttempt::BuilderConfigTacticSources, + )) + } + } else { + Ok(()) + } + } + + /// See [IBuilderConfig::getTacticSources] + pub fn get_tactic_sources(&self) -> u32 { + if cfg!(not(feature = "mock")) { + self.inner.getTacticSources() + } else { + 0 + } + } + + /// See [IBuilderConfig::getMemoryPoolLimit] + pub fn get_memory_pool_limit(&self, pool: MemoryPoolType) -> usize { + if cfg!(not(feature = "mock")) { + self.inner.getMemoryPoolLimit(pool.into()) + } else { + 0 + } + } + + /// See [IBuilderConfig::setPreviewFeature] + pub fn set_preview_feature(&mut self, feature: PreviewFeature, enable: bool) { + #[cfg(not(feature = "mock"))] + self.inner + .pin_mut() + .setPreviewFeature(feature.into(), enable); + } + + /// See [IBuilderConfig::getPreviewFeature] + pub fn get_preview_feature(&self, feature: PreviewFeature) -> bool { + if cfg!(not(feature = "mock")) { + self.inner.getPreviewFeature(feature.into()) + } else { + false + } + } + + /// See [IBuilderConfig::setBuilderOptimizationLevel] + pub fn set_builder_optimization_level(&mut self, level: i32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setBuilderOptimizationLevel(level); + } + + /// See [IBuilderConfig::getBuilderOptimizationLevel] + pub fn get_builder_optimization_level(&mut self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.pin_mut().getBuilderOptimizationLevel() + } else { + 0 + } + } + + /// See [IBuilderConfig::setHardwareCompatibilityLevel] + pub fn set_hardware_compatibility_level(&mut self, level: HardwareCompatibilityLevel) { + #[cfg(not(feature = "mock"))] + self.inner + .pin_mut() + .setHardwareCompatibilityLevel(level.into()); + } + + /// See [IBuilderConfig::getHardwareCompatibilityLevel] + pub fn get_hardware_compatibility_level(&self) -> HardwareCompatibilityLevel { + self.inner.getHardwareCompatibilityLevel().into() + } + + /// See [IBuilderConfig::setMaxAuxStreams] + pub fn set_max_aux_streams(&mut self, nb_streams: i32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setMaxAuxStreams(nb_streams); + } + + /// See [IBuilderConfig::getMaxAuxStreams] + pub fn get_max_aux_streams(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getMaxAuxStreams() + } else { + 0 + } + } + + /// See [IBuilderConfig::setRuntimePlatform] + pub fn set_runtime_platform(&mut self, platform: RuntimePlatform) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setRuntimePlatform(platform.into()); + } + + /// See [IBuilderConfig::getRuntimePlatform] + pub fn get_runtime_platform(&self) -> RuntimePlatform { + if cfg!(not(feature = "mock")) { + self.inner.getRuntimePlatform().into() + } else { + RuntimePlatform::kSAME_AS_BUILD + } + } + + /// See [IBuilderConfig::setMaxNbTactics] + pub fn set_max_nb_tactics(&mut self, max_nb_tactics: i32) { + #[cfg(not(feature = "mock"))] + self.inner.pin_mut().setMaxNbTactics(max_nb_tactics); + } + + /// See [IBuilderConfig::getMaxNbTactics] + pub fn get_max_nb_tactics(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getMaxNbTactics() + } else { + 0 + } + } + + /// See [IBuilderConfig::setTilingOptimizationLevel] + pub fn set_tiling_optimization_level( + &mut self, + level: TilingOptimizationLevel, + ) -> crate::Result<()> { + if cfg!(not(feature = "mock")) { + if self + .inner + .pin_mut() + .setTilingOptimizationLevel(level.into()) + { + Ok(()) + } else { + Err(crate::Error::FailedToSetProperty( + PropertySetAttempt::BuilderConfigTilingOptimizationLevel, + )) + } + } else { + Ok(()) + } + } + + /// See [IBuilderConfig::getTilingOptimizationLevel] + pub fn get_tiling_optimization_level(&self) -> TilingOptimizationLevel { + if cfg!(not(feature = "mock")) { + self.inner.getTilingOptimizationLevel().into() + } else { + TilingOptimizationLevel::kNONE + } + } + + /// See [IBuilderConfig::setL2LimitForTiling] + pub fn set_l2_limit_for_tiling(&mut self, size: i64) -> crate::Result<()> { + if cfg!(not(feature = "mock")) { + if self.inner.pin_mut().setL2LimitForTiling(size) { + Ok(()) + } else { + Err(crate::Error::FailedToSetProperty( + PropertySetAttempt::BuilderConfigL2LimitForTiling, + )) + } + } else { + Ok(()) + } + } + + /// See [IBuilderConfig::getL2LimitForTiling] + pub fn get_l2_limit_for_tiling(&self) -> i64 { + if cfg!(not(feature = "mock")) { + self.inner.getL2LimitForTiling() + } else { + 0 + } + } + + /// See [IBuilderConfig::setNbComputeCapabilities] + pub fn set_nb_compute_capabilities( + &mut self, + max_nb_compute_capabilities: i32, + ) -> crate::Result<()> { + if cfg!(not(feature = "mock")) { + if self + .inner + .pin_mut() + .setNbComputeCapabilities(max_nb_compute_capabilities) + { + Ok(()) + } else { + Err(crate::Error::FailedToSetProperty( + PropertySetAttempt::BuilderConfigNbComputeCapabilities, + )) + } + } else { + Ok(()) + } + } + + /// See [IBuilderConfig::getNbComputeCapabilities] + pub fn get_nb_compute_capabilities(&self) -> i32 { + if cfg!(not(feature = "mock")) { + self.inner.getNbComputeCapabilities() + } else { + 0 + } + } + + /// See [IBuilderConfig::setComputeCapability] + pub fn set_compute_capability( + &mut self, + compute_capability: ComputeCapability, + index: i32, + ) -> crate::Result<()> { + if cfg!(not(feature = "mock")) { + if self + .inner + .pin_mut() + .setComputeCapability(compute_capability.into(), index) + { + Ok(()) + } else { + Err(crate::Error::FailedToSetProperty( + PropertySetAttempt::BuilderConfigComputeCapability, + )) + } + } else { + Ok(()) + } + } + + /// See [IBuilderConfig::getComputeCapability] + pub fn get_compute_capability(&self, index: i32) -> ComputeCapability { + if cfg!(not(feature = "mock")) { + self.inner.getComputeCapability(index).into() + } else { + ComputeCapability::kNONE + } + } +} + +#[cfg(test)] +#[cfg(not(feature = "mock"))] +mod tests { + use crate::builder::MemoryPoolType; + use crate::interfaces::MonitorProgress; + use crate::{Builder, DataType, Logger, NetworkDefinition}; + use std::ops::ControlFlow; + use std::sync::atomic::{AtomicU32, Ordering}; + + const NUM_LAYERS: usize = 40; + + /// Progress monitor that writes to stdout and cancels the build after a few steps. + struct StdoutProgressMonitor { + step_count: AtomicU32, + cancel_after: u32, + } + + impl StdoutProgressMonitor { + fn new(cancel_after: u32) -> Self { + Self { + step_count: AtomicU32::new(0), + cancel_after, + } + } + } + + impl MonitorProgress for StdoutProgressMonitor { + fn phase_start(&self, phase_name: &str, parent_phase: Option<&str>, num_steps: i32) { + println!( + "[progress] phase_start phase={:?} parent={:?} num_steps={}", + phase_name, parent_phase, num_steps + ); + } + + fn step_complete(&self, phase_name: &str, step: i32) -> ControlFlow<()> { + let n = self.step_count.fetch_add(1, Ordering::SeqCst); + println!( + "[progress] step_complete phase={:?} step={}", + phase_name, step + ); + if n + 1 >= self.cancel_after { + println!("[progress] cancel requested after {} steps", n + 1); + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + + fn phase_finish(&self, phase_name: &str) { + println!("[progress] phase_finish phase={:?}", phase_name); + } + } + + /// Build a network with many repeated identity layers, each named. + fn build_heavy_network(logger: &Logger) -> crate::Result<(Builder<'_>, NetworkDefinition<'_>)> { + let mut builder = Builder::new(logger)?; + let mut network = builder.create_network(0)?; + + let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 4])?; + for i in 0..NUM_LAYERS { + let mut layer = network.add_identity(&mut tensor)?; + layer.set_name(&mut network, &format!("layer_{}", i))?; + tensor = layer.get_output(&network, 0)?; + } + tensor.set_name(&mut network, "output")?; + network.mark_output(&mut tensor); + + Ok((builder, network)) + } + + #[test] + fn set_progress_monitor_cancel_build() { + let logger = Logger::stderr().expect("logger"); + let (mut builder, mut network) = build_heavy_network(&logger).expect("build network"); + + let mut config = builder.create_config().expect("config"); + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24); + + let monitor = Box::new(StdoutProgressMonitor::new(3)); + config.set_progress_monitor(monitor).unwrap(); + + let result = builder.build_serialized_network(&mut network, &mut config); + + assert!( + result.is_err(), + "build should fail (cancelled by progress monitor)" + ); + } + + #[test] + fn set_progress_monitor_progress_to_stdout() { + let logger = Logger::stderr().expect("logger"); + let (mut builder, mut network) = build_heavy_network(&logger).expect("build network"); + + let mut config = builder.create_config().expect("config"); + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24); + + let monitor = Box::new(StdoutProgressMonitor::new(10000)); + config.set_progress_monitor(monitor).unwrap(); + + let result = builder.build_serialized_network(&mut network, &mut config); + + assert!(result.is_ok(), "build should succeed when not cancelling"); + } +} diff --git a/trtx/src/cuda.rs b/trtx/src/cuda.rs index 273ee2c..fc71027 100644 --- a/trtx/src/cuda.rs +++ b/trtx/src/cuda.rs @@ -1,11 +1,72 @@ //! CUDA memory management utilities -//! -//! Delegates to real/ or mock/ based on feature flag. -#[cfg(feature = "mock")] -pub use crate::mock::cuda::*; -#[cfg(not(feature = "mock"))] -pub use crate::real::cuda::*; +use crate::error::{Error, Result}; + +use cudarc::driver::{CudaDevice, CudaSlice, DevicePtr}; + +/// RAII wrapper for CUDA device memory +pub struct DeviceBuffer { + ptr: CudaSlice, + device: std::sync::Arc, + size: usize, +} + +impl DeviceBuffer { + pub fn new(size: usize) -> Result { + let device = CudaDevice::new(0) + .map_err(|e| Error::Cuda(format!("Failed to initialize CUDA device: {:?}", e)))?; + let ptr = device + .alloc_zeros::(size) + .map_err(|e| Error::Cuda(format!("Failed to allocate CUDA memory: {:?}", e)))?; + Ok(DeviceBuffer { ptr, device, size }) + } + + pub fn as_ptr(&self) -> *mut std::ffi::c_void { + *self.ptr.device_ptr() as *mut std::ffi::c_void + } + + pub fn size(&self) -> usize { + self.size + } + + pub fn copy_from_host(&mut self, data: &[u8]) -> Result<()> { + if data.len() > self.size { + return Err(Error::InvalidArgument( + "Data size exceeds buffer size".to_string(), + )); + } + self.device + .htod_copy_into(data.to_vec(), &mut self.ptr) + .map_err(|e| Error::Cuda(format!("Failed to copy to device: {:?}", e))) + } + + pub fn copy_to_host(&self, data: &mut [u8]) -> Result<()> { + if data.len() > self.size { + return Err(Error::InvalidArgument( + "Data size exceeds buffer size".to_string(), + )); + } + self.device + .dtoh_sync_copy_into(&self.ptr, data) + .map_err(|e| Error::Cuda(format!("Failed to copy from device: {:?}", e))) + } +} + +unsafe impl Send for DeviceBuffer {} + +/// Synchronize CUDA device +pub fn synchronize() -> Result<()> { + let device = CudaDevice::new(0) + .map_err(|e| Error::Cuda(format!("Failed to get CUDA device: {:?}", e)))?; + device + .synchronize() + .map_err(|e| Error::Cuda(format!("Failed to synchronize device: {:?}", e))) +} + +/// Get the default CUDA stream +pub fn get_default_stream() -> *mut std::ffi::c_void { + std::ptr::null_mut() +} #[cfg(test)] mod tests { diff --git a/trtx/src/cuda_engine.rs b/trtx/src/cuda_engine.rs new file mode 100644 index 0000000..ee7f8a0 --- /dev/null +++ b/trtx/src/cuda_engine.rs @@ -0,0 +1,404 @@ +use std::{ffi::CStr, marker::PhantomData}; + +use crate::engine_inspector::EngineInspector; +use crate::error::PropertySetAttempt; +use crate::host_memory::HostMemory; +use crate::{DataType, Error, ExecutionContext, Result}; +use autocxx::cxx::UniquePtr; +use trtx_sys::{ + nvinfer1::{self, ICudaEngine, TensorIOMode}, + SerializationFlag, +}; + +pub struct SerializationConfig<'cuda_engine> { + inner: UniquePtr, + _runtime: PhantomData<&'cuda_engine nvinfer1::ICudaEngine>, +} +impl SerializationConfig<'_> { + pub fn get_flag(&self, flag: SerializationFlag) -> bool { + self.inner.getFlag(flag.into()) + } + pub fn get_flags(&self) -> u32 { + self.inner.getFlags() + } + pub fn set_flag(&mut self, flag: SerializationFlag) -> Result<()> { + if self.inner.pin_mut().setFlag(flag.into()) { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::SerializationFlag, + )) + } + } + pub fn set_flags(&mut self, flags: u32) -> Result<()> { + if self.inner.pin_mut().setFlags(flags) { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::SerializationFlag, + )) + } + } + pub fn clear_flag(&mut self, flag: SerializationFlag) -> Result<()> { + if self.inner.pin_mut().clearFlag(flag.into()) { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::SerializationFlag, + )) + } + } +} + +pub struct CudaEngine<'runtime> { + pub(crate) inner: UniquePtr, + _runtime: PhantomData<&'runtime nvinfer1::IRuntime>, +} + +impl<'engine> CudaEngine<'engine> { + pub(crate) unsafe fn from_ptr(ptr: *mut ICudaEngine) -> Self { + Self { + inner: unsafe { UniquePtr::from_raw(ptr) }, + _runtime: Default::default(), + } + } + + pub fn get_nb_io_tensors(&self) -> Result { + if cfg!(feature = "mock") { + Ok(0) + } else { + Ok(self.inner.getNbIOTensors()) + } + } + + pub fn get_tensor_name(&self, index: i32) -> Result { + if cfg!(feature = "mock") { + Ok("mock".to_string()) + } else { + let name_ptr = self.inner.getIOTensorName(index); + if name_ptr.is_null() { + return Err(Error::InvalidArgument("Invalid tensor index".to_string())); + } + Ok(unsafe { CStr::from_ptr(name_ptr) }.to_str()?.to_string()) + } + } + + pub fn get_tensor_shape(&self, name: &str) -> Result> { + let name_cstr = std::ffi::CString::new(name)?; + let dims = unsafe { self.inner.getTensorShape(name_cstr.as_ptr()) }; + let nb_dims = dims.nbDims as usize; + if nb_dims > 8 { + return Err(Error::Runtime("Tensor has too many dimensions".to_string())); + } + Ok((0..nb_dims).map(|i| dims.d[i]).collect()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorDataType`]. + pub fn get_tensor_data_type(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorDataType(name_cstr.as_ptr()) }.into()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getNbLayers`]. + pub fn get_nb_layers(&self) -> Result { + Ok(self.inner.getNbLayers()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getNbOptimizationProfiles`]. + pub fn get_nb_optimization_profiles(&self) -> Result { + Ok(self.inner.getNbOptimizationProfiles()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getNbAuxStreams`]. + pub fn get_nb_aux_streams(&self) -> Result { + Ok(self.inner.getNbAuxStreams()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorIOMode`]. + pub fn get_tensor_io_mode(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorIOMode(name_cstr.as_ptr()) }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorLocation`]. + pub fn get_tensor_location(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorLocation(name_cstr.as_ptr()) }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorFormat`]. + pub fn get_tensor_format(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorFormat(name_cstr.as_ptr()) }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorFormat`] (profile variant). + pub fn get_tensor_format_for_profile( + &self, + name: &str, + profile_index: i32, + ) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { + self.inner + .getTensorFormat1(name_cstr.as_ptr(), profile_index) + }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorFormatDesc`]. + pub fn get_tensor_format_desc(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + let ptr = unsafe { self.inner.getTensorFormatDesc(name_cstr.as_ptr()) }; + if ptr.is_null() { + return Ok(String::new()); + } + Ok(unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorFormatDesc`] (profile variant). + pub fn get_tensor_format_desc_for_profile( + &self, + name: &str, + profile_index: i32, + ) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + let ptr = unsafe { + self.inner + .getTensorFormatDesc1(name_cstr.as_ptr(), profile_index) + }; + if ptr.is_null() { + return Ok(String::new()); + } + Ok(unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string()) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorVectorizedDim`]. + pub fn get_tensor_vectorized_dim(&self, name: &str) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorVectorizedDim(name_cstr.as_ptr()) }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorVectorizedDim`] (profile variant). + pub fn get_tensor_vectorized_dim_for_profile( + &self, + name: &str, + profile_index: i32, + ) -> Result { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { + self.inner + .getTensorVectorizedDim1(name_cstr.as_ptr(), profile_index) + }) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorBytesPerComponent`]. + pub fn get_tensor_bytes_per_component(&self, name: &str) -> Result { + #[cfg(not(feature = "mock"))] + { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorBytesPerComponent(name_cstr.as_ptr()) }) + } + #[cfg(feature = "mock")] + Ok(42) + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorBytesPerComponent`] (profile variant). + pub fn get_tensor_bytes_per_component_for_profile( + &self, + name: &str, + profile_index: i32, + ) -> Result { + if !self.inner.is_null() { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { + self.inner + .getTensorBytesPerComponent1(name_cstr.as_ptr(), profile_index) + }) + } else { + Ok(0) + } + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorComponentsPerElement`]. + pub fn get_tensor_components_per_element(&self, name: &str) -> Result { + #[cfg(not(feature = "mock"))] + { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorComponentsPerElement(name_cstr.as_ptr()) }) + } + #[cfg(feature = "mock")] + { + Ok(42) + } + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::getTensorComponentsPerElement`] (profile variant). + pub fn get_tensor_components_per_element_for_profile( + &self, + name: &str, + profile_index: i32, + ) -> Result { + #[cfg(not(feature = "mock"))] + { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { + self.inner + .getTensorComponentsPerElement1(name_cstr.as_ptr(), profile_index) + }) + } + #[cfg(feature = "mock")] + { + Ok(42) + } + } + + /// See [`trtx_sys::nvinfer1::ICudaEngine::createEngineInspector`]. + /// Returns an inspector that can print layer and engine information (e.g. JSON or one-line format). + pub fn create_engine_inspector(&self) -> Result> { + #[cfg(not(feature = "mock"))] + { + use crate::engine_inspector::EngineInspector; + + let inspector = self.inner.createEngineInspector(); + let inspector = unsafe { + inspector.as_mut().ok_or_else(|| { + Error::Runtime("Failed to create engine inspector".to_string()) + })? + }; + Ok(EngineInspector { + inner: unsafe { UniquePtr::from_raw(inspector) }, + _engine: Default::default(), + }) + } + #[cfg(feature = "mock")] + { + Ok(EngineInspector { + inner: UniquePtr::null(), + _engine: Default::default(), + }) + } + } + + /// Returns the data type of the tensor (e.g. kFLOAT, kHALF). + /// Required for correct buffer sizing and f32/f16 conversion when I/O uses half precision. + pub fn get_tensor_dtype(&self, name: &str) -> Result { + #[cfg(not(feature = "mock"))] + { + let name_cstr = std::ffi::CString::new(name)?; + Ok(unsafe { self.inner.getTensorDataType(name_cstr.as_ptr()).into() }) + } + #[cfg(feature = "mock")] + Ok(trtx_sys::DataType::kFLOAT) + } + + pub fn create_execution_context(&'_ mut self) -> Result> { + #[cfg(not(feature = "mock"))] + { + use crate::ExecutionContext; + + let context_ptr = self.inner.pin_mut().createExecutionContext( + trtx_sys::nvinfer1::ExecutionContextAllocationStrategy::kSTATIC, + ); + Ok(unsafe { ExecutionContext::from_ptr(context_ptr)? }) + } + #[cfg(feature = "mock")] + Ok(unsafe { ExecutionContext::from_ptr(std::ptr::null_mut())? }) + } + + pub fn create_serialization_config(&mut self) -> Result> { + let config = unsafe { + self.inner + .pin_mut() + .createSerializationConfig() + .as_mut() + .ok_or_else(|| Error::Runtime("SerializationConfig creation failed".to_string()))? + }; + Ok(SerializationConfig { + inner: unsafe { UniquePtr::from_raw(config) }, + _runtime: Default::default(), + }) + } + + /// See [nvinfer1::ICudaEngine::serializeWithConfig] + pub fn serialize_with_config( + &'_ self, + config: &mut SerializationConfig, + ) -> Result> { + if !cfg!(feature = "mock") { + let host_mem = unsafe { + self.inner + .serializeWithConfig(config.inner.pin_mut()) + .as_mut() + .ok_or_else(|| { + Error::Runtime("Failed to serialize ICudaEngine with config".to_string()) + })? + }; + Ok(unsafe { HostMemory::from_raw(host_mem) }) + } else { + Ok(unsafe { HostMemory::from_raw(std::ptr::null_mut()) }) + } + } +} + +#[cfg(test)] +#[cfg(not(feature = "mock"))] +mod tests { + use crate::builder::network_flags; + use crate::builder::{Builder, MemoryPoolType}; + use crate::logger::Logger; + use crate::runtime::Runtime; + use crate::{CudaEngine, DataType}; + use trtx_sys::LayerInformationFormat; + + /// Build a minimal serialized engine with ProfilingVerbosity::kVERBOSE so inspector has layer info. + fn build_minimal_engine_with_verbose_profiling(logger: &Logger) -> crate::Result> { + let mut builder = Builder::new(logger)?; + let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?; + let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 4])?; + tensor = network + .add_activation(&tensor, trtx_sys::ActivationType::kRELU) + .unwrap() + .get_output(&network, 0) + .unwrap(); + tensor = network + .add_activation(&tensor, trtx_sys::ActivationType::kRELU) + .unwrap() + .get_output(&network, 0) + .unwrap(); + network.mark_output(&tensor); + + let mut config = builder.create_config()?; + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 20); + config.set_profiling_verbosity(crate::ProfilingVerbosity::kDETAILED); + + let engine_data = builder.build_serialized_network(&mut network, &mut config)?; + Ok(engine_data.to_vec()) + } + + #[test] + fn engine_inspector_json_verbose_profiling() { + let logger = Logger::stderr().expect("logger"); + let engine_data = + build_minimal_engine_with_verbose_profiling(&logger).expect("build engine"); + + let mut runtime = Runtime::new(&logger).expect("runtime"); + let engine: CudaEngine<'_> = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize"); + + let inspector = engine.create_engine_inspector().expect("engine inspector"); + let json = inspector + .get_engine_information(LayerInformationFormat::kJSON) + .expect("get_engine_information JSON"); + + assert!( + !json.is_empty(), + "engine information JSON should not be empty" + ); + assert!( + json.trim_start().starts_with('{'), + "engine information should be JSON (starts with '{{'); got: {}...", + json.chars().take(80).collect::() + ); + } +} diff --git a/trtx/src/engine_inspector.rs b/trtx/src/engine_inspector.rs new file mode 100644 index 0000000..9dbfb6c --- /dev/null +++ b/trtx/src/engine_inspector.rs @@ -0,0 +1,44 @@ +use crate::{Error, Result}; +use std::{ffi::CStr, marker::PhantomData}; + +use autocxx::cxx::UniquePtr; +use trtx_sys::{nvinfer1, LayerInformationFormat}; + +/// Engine inspector for layer/engine information (real mode). +/// See [`trtx_sys::nvinfer1::IEngineInspector`]. +pub struct EngineInspector<'engine> { + pub(crate) inner: UniquePtr, + pub(crate) _engine: PhantomData<&'engine nvinfer1::ICudaEngine>, +} + +impl EngineInspector<'_> { + /// Returns layer information for the given layer index in the requested format. + /// See [`trtx_sys::nvinfer1::IEngineInspector::getLayerInformation`]. + pub fn get_layer_information( + &mut self, + layer_index: i32, + format: LayerInformationFormat, + ) -> Result { + let ptr = self.inner.getLayerInformation(layer_index, format.into()); + Ok(if ptr.is_null() { + return Err(Error::Runtime( + "Could not get layer information".to_string(), + )); + } else { + unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string() + }) + } + + /// Returns engine information in the requested format. + /// See [`trtx_sys::nvinfer1::IEngineInspector::getEngineInformation`]. + pub fn get_engine_information(&self, format: LayerInformationFormat) -> Result { + let ptr = self.inner.getEngineInformation(format.into()); + Ok(if ptr.is_null() { + return Err(Error::Runtime( + "Could not get layer information".to_string(), + )); + } else { + unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string() + }) + } +} diff --git a/trtx/src/enum_helpers.rs b/trtx/src/enum_helpers.rs deleted file mode 100644 index f22acd9..0000000 --- a/trtx/src/enum_helpers.rs +++ /dev/null @@ -1,293 +0,0 @@ -//! Helper functions for converting TensorRT enums to strings -//! -//! This module provides helper functions for getting string names of TensorRT enum variants. -//! These are useful for error messages, debugging, and logging. - -use trtx_sys::nvinfer1::{ - ActivationType, CumulativeOperation, DataType, ElementWiseOperation, GatherMode, - InterpolationMode, PoolingType, ReduceOperation, ResizeCoordinateTransformation, - ResizeRoundMode, ResizeSelector, ScatterMode, UnaryOperation, -}; - -// ============================================================================ -// DataType -// ============================================================================ - -/// Get the string name of a DataType enum variant -pub fn datatype_name(dt: &DataType) -> &'static str { - match *dt { - DataType::kFLOAT => "kFLOAT", - DataType::kHALF => "kHALF", - DataType::kINT8 => "kINT8", - DataType::kINT32 => "kINT32", - DataType::kUINT8 => "kUINT8", - DataType::kBOOL => "kBOOL", - DataType::kFP8 => "kFP8", - DataType::kBF16 => "kBF16", - DataType::kINT64 => "kINT64", - DataType::kINT4 => "kINT4", - DataType::kFP4 => "kFP4", - DataType::kE8M0 => "kE8M0", - } -} - -// ============================================================================ -// ElementWiseOperation -// ============================================================================ - -/// Get the string name of an ElementWiseOperation enum variant -pub fn elementwise_op_name(op: &ElementWiseOperation) -> &'static str { - match *op { - ElementWiseOperation::kSUM => "kSUM", - ElementWiseOperation::kPROD => "kPROD", - ElementWiseOperation::kMAX => "kMAX", - ElementWiseOperation::kMIN => "kMIN", - ElementWiseOperation::kSUB => "kSUB", - ElementWiseOperation::kDIV => "kDIV", - ElementWiseOperation::kPOW => "kPOW", - ElementWiseOperation::kFLOOR_DIV => "kFLOOR_DIV", - ElementWiseOperation::kAND => "kAND", - ElementWiseOperation::kOR => "kOR", - ElementWiseOperation::kXOR => "kXOR", - ElementWiseOperation::kEQUAL => "kEQUAL", - ElementWiseOperation::kGREATER => "kGREATER", - ElementWiseOperation::kLESS => "kLESS", - } -} - -// ============================================================================ -// UnaryOperation -// ============================================================================ - -/// Get the string name of a UnaryOperation enum variant -pub fn unary_op_name(op: &UnaryOperation) -> &'static str { - match *op { - UnaryOperation::kEXP => "kEXP", - UnaryOperation::kLOG => "kLOG", - UnaryOperation::kSQRT => "kSQRT", - UnaryOperation::kRECIP => "kRECIP", - UnaryOperation::kABS => "kABS", - UnaryOperation::kNEG => "kNEG", - UnaryOperation::kSIN => "kSIN", - UnaryOperation::kCOS => "kCOS", - UnaryOperation::kTAN => "kTAN", - UnaryOperation::kSINH => "kSINH", - UnaryOperation::kCOSH => "kCOSH", - UnaryOperation::kASIN => "kASIN", - UnaryOperation::kACOS => "kACOS", - UnaryOperation::kATAN => "kATAN", - UnaryOperation::kASINH => "kASINH", - UnaryOperation::kACOSH => "kACOSH", - UnaryOperation::kATANH => "kATANH", - UnaryOperation::kCEIL => "kCEIL", - UnaryOperation::kFLOOR => "kFLOOR", - UnaryOperation::kERF => "kERF", - UnaryOperation::kNOT => "kNOT", - UnaryOperation::kSIGN => "kSIGN", - UnaryOperation::kROUND => "kROUND", - UnaryOperation::kISINF => "kISINF", - UnaryOperation::kISNAN => "kISNAN", - } -} - -// ============================================================================ -// ActivationType -// ============================================================================ - -/// Get the string name of an ActivationType enum variant -pub fn activation_type_name(at: &ActivationType) -> &'static str { - match *at { - ActivationType::kRELU => "kRELU", - ActivationType::kSIGMOID => "kSIGMOID", - ActivationType::kTANH => "kTANH", - ActivationType::kLEAKY_RELU => "kLEAKY_RELU", - ActivationType::kELU => "kELU", - ActivationType::kSELU => "kSELU", - ActivationType::kSOFTSIGN => "kSOFTSIGN", - ActivationType::kSOFTPLUS => "kSOFTPLUS", - ActivationType::kCLIP => "kCLIP", - ActivationType::kHARD_SIGMOID => "kHARD_SIGMOID", - ActivationType::kSCALED_TANH => "kSCALED_TANH", - ActivationType::kTHRESHOLDED_RELU => "kTHRESHOLDED_RELU", - ActivationType::kGELU_ERF => "kGELU_ERF", - ActivationType::kGELU_TANH => "kGELU_TANH", - } -} - -// ============================================================================ -// PoolingType -// ============================================================================ - -/// Get the string name of a PoolingType enum variant -pub fn pooling_type_name(pt: &PoolingType) -> &'static str { - match *pt { - PoolingType::kMAX => "kMAX", - PoolingType::kAVERAGE => "kAVERAGE", - PoolingType::kMAX_AVERAGE_BLEND => "kMAX_AVERAGE_BLEND", - } -} - -// ============================================================================ -// ReduceOperation -// ============================================================================ - -/// Get the string name of a ReduceOperation enum variant -pub fn reduce_op_name(op: &ReduceOperation) -> &'static str { - match *op { - ReduceOperation::kSUM => "kSUM", - ReduceOperation::kPROD => "kPROD", - ReduceOperation::kMAX => "kMAX", - ReduceOperation::kMIN => "kMIN", - ReduceOperation::kAVG => "kAVG", - } -} - -// ============================================================================ -// CumulativeOperation -// ============================================================================ - -/// Get the string name of a CumulativeOperation enum variant -pub fn cumulative_op_name(op: &CumulativeOperation) -> &'static str { - match *op { - CumulativeOperation::kSUM => "kSUM", - #[cfg(feature = "mock")] - CumulativeOperation::kPROD => "kPROD", - #[cfg(feature = "mock")] - CumulativeOperation::kMIN => "kMIN", - #[cfg(feature = "mock")] - CumulativeOperation::kMAX => "kMAX", - } -} - -// ============================================================================ -// GatherMode -// ============================================================================ - -/// Get the string name of a GatherMode enum variant -pub fn gather_mode_name(mode: &GatherMode) -> &'static str { - match *mode { - GatherMode::kDEFAULT => "kDEFAULT", - GatherMode::kELEMENT => "kELEMENT", - GatherMode::kND => "kND", - } -} - -// ============================================================================ -// ScatterMode -// ============================================================================ - -/// Get the string name of a ScatterMode enum variant -pub fn scatter_mode_name(mode: &ScatterMode) -> &'static str { - match *mode { - ScatterMode::kELEMENT => "kELEMENT", - ScatterMode::kND => "kND", - } -} - -// ============================================================================ -// InterpolationMode (ResizeMode) -// ============================================================================ - -/// Get the string name of an InterpolationMode enum variant -/// Note: ResizeMode is a typedef for InterpolationMode -pub fn interpolation_mode_name(mode: &InterpolationMode) -> &'static str { - match *mode { - InterpolationMode::kNEAREST => "kNEAREST", - InterpolationMode::kLINEAR => "kLINEAR", - InterpolationMode::kCUBIC => "kCUBIC", - } -} - -// ============================================================================ -// ResizeCoordinateTransformation -// ============================================================================ - -/// Get the string name of a ResizeCoordinateTransformation enum variant -pub fn resize_coord_transform_name(transform: &ResizeCoordinateTransformation) -> &'static str { - match *transform { - ResizeCoordinateTransformation::kALIGN_CORNERS => "kALIGN_CORNERS", - ResizeCoordinateTransformation::kASYMMETRIC => "kASYMMETRIC", - ResizeCoordinateTransformation::kHALF_PIXEL => "kHALF_PIXEL", - #[cfg(feature = "mock")] - ResizeCoordinateTransformation::kHALF_PIXEL_SYMMETRIC => "kHALF_PIXEL_SYMMETRIC", - } -} - -// ============================================================================ -// ResizeSelector -// ============================================================================ - -/// Get the string name of a ResizeSelector enum variant -pub fn resize_selector_name(selector: &ResizeSelector) -> &'static str { - match *selector { - ResizeSelector::kFORMULA => "kFORMULA", - #[cfg(feature = "mock")] - ResizeSelector::kSIZES => "kSIZES", - ResizeSelector::kUPPER => "kUPPER", - } -} - -// ============================================================================ -// ResizeRoundMode -// ============================================================================ - -/// Get the string name of a ResizeRoundMode enum variant -pub fn resize_round_mode_name(mode: &ResizeRoundMode) -> &'static str { - match *mode { - ResizeRoundMode::kFLOOR => "kFLOOR", - ResizeRoundMode::kCEIL => "kCEIL", - #[cfg(feature = "mock")] - ResizeRoundMode::kROUND => "kROUND", - ResizeRoundMode::kHALF_UP => "kHALF_UP", - ResizeRoundMode::kHALF_DOWN => "kHALF_DOWN", - } -} - -// NOTE: RNN enum helper functions removed - IRNNv2Layer is deprecated and bindings unavailable - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_datatype_name() { - assert_eq!(datatype_name(&DataType::kFLOAT), "kFLOAT"); - assert_eq!(datatype_name(&DataType::kBOOL), "kBOOL"); - } - - #[test] - fn test_elementwise_op_name() { - assert_eq!(elementwise_op_name(&ElementWiseOperation::kSUM), "kSUM"); - assert_eq!( - elementwise_op_name(&ElementWiseOperation::kGREATER), - "kGREATER" - ); - } - - #[test] - fn test_unary_op_name() { - assert_eq!(unary_op_name(&UnaryOperation::kEXP), "kEXP"); - assert_eq!(unary_op_name(&UnaryOperation::kNOT), "kNOT"); - } - - #[test] - fn test_activation_type_name() { - assert_eq!(activation_type_name(&ActivationType::kRELU), "kRELU"); - assert_eq!( - activation_type_name(&ActivationType::kGELU_ERF), - "kGELU_ERF" - ); - } - - #[test] - fn test_pooling_type_name() { - assert_eq!(pooling_type_name(&PoolingType::kMAX), "kMAX"); - assert_eq!(pooling_type_name(&PoolingType::kAVERAGE), "kAVERAGE"); - } - - #[test] - fn test_reduce_op_name() { - assert_eq!(reduce_op_name(&ReduceOperation::kSUM), "kSUM"); - assert_eq!(reduce_op_name(&ReduceOperation::kAVG), "kAVG"); - } -} diff --git a/trtx/src/error.rs b/trtx/src/error.rs index e57d912..a6103d5 100644 --- a/trtx/src/error.rs +++ b/trtx/src/error.rs @@ -2,10 +2,25 @@ use std::ffi::NulError; use thiserror::Error; +use trtx_sys::LayerType; /// Result type for TensorRT-RTX operations pub type Result = std::result::Result; +#[derive(Debug, Eq, PartialEq)] +pub enum PropertySetAttempt { + SerializationFlag, + OptimizationProfileSetDimensions, + OptimizationProfileSetExtraMemoryTarget, + OptimizationProfileSetShapeValues, + BuilderConfigTacticSources, + BuilderConfigTilingOptimizationLevel, + BuilderConfigL2LimitForTiling, + BuilderConfigNbComputeCapabilities, + BuilderConfigComputeCapability, + ExecutionContextTensorDebugState, +} + /// Errors that can occur when using TensorRT-RTX #[derive(Debug, Error)] pub enum Error { @@ -56,41 +71,26 @@ pub enum Error { #[error("Would unwrap a poisened lock")] LockPoisining, -} -impl From> for Error { - fn from(_: std::sync::PoisonError) -> Self { - Error::LockPoisining - } -} + #[error("Failed to create layer: {0:?}")] + LayerCreationFailed(LayerType), -impl Error { - /// Create error from FFI error code and message buffer (mock mode) - #[cfg(feature = "mock")] - pub(crate) fn from_ffi(code: i32, error_msg: &[i8]) -> Self { - let msg = Self::parse_error_msg(error_msg); - - match code { - code if code == trtx_sys::TRTX_ERROR_INVALID_ARGUMENT as i32 => { - Error::InvalidArgument(msg) - } - code if code == trtx_sys::TRTX_ERROR_OUT_OF_MEMORY as i32 => Error::OutOfMemory(msg), - code if code == trtx_sys::TRTX_ERROR_RUNTIME_ERROR as i32 => Error::Runtime(msg), - code if code == trtx_sys::TRTX_ERROR_CUDA_ERROR as i32 => Error::Cuda(msg), - _ => Error::Unknown(msg), - } - } + #[error("Failed to get generic layer from network")] + GetLayerFailed, + + #[error("Failed to get a tensor from the network")] + GetTensorFailed, - /// Parse error message from C string buffer (mock mode) - #[cfg(feature = "mock")] - fn parse_error_msg(buffer: &[i8]) -> String { - // Find null terminator - let len = buffer.iter().position(|&c| c == 0).unwrap_or(buffer.len()); + #[error("Failed to create BuilderConfig")] + BuilderConfigCreationFailed, - // Convert i8 to u8 safely - let bytes: Vec = buffer[..len].iter().map(|&c| c as u8).collect(); + #[error("Failed to set property: {0:?}")] + FailedToSetProperty(PropertySetAttempt), +} - String::from_utf8_lossy(&bytes).into_owned() +impl From> for Error { + fn from(_: std::sync::PoisonError) -> Self { + Error::LockPoisining } } @@ -103,23 +103,4 @@ mod tests { let err = Error::InvalidArgument("test".to_string()); assert_eq!(err.to_string(), "Invalid argument: test"); } - - #[test] - #[cfg(feature = "mock")] - fn test_parse_error_msg() { - let msg = b"test error\0".map(|b| b as i8); - let parsed = Error::parse_error_msg(&msg); - assert_eq!(parsed, "test error"); - } - - #[test] - #[cfg(feature = "mock")] - fn test_from_ffi() { - let msg = b"test\0".map(|b| b as i8); - let err = Error::from_ffi(trtx_sys::TRTX_ERROR_INVALID_ARGUMENT as i32, &msg); - match err { - Error::InvalidArgument(s) => assert_eq!(s, "test"), - _ => panic!("Wrong error type"), - } - } } diff --git a/trtx/src/executor.rs b/trtx/src/executor.rs index c0a6874..1fc6204 100644 --- a/trtx/src/executor.rs +++ b/trtx/src/executor.rs @@ -64,13 +64,13 @@ fn build_engine_from_onnx(logger: &Logger, onnx_bytes: &[u8]) -> Result> // Create builder use crate::builder::MemoryPoolType; - let builder = Builder::new(logger)?; + let mut builder = Builder::new(logger)?; // Create network with explicit batch let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?; // Parse ONNX model - let parser = OnnxParser::new(&mut network, logger)?; + let mut parser = OnnxParser::new(&mut network, logger)?; parser.parse(onnx_bytes)?; // Configure builder @@ -93,8 +93,8 @@ fn execute_engine( inputs: &[TensorInput], ) -> Result> { // Create runtime and deserialize engine - let runtime = Runtime::new(logger)?; - let engine = runtime.deserialize_cuda_engine(engine_data)?; + let mut runtime = Runtime::new(logger)?; + let mut engine = runtime.deserialize_cuda_engine(engine_data)?; let mut context = engine.create_execution_context()?; // Get tensor information diff --git a/trtx/src/host_memory.rs b/trtx/src/host_memory.rs new file mode 100644 index 0000000..c6e9eb2 --- /dev/null +++ b/trtx/src/host_memory.rs @@ -0,0 +1,57 @@ +use core::slice; +use cxx::UniquePtr; +use std::marker::PhantomData; +use std::ops::Deref; +use trtx_sys::nvinfer1::{self}; +use trtx_sys::DataType; + +pub struct HostMemory<'builder> { + pub(crate) inner: UniquePtr, + _builder: PhantomData<&'builder nvinfer1::IBuilder>, + #[cfg(feature = "mock")] + mock_data: Vec, +} + +impl<'builder> HostMemory<'builder> { + /// assumes ownership of ref + pub(crate) unsafe fn from_raw(ptr: *mut nvinfer1::IHostMemory) -> Self { + unsafe { + HostMemory { + inner: UniquePtr::from_raw(ptr), + _builder: Default::default(), + #[cfg(feature = "mock")] + mock_data: Default::default(), + } + } + } + + pub fn data_type(&self) -> DataType { + if cfg!(feature = "mock") { + DataType::kINT8 + } else { + self.inner.type_().into() + } + } +} + +#[cfg(feature = "mock")] +impl<'memory> AsRef<[u8]> for HostMemory<'memory> { + fn as_ref(&self) -> &[u8] { + &self.mock_data + } +} +#[cfg(not(feature = "mock"))] +impl<'memory> AsRef<[u8]> for HostMemory<'memory> { + fn as_ref(&self) -> &'memory [u8] { + unsafe { slice::from_raw_parts(self.inner.data() as *const u8, self.inner.size()) } + } +} + +impl<'builder> Deref for HostMemory<'builder> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + // You can leverage your existing AsRef implementation here + self.as_ref() + } +} diff --git a/trtx/src/interfaces.rs b/trtx/src/interfaces.rs new file mode 100644 index 0000000..c91de5c --- /dev/null +++ b/trtx/src/interfaces.rs @@ -0,0 +1,462 @@ +use crate::{Error, Result}; +use cxx::UniquePtr; +use std::{ffi::CStr, pin::Pin}; +use trtx_sys::{ + nvinfer1, trtx_create_debug_listener, trtx_create_error_recorder, trtx_create_gpu_allocator, + trtx_create_progress_monitor, +}; +use trtx_sys::{DataType, Dims64, ErrorCode, TensorLocation}; + +/// Rust trait that corresponds to [nvinfer1::IProgressMonitor] +/// +/// Use with [crate::BuilderConfig::set_progress_monitor] +pub trait MonitorProgress: Send + Sync { + /// See [nvinfer::IProgressMonitor::phaseStart] + fn phase_start(&self, phase_name: &str, parent_phase: Option<&str>, num_steps: i32); + /// See [nvinfer::IProgressMonitor::stepComplete]. Return whether to continue building or cancel + fn step_complete(&self, phase_name: &str, step: i32) -> std::ops::ControlFlow<()>; + /// See [nvinfer::IProgressMonitor::phaseFinish] + fn phase_finish(&self, phase_name: &str); +} + +#[allow(non_snake_case)] +unsafe extern "system" fn ProgressMonitor_phaseStart( + this: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, + parentPhase: *const ::std::os::raw::c_char, + nbSteps: i32, +) { + let this = this as *mut ProgressMonitor; + let phase_name = CStr::from_ptr(phaseName); + let parent_phase = + (!parentPhase.is_null()).then(|| CStr::from_ptr(phaseName).to_string_lossy()); + this.as_mut().unwrap().rust_impl.phase_start( + &phase_name.to_string_lossy(), + parent_phase.as_deref(), + nbSteps, + ); +} +#[allow(non_snake_case)] +unsafe extern "system" fn ProgressMonitor_stepComplete( + this: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, + step: i32, +) -> bool { + let this = this as *mut ProgressMonitor; + let phase_name = CStr::from_ptr(phaseName); + this.as_mut() + .unwrap() + .rust_impl + .step_complete(&phase_name.to_string_lossy(), step) + .is_continue() +} +#[allow(non_snake_case)] +unsafe extern "system" fn ProgressMonitor_phaseFinish( + this: *mut std::ffi::c_void, + phaseName: *const ::std::os::raw::c_char, +) { + let this = this as *mut ProgressMonitor; + let phase_name = CStr::from_ptr(phaseName); + this.as_mut() + .unwrap() + .rust_impl + .phase_finish(&phase_name.to_string_lossy()); +} + +/// +/// Subclasses [nvinfer1::IProgressMonitor] +/// +/// Construct a object with a dyn [HandleProgress] to implement +/// [nvinfer1::IProgressMonitor] from Rust +#[repr(C)] +pub(crate) struct ProgressMonitor { + cpp_obj: UniquePtr, + rust_impl: Box, +} + +impl ProgressMonitor { + pub(crate) fn new(inner: Box) -> Result>> { + let mut rust_obj = Box::pin(ProgressMonitor { + cpp_obj: UniquePtr::null(), + rust_impl: inner, + }); + + unsafe { + let cpp_obj = UniquePtr::from_raw(trtx_create_progress_monitor( + rust_obj.as_mut().get_unchecked_mut() as *mut ProgressMonitor + as *mut std::ffi::c_void, + ProgressMonitor_phaseStart, + ProgressMonitor_stepComplete, + ProgressMonitor_phaseFinish, + )); + if cpp_obj.is_null() { + return Err(Error::Runtime( + "Failed to allocate object for IProgressMonitor subclass".to_string(), + )); + } + rust_obj.cpp_obj = cpp_obj; + } + Ok(rust_obj) + } + pub fn as_trt_progress_monitor(&self) -> *mut nvinfer1::IProgressMonitor { + self.cpp_obj.as_mut_ptr() + } +} + +/// C callbacks for GpuAllocatorSubclass (bridge to Rust). `this` is *mut RefCell. +#[allow(non_snake_case)] +unsafe extern "system" fn GpuAllocator_allocateAsync( + this: *const std::ffi::c_void, + size: u64, + alignment: u64, + flags: u32, + cuda_stream: *mut std::ffi::c_void, +) -> *mut std::ffi::c_void { + let this = this as *const GpuAllocator; + this.as_ref() + .unwrap() + .rust_impl + .allocate_async(size, alignment, flags, cuda_stream) +} +#[allow(non_snake_case)] +unsafe extern "system" fn GpuAllocator_reallocate( + this: *const std::ffi::c_void, + memory: *mut std::ffi::c_void, + alignment: u64, + new_size: u64, +) -> *mut std::ffi::c_void { + let this = this as *const GpuAllocator; + this.as_ref() + .unwrap() + .rust_impl + .reallocate(memory, alignment, new_size) +} +#[allow(non_snake_case)] +unsafe extern "system" fn GpuAllocator_deallocateAsync( + this: *const std::ffi::c_void, + memory: *mut std::ffi::c_void, + cuda_stream: *mut std::ffi::c_void, +) -> bool { + let this = this as *const GpuAllocator; + this.as_ref() + .unwrap() + .rust_impl + .deallocate_async(memory, cuda_stream) +} + +/// +/// Subclasses [nvinfer1::IGpuAllocator] via C++ bridge. +/// +/// Construct with an [AllocateGpu] to implement [nvinfer1::IGpuAllocator] from Rust. +#[repr(C)] +pub struct GpuAllocator { + cpp_obj: UniquePtr, + rust_impl: Box, +} + +impl GpuAllocator { + pub fn new(inner: Box) -> Result>> { + let mut rust_obj = Box::pin(GpuAllocator { + cpp_obj: UniquePtr::null(), + rust_impl: inner, + }); + unsafe { + let cpp_obj = UniquePtr::from_raw(trtx_create_gpu_allocator( + rust_obj.as_mut().get_unchecked_mut() as *mut GpuAllocator as *mut std::ffi::c_void, + GpuAllocator_allocateAsync, + GpuAllocator_reallocate, + GpuAllocator_deallocateAsync, + )); + if cpp_obj.is_null() { + return Err(Error::Runtime( + "Failed to allocate object for IGpuAllocator subclass".to_string(), + )); + } + rust_obj.cpp_obj = cpp_obj; + } + Ok(rust_obj) + } + + pub fn as_trt_gpu_allocator(&self) -> *mut nvinfer1::IGpuAllocator { + self.cpp_obj.as_mut_ptr() + } +} + +pub trait AllocateGpu: Send + Sync { + // we omit the following deprecated methods + //fn allocate(&mut self, size: u64, alignment: u64, flags: u32) -> *mut autocxx::c_void; + //unsafe fn deallocate(&mut self, data: *mut autocxx::c_void) -> bool; + + /// # Safety + /// User needs to ensure memory safety for CUDA device pointers and follow regular CUDA rules + unsafe fn allocate_async( + &self, + size: u64, + alignment: u64, + flags: u32, + cuda_stream: *mut std::ffi::c_void, + ) -> *mut std::ffi::c_void; + /// # Safety + /// User needs to ensure memory safety for CUDA device pointers and follow regular CUDA rules + unsafe fn reallocate( + &self, + memory: *mut std::ffi::c_void, + alignment: u64, + new_size: u64, + ) -> *mut std::ffi::c_void; + /// # Safety + /// User needs to ensure memory safety for CUDA device pointers and follow regular CUDA rules + unsafe fn deallocate_async( + &self, + data: *mut std::ffi::c_void, + cuda_stream: *mut std::ffi::c_void, + ) -> bool; +} + +/// C callbacks for ErrorRecorderSubclass (bridge to Rust). `this` is *mut RefCell. +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_getNbErrors(this: *mut ErrorRecorder) -> i32 { + this.as_ref().unwrap().rust_impl.nb_errors() +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_getErrorCode( + this: *const ErrorRecorder, + error_idx: i32, +) -> i32 { + this.as_ref().unwrap().rust_impl.error_code(error_idx) as i32 +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_getErrorDesc( + this: *const ErrorRecorder, + error_idx: i32, + out_buf: *mut ::std::os::raw::c_char, + out_buf_size: usize, +) { + if out_buf.is_null() || out_buf_size == 0 { + return; + } + let desc = this.as_ref().unwrap().rust_impl.error_desc(error_idx); + let bytes = desc.to_bytes_with_nul(); + let copy_len = (bytes.len()).min(out_buf_size); + std::ptr::copy_nonoverlapping(bytes.as_ptr(), out_buf as *mut u8, copy_len); + if copy_len < out_buf_size { + *out_buf.add(copy_len) = 0; + } +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_hasOverflowed(this: *mut ErrorRecorder) -> bool { + this.as_ref().unwrap().rust_impl.has_overflowed() +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_clear(this: *mut ErrorRecorder) { + this.as_mut().unwrap().rust_impl.clear(); +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_reportError( + this: *mut ErrorRecorder, + val: i32, + desc: *const ::std::os::raw::c_char, +) -> bool { + let desc_str = CStr::from_ptr(desc).to_string_lossy(); + this.as_mut().unwrap().rust_impl.report_error( + match val { + 0 => ErrorCode::kSUCCESS, + 1 => ErrorCode::kUNSPECIFIED_ERROR, + 2 => ErrorCode::kINTERNAL_ERROR, + 3 => ErrorCode::kINVALID_ARGUMENT, + 4 => ErrorCode::kINVALID_CONFIG, + 5 => ErrorCode::kFAILED_ALLOCATION, + 6 => ErrorCode::kFAILED_INITIALIZATION, + 7 => ErrorCode::kFAILED_EXECUTION, + 8 => ErrorCode::kFAILED_COMPUTATION, + 9 => ErrorCode::kINVALID_STATE, + 10 => ErrorCode::kUNSUPPORTED_STATE, + _ => ErrorCode::kUNSPECIFIED_ERROR, + }, + &desc_str, + ) +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_incRefCount(this: *mut ErrorRecorder) -> i32 { + this.as_mut().unwrap().rust_impl.inc_ref_count() +} +#[allow(non_snake_case)] +unsafe extern "system" fn ErrorRecorder_decRefCount(this: *mut ErrorRecorder) -> i32 { + this.as_mut().unwrap().rust_impl.dec_ref_count() +} + +/// +/// Subclasses [nvinfer1::IErrorRecorder] via C++ bridge. +/// +/// Construct with a [RecordError] to implement [nvinfer1::IErrorRecorder] from Rust. +#[repr(C)] +pub struct ErrorRecorder { + cpp_obj: UniquePtr, + rust_impl: Box, +} + +impl ErrorRecorder { + pub fn new(inner: Box) -> Result>> { + let mut rust_obj = Box::pin(ErrorRecorder { + cpp_obj: UniquePtr::null(), + rust_impl: inner, + }); + unsafe { + let cpp_obj = UniquePtr::from_raw(trtx_create_error_recorder( + rust_obj.as_mut().get_unchecked_mut() as *mut ErrorRecorder + as *mut std::ffi::c_void, + ErrorRecorder_getNbErrors as *mut std::ffi::c_void, + ErrorRecorder_getErrorCode as *mut std::ffi::c_void, + ErrorRecorder_getErrorDesc as *mut std::ffi::c_void, + ErrorRecorder_hasOverflowed as *mut std::ffi::c_void, + ErrorRecorder_clear as *mut std::ffi::c_void, + ErrorRecorder_reportError as *mut std::ffi::c_void, + ErrorRecorder_incRefCount as *mut std::ffi::c_void, + ErrorRecorder_decRefCount as *mut std::ffi::c_void, + )); + if cpp_obj.is_null() { + return Err(Error::Runtime( + "Failed to allocate object for IErrorRecorder subclass".to_string(), + )); + } + rust_obj.cpp_obj = cpp_obj; + } + Ok(rust_obj) + } + + pub fn as_trt_error_recorder(&self) -> *mut nvinfer1::IErrorRecorder { + self.cpp_obj.as_mut_ptr() + } +} + +pub trait RecordError: Send + Sync { + fn nb_errors(&self) -> i32; + fn error_code(&self, error_idx: i32) -> ErrorCode; + fn error_desc(&self, error_idx: i32) -> &CStr; + fn has_overflowed(&self) -> bool; + fn clear(&self); + fn report_error(&self, val: ErrorCode, desc: &str) -> bool; + fn inc_ref_count(&self) -> i32; + fn dec_ref_count(&self) -> i32; +} + +#[allow(non_snake_case)] +unsafe extern "system" fn DebugListener_processDebugTensor( + this: *const std::ffi::c_void, + addr: *const std::ffi::c_void, + location: nvinfer1::TensorLocation, + type_: nvinfer1::DataType, + shape: *const Dims64, + name: *const std::ffi::c_char, + stream: *mut std::ffi::c_void, +) -> bool { + let this = this as *const DebugListener; + let name = (!name.is_null()).then(|| CStr::from_ptr(name)); + let name = name.map(|s| s.to_string_lossy()); + this.as_ref() + .unwrap() + .rust_impl + .process_debug_tensor( + addr, + location.into(), + type_.into(), + shape.as_ref().unwrap(), + name.as_deref(), + stream, + ) + .is_ok() +} + +/// +/// Subclasses [nvinfer1::IDebugListener] via C++ bridge. +#[repr(C)] +pub struct DebugListener { + cpp_obj: UniquePtr, + rust_impl: Box, +} + +pub type ProcessDebugTensorResult = std::result::Result<(), ()>; + +impl DebugListener { + pub fn new(inner: Box) -> Result>> { + let mut rust_obj = Box::pin(Self { + cpp_obj: UniquePtr::null(), + rust_impl: inner, + }); + unsafe { + let cpp_obj = UniquePtr::from_raw(trtx_create_debug_listener( + rust_obj.as_mut().get_unchecked_mut() as *mut DebugListener + as *mut std::ffi::c_void, + DebugListener_processDebugTensor, + )); + if cpp_obj.is_null() { + return Err(Error::Runtime( + "Failed to allocate object for IDebugListener subclass".to_string(), + )); + } + rust_obj.cpp_obj = cpp_obj; + } + Ok(rust_obj) + } + + pub fn as_raw(&self) -> *mut nvinfer1::IDebugListener { + self.cpp_obj.as_mut_ptr() + } +} + +pub trait ProcessDebugTensor: Send + Sync { + /// # Safety + /// + /// User needs to ensure memory safety for CUDA pointers and ensure correct lifetimes for CUDA + /// objects + unsafe fn process_debug_tensor( + &self, + addr: *const std::ffi::c_void, + location: TensorLocation, + type_: DataType, + shape: &Dims64, + name: Option<&str>, + stream: *mut std::ffi::c_void, + ) -> ProcessDebugTensorResult; +} + +//#[subclass] +//#[derive(Default)] +//pub struct StreamReaderV2 { +//inner: Option>, +//} + +//impl StreamReaderV2 { +//pub fn new(inner: Box) -> Rc> { +//let rtn = Self::default_rust_owned(); +//rtn.borrow_mut().inner = Some(inner); +//rtn +//} +//} + +//impl nvinfer1::IStreamReaderV2_methods for StreamReaderV2 { +//unsafe fn read( +//&mut self, +//destination: *mut autocxx::c_void, +//nbBytes: i64, +//stream: *mut crate::ffi::CUstream_st, +//) -> i64 { +//self.inner +//.as_mut() +//.unwrap() +//.read(destination, nbBytes, stream) +//} +//fn seek(&mut self, offset: i64, where_: nvinfer1::SeekPosition) -> bool { +//self.inner.as_mut().unwrap().seek(offset, where_.into()) +//} +//} + +//pub trait ReadStreamV2: Send + Sync { +//unsafe fn read( +//&mut self, +//destination: *mut autocxx::c_void, +//nbBytes: i64, +//stream: *mut crate::ffi::CUstream_st, +//) -> i64; +//fn seek(&mut self, offset: i64, where_: crate::SeekPosition) -> bool; +//} diff --git a/trtx/src/lib.rs b/trtx/src/lib.rs index 14ff98e..8730508 100644 --- a/trtx/src/lib.rs +++ b/trtx/src/lib.rs @@ -51,7 +51,7 @@ //! let logger = Logger::stderr()?; //! //! // Build phase -//! let builder = Builder::new(&logger)?; +//! let mut builder = Builder::new(&logger)?; //! let mut network = builder.create_network(network_flags::EXPLICIT_BATCH)?; //! let mut config = builder.create_config()?; //! @@ -63,8 +63,8 @@ //! std::fs::write("model.engine", &engine_data)?; //! //! // Inference phase -//! let runtime = Runtime::new(&logger)?; -//! let engine = runtime.deserialize_cuda_engine(&engine_data)?; +//! let mut runtime = Runtime::new(&logger)?; +//! let mut engine = runtime.deserialize_cuda_engine(&engine_data)?; //! let context = engine.create_execution_context()?; //! //! // List I/O tensors @@ -103,32 +103,33 @@ // Allow unnecessary casts - they're needed for real mode (u32) but not mock mode (i32) #![cfg_attr(feature = "mock", allow(clippy::unnecessary_cast))] - -#[cfg(not(feature = "mock"))] -mod real; - -#[cfg(feature = "mock")] -pub mod mock; +// We don't use real parameters in mocks +#![cfg_attr(feature = "mock", allow(unused))] +#![cfg_attr(feature = "mock", allow(unused_variables))] pub mod autocxx_helpers; +pub mod axes; pub mod builder; +pub mod builder_config; pub mod cuda; -pub mod enum_helpers; +pub mod cuda_engine; +pub mod engine_inspector; pub mod error; pub mod executor; +pub mod host_memory; +pub mod interfaces; pub mod logger; pub mod network; #[cfg(feature = "onnxparser")] pub mod onnx_parser; +pub mod optimization_profile; +pub mod refitter; pub mod runtime; // Re-export commonly used types -pub use builder::{Builder, BuilderConfig}; +pub use axes::Axes; +pub use builder::{Builder, BuilderConfig, ProfilingVerbosity}; pub use cuda::{get_default_stream, synchronize, DeviceBuffer}; -pub use enum_helpers::{ - activation_type_name, datatype_name, elementwise_op_name, pooling_type_name, reduce_op_name, - unary_op_name, -}; pub use error::{Error, Result}; #[cfg(feature = "onnxparser")] pub use executor::{run_onnx_with_tensorrt, run_onnx_zeroed}; @@ -139,7 +140,8 @@ pub use logger::{LogHandler, Logger, Severity, StderrLogger}; pub use network::{ConvWeights, NetworkDefinition, Tensor}; #[cfg(feature = "onnxparser")] pub use onnx_parser::OnnxParser; -pub use runtime::{CudaEngine, ExecutionContext, Runtime}; +pub use refitter::Refitter; +pub use runtime::{CudaEngine, EngineInspector, ExecutionContext, Runtime}; #[cfg(feature = "dlopen_tensorrt_rtx")] #[cfg(not(any(feature = "link_tensorrt_rtx", feature = "mock")))] @@ -209,9 +211,9 @@ pub fn dynamically_load_tensorrt_onnxparser(_filename: Option) // Re-export TensorRT operation enums pub use trtx_sys::{ ActivationType, CumulativeOperation, DataType, ElementWiseOperation, GatherMode, - InterpolationMode, MatrixOperation, PoolingType, ReduceOperation, - ResizeCoordinateTransformation, ResizeRoundMode, ResizeSelector, ScaleMode, ScatterMode, - TopKOperation, UnaryOperation, + InterpolationMode, LayerInformationFormat, LayerType, MatrixOperation, PoolingType, + ReduceOperation, ResizeCoordinateTransformation, ResizeRoundMode, ResizeSelector, ScaleMode, + ScatterMode, TopKOperation, UnaryOperation, }; // Re-export ResizeMode typedef (InterpolationMode alias) diff --git a/trtx/src/logger.rs b/trtx/src/logger.rs index 583d4bc..e205a8d 100644 --- a/trtx/src/logger.rs +++ b/trtx/src/logger.rs @@ -1,6 +1,4 @@ //! Logger interface for TensorRT-RTX -//! -//! Delegates to real/ or mock/ based on feature flag. /// Severity level for log messages #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -34,10 +32,84 @@ impl LogHandler for StderrLogger { } } -#[cfg(feature = "mock")] -pub use crate::mock::logger::Logger; -#[cfg(not(feature = "mock"))] -pub use crate::real::logger::Logger; +/// Logger (uses Rust bridge to TensorRT) +pub struct Logger { + bridge: *mut trtx_sys::RustLoggerBridge, + user_data: *mut std::ffi::c_void, +} + +impl Logger { + pub fn new(handler: H) -> crate::Result { + let handler_box: Box = Box::new(handler); + let user_data = Box::into_raw(Box::new(handler_box)) as *mut std::ffi::c_void; + + let bridge = unsafe { trtx_sys::create_rust_logger_bridge(Self::log_callback, user_data) }; + + if bridge.is_null() { + unsafe { + let outer = Box::from_raw(user_data as *mut Box); + let _ = *outer; + } + return Err(crate::error::Error::Runtime( + "Failed to create logger bridge".to_string(), + )); + } + + Ok(Logger { bridge, user_data }) + } + + pub fn stderr() -> crate::Result { + Self::new(StderrLogger) + } + + pub(crate) fn as_logger_ptr(&self) -> *mut std::ffi::c_void { + unsafe { trtx_sys::get_logger_interface(self.bridge) } + } + + extern "C" fn log_callback( + user_data: *mut std::ffi::c_void, + severity: i32, + msg: *const std::os::raw::c_char, + ) { + if user_data.is_null() || msg.is_null() { + return; + } + unsafe { + let handler_box = &*(user_data as *const Box); + let msg_str = std::ffi::CStr::from_ptr(msg); + let severity = match severity { + 0 => Severity::InternalError, + 1 => Severity::Error, + 2 => Severity::Warning, + 3 => Severity::Info, + 4 => Severity::Verbose, + _ => Severity::Verbose, + }; + if let Ok(msg) = msg_str.to_str() { + handler_box.log(severity, msg); + } + } + } +} + +impl Drop for Logger { + fn drop(&mut self) { + if !self.bridge.is_null() { + unsafe { + trtx_sys::destroy_rust_logger_bridge(self.bridge); + } + } + if !self.user_data.is_null() { + unsafe { + let outer = Box::from_raw(self.user_data as *mut Box); + let _ = *outer; + } + } + } +} + +unsafe impl Send for Logger {} +unsafe impl Sync for Logger {} #[cfg(test)] mod tests { diff --git a/trtx/src/mock/builder.rs b/trtx/src/mock/builder.rs deleted file mode 100644 index 19f80bd..0000000 --- a/trtx/src/mock/builder.rs +++ /dev/null @@ -1,349 +0,0 @@ -//! Mock builder implementations - -use crate::error::Result; -use crate::logger::Logger; -use crate::network::NetworkDefinition; -use trtx_sys::{ - BuilderFlag, ComputeCapability, DeviceType, EngineCapability, HardwareCompatibilityLevel, - MemoryPoolType, PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel, -}; - -/// Builder configuration (mock mode) -pub struct BuilderConfig { - pub(crate) inner: *mut trtx_sys::TrtxBuilderConfig, -} - -impl BuilderConfig { - pub fn set_memory_pool_limit(&mut self, pool: MemoryPoolType, size: usize) { - set_memory_pool_limit(self.inner, pool as i32, size) - } - - pub(crate) fn as_ptr(&self) -> *mut trtx_sys::TrtxBuilderConfig { - self.inner - } - - pub fn set_profiling_verbosity(&mut self, _verbosity: ProfilingVerbosity) {} - - pub fn get_profiling_verbosity(&self) -> ProfilingVerbosity { - ProfilingVerbosity::kDETAILED - } - - pub fn set_avg_timing_iterations(&mut self, _avg_timing: i32) {} - - pub fn get_avg_timing_iterations(&self) -> i32 { - 1 - } - - pub fn set_engine_capability(&mut self, _capability: EngineCapability) {} - - pub fn get_engine_capability(&self) -> EngineCapability { - EngineCapability::kSTANDARD - } - - pub fn set_flags(&mut self, _flags: u32) {} - - pub fn get_flags(&self) -> u32 { - 0 - } - - pub fn set_flag(&mut self, _flag: BuilderFlag) {} - - pub fn clear_flag(&mut self, _flag: BuilderFlag) {} - - pub fn get_flag(&self, _flag: BuilderFlag) -> bool { - false - } - - pub fn set_dla_core(&mut self, _dla_core: i32) {} - - pub fn get_dla_core(&self) -> i32 { - -1 - } - - pub fn set_default_device_type(&mut self, _device_type: DeviceType) {} - - pub fn get_default_device_type(&self) -> DeviceType { - DeviceType::kGPU - } - - pub fn reset(&mut self) {} - - pub fn get_nb_optimization_profiles(&self) -> i32 { - 0 - } - - pub fn set_tactic_sources(&mut self, _sources: u32) -> bool { - true - } - - pub fn get_tactic_sources(&self) -> u32 { - 0 - } - - pub fn get_memory_pool_limit(&self, _pool: MemoryPoolType) -> usize { - 0 - } - - pub fn set_preview_feature(&mut self, _feature: PreviewFeature, _enable: bool) {} - - pub fn get_preview_feature(&self, _feature: PreviewFeature) -> bool { - false - } - - pub fn set_builder_optimization_level(&mut self, _level: i32) {} - - pub fn get_builder_optimization_level(&mut self) -> i32 { - 3 - } - - pub fn set_hardware_compatibility_level(&mut self, _level: HardwareCompatibilityLevel) {} - - pub fn get_hardware_compatibility_level(&self) -> HardwareCompatibilityLevel { - HardwareCompatibilityLevel::kNONE - } - - pub fn set_max_aux_streams(&mut self, _nb_streams: i32) {} - - pub fn get_max_aux_streams(&self) -> i32 { - 0 - } - - pub fn set_runtime_platform(&mut self, _platform: RuntimePlatform) {} - - pub fn get_runtime_platform(&self) -> RuntimePlatform { - RuntimePlatform::kSAME_AS_BUILD - } - - pub fn set_max_nb_tactics(&mut self, _max_nb_tactics: i32) {} - - pub fn get_max_nb_tactics(&self) -> i32 { - -1 - } - - pub fn set_tiling_optimization_level(&mut self, _level: TilingOptimizationLevel) -> bool { - true - } - - pub fn get_tiling_optimization_level(&self) -> TilingOptimizationLevel { - TilingOptimizationLevel::kNONE - } - - pub fn set_l2_limit_for_tiling(&mut self, _size: i64) -> bool { - true - } - - pub fn get_l2_limit_for_tiling(&self) -> i64 { - 0 - } - - pub fn set_nb_compute_capabilities(&mut self, _max_nb_compute_capabilities: i32) -> bool { - true - } - - pub fn get_nb_compute_capabilities(&self) -> i32 { - 0 - } - - pub fn set_compute_capability( - &mut self, - _compute_capability: ComputeCapability, - _index: i32, - ) -> bool { - true - } - - pub fn get_compute_capability(&self, _index: i32) -> ComputeCapability { - ComputeCapability::kNONE - } -} - -impl Drop for BuilderConfig { - fn drop(&mut self) { - destroy_config(self.inner); - } -} - -unsafe impl Send for BuilderConfig {} - -/// Builder (mock mode) -pub struct Builder<'a> { - inner: *mut trtx_sys::TrtxBuilder, - _logger: &'a Logger, -} - -impl<'a> Builder<'a> { - pub fn new(logger: &'a Logger) -> Result { - let builder_ptr = trtx_builder_create(logger.as_ptr())?; - Ok(Builder { - inner: builder_ptr, - _logger: logger, - }) - } - - pub fn create_network(&self, flags: u32) -> Result { - create_network(self.inner, flags) - } - - pub fn create_config(&self) -> Result { - let config_ptr = create_config(self.inner)?; - Ok(BuilderConfig { inner: config_ptr }) - } - - pub fn build_serialized_network( - &self, - network: &mut NetworkDefinition, - config: &mut BuilderConfig, - ) -> Result> { - build_serialized_network(self.inner, network.as_mut_ptr(), config.as_ptr()) - } -} - -impl Drop for Builder<'_> { - fn drop(&mut self) { - destroy_builder(self.inner); - } -} - -unsafe impl Send for Builder<'_> {} - -//------------------------------------------------------------------------------ -// Helper functions (used by above impls) -//------------------------------------------------------------------------------ - -/// Create builder via FFI (mock mode) -pub(crate) fn trtx_builder_create( - logger_ptr: *mut trtx_sys::TrtxLogger, -) -> Result<*mut trtx_sys::TrtxBuilder> { - let mut builder_ptr: *mut trtx_sys::TrtxBuilder = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_builder_create( - logger_ptr, - &mut builder_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(super::from_ffi(result, &error_msg)); - } - - Ok(builder_ptr) -} - -fn set_memory_pool_limit(config_ptr: *mut trtx_sys::TrtxBuilderConfig, pool: i32, size: usize) { - let mut error_msg = [0i8; 1024]; - - unsafe { - trtx_sys::trtx_builder_config_set_memory_pool_limit( - config_ptr, - pool, - size, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; -} - -fn destroy_config(config_ptr: *mut trtx_sys::TrtxBuilderConfig) { - if !config_ptr.is_null() { - unsafe { - trtx_sys::trtx_builder_config_destroy(config_ptr); - } - } -} - -fn create_network( - builder_ptr: *mut trtx_sys::TrtxBuilder, - flags: u32, -) -> Result { - let mut network_ptr: *mut trtx_sys::TrtxNetworkDefinition = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_builder_create_network( - builder_ptr, - flags, - &mut network_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(super::from_ffi(result, &error_msg)); - } - - Ok(NetworkDefinition::from_ptr( - network_ptr as *mut std::ffi::c_void, - )) -} - -fn create_config( - builder_ptr: *mut trtx_sys::TrtxBuilder, -) -> Result<*mut trtx_sys::TrtxBuilderConfig> { - let mut config_ptr: *mut trtx_sys::TrtxBuilderConfig = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_builder_create_builder_config( - builder_ptr, - &mut config_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(super::from_ffi(result, &error_msg)); - } - - Ok(config_ptr) -} - -fn build_serialized_network( - builder_ptr: *mut trtx_sys::TrtxBuilder, - network_ptr: *mut std::ffi::c_void, - config_ptr: *mut trtx_sys::TrtxBuilderConfig, -) -> Result> { - let network_ptr = network_ptr as *mut trtx_sys::TrtxNetworkDefinition; - let mut data_ptr: *mut std::ffi::c_void = std::ptr::null_mut(); - let mut size: usize = 0; - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_builder_build_serialized_network( - builder_ptr, - network_ptr, - config_ptr, - &mut data_ptr, - &mut size, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(super::from_ffi(result, &error_msg)); - } - - // Copy data to Vec and free C buffer - let data = unsafe { - let slice = std::slice::from_raw_parts(data_ptr as *const u8, size); - let vec = slice.to_vec(); - trtx_sys::trtx_free_buffer(data_ptr); - vec - }; - - Ok(data) -} - -fn destroy_builder(builder_ptr: *mut trtx_sys::TrtxBuilder) { - if !builder_ptr.is_null() { - unsafe { - trtx_sys::trtx_builder_destroy(builder_ptr); - } - } -} diff --git a/trtx/src/mock/cuda.rs b/trtx/src/mock/cuda.rs deleted file mode 100644 index c70e344..0000000 --- a/trtx/src/mock/cuda.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! Mock CUDA implementations - -use crate::error::{Error, Result}; - -use super::from_ffi; - -/// RAII wrapper for CUDA device memory (mock mode) -pub struct DeviceBuffer { - ptr: *mut std::ffi::c_void, - size: usize, -} - -impl DeviceBuffer { - pub fn new(size: usize) -> Result { - let ptr = cuda_malloc(size)?; - Ok(DeviceBuffer { ptr, size }) - } - - pub fn as_ptr(&self) -> *mut std::ffi::c_void { - self.ptr - } - - pub fn size(&self) -> usize { - self.size - } - - pub fn copy_from_host(&mut self, data: &[u8]) -> Result<()> { - if data.len() > self.size { - return Err(Error::InvalidArgument( - "Data size exceeds buffer size".to_string(), - )); - } - memcpy_host_to_device(self.ptr, data) - } - - pub fn copy_to_host(&self, data: &mut [u8]) -> Result<()> { - if data.len() > self.size { - return Err(Error::InvalidArgument( - "Data size exceeds buffer size".to_string(), - )); - } - memcpy_device_to_host(self.ptr, data) - } -} - -impl Drop for DeviceBuffer { - fn drop(&mut self) { - cuda_free(self.ptr); - } -} - -unsafe impl Send for DeviceBuffer {} - -/// Synchronize CUDA device -pub fn synchronize() -> Result<()> { - let mut error_msg = [0i8; 1024]; - let result = - unsafe { trtx_sys::trtx_cuda_synchronize(error_msg.as_mut_ptr(), error_msg.len()) }; - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - Ok(()) -} - -/// Get default CUDA stream -pub fn get_default_stream() -> *mut std::ffi::c_void { - unsafe { trtx_sys::trtx_cuda_get_default_stream() } -} - -//------------------------------------------------------------------------------ -// Helper functions -//------------------------------------------------------------------------------ - -fn cuda_malloc(size: usize) -> Result<*mut std::ffi::c_void> { - let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - let result = unsafe { - trtx_sys::trtx_cuda_malloc(&mut ptr, size, error_msg.as_mut_ptr(), error_msg.len()) - }; - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - Ok(ptr) -} - -fn memcpy_host_to_device(ptr: *mut std::ffi::c_void, data: &[u8]) -> Result<()> { - let mut error_msg = [0i8; 1024]; - let result = unsafe { - trtx_sys::trtx_cuda_memcpy_host_to_device( - ptr, - data.as_ptr() as *const std::ffi::c_void, - data.len(), - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - Ok(()) -} - -fn memcpy_device_to_host(ptr: *mut std::ffi::c_void, data: &mut [u8]) -> Result<()> { - let mut error_msg = [0i8; 1024]; - let result = unsafe { - trtx_sys::trtx_cuda_memcpy_device_to_host( - data.as_mut_ptr() as *mut std::ffi::c_void, - ptr, - data.len(), - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - Ok(()) -} - -fn cuda_free(ptr: *mut std::ffi::c_void) { - if !ptr.is_null() { - let mut error_msg = [0i8; 1024]; - unsafe { - let _ = trtx_sys::trtx_cuda_free(ptr, error_msg.as_mut_ptr(), error_msg.len()); - } - } -} diff --git a/trtx/src/mock/error.rs b/trtx/src/mock/error.rs deleted file mode 100644 index a02a62c..0000000 --- a/trtx/src/mock/error.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! Mock error handling - FFI error code conversion -//! -//! Delegates to Error::from_ffi in the error module to avoid circular dependency. - -use crate::error::Error; - -/// Create error from FFI error code and message buffer (mock mode) -pub(crate) fn from_ffi(code: i32, error_msg: &[i8]) -> Error { - Error::from_ffi(code, error_msg) -} diff --git a/trtx/src/mock/logger.rs b/trtx/src/mock/logger.rs deleted file mode 100644 index eb0b852..0000000 --- a/trtx/src/mock/logger.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Mock logger implementations - -use crate::error::Result; -use crate::logger::{LogHandler, Severity}; -use std::os::raw::c_char; - -/// Logger (mock mode) -pub struct Logger { - pub(crate) inner: *mut trtx_sys::TrtxLogger, - _handler: Box, -} - -impl Logger { - pub fn new(handler: H) -> Result { - let handler_box: Box = Box::new(handler); - let user_data = Box::into_raw(Box::new(handler_box)) as *mut std::ffi::c_void; - - let logger_ptr = trtx_logger_create(user_data, Some(log_callback_mock))?; - - let outer_box = unsafe { Box::from_raw(user_data as *mut Box) }; - let handler_box = *outer_box; - - Ok(Logger { - inner: logger_ptr, - _handler: handler_box, - }) - } - - pub fn stderr() -> Result { - Self::new(crate::logger::StderrLogger) - } - - pub(crate) fn as_ptr(&self) -> *mut trtx_sys::TrtxLogger { - self.inner - } -} - -impl Drop for Logger { - fn drop(&mut self) { - trtx_logger_destroy(self.inner); - } -} - -unsafe impl Send for Logger {} -unsafe impl Sync for Logger {} - -//------------------------------------------------------------------------------ -// Helper functions -//------------------------------------------------------------------------------ - -fn trtx_logger_create( - user_data: *mut std::ffi::c_void, - callback: trtx_sys::TrtxLoggerCallback, -) -> Result<*mut trtx_sys::TrtxLogger> { - let mut logger_ptr: *mut trtx_sys::TrtxLogger = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_logger_create( - callback, - user_data, - &mut logger_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(super::from_ffi(result, &error_msg)); - } - - Ok(logger_ptr) -} - -fn trtx_logger_destroy(logger_ptr: *mut trtx_sys::TrtxLogger) { - if !logger_ptr.is_null() { - unsafe { - trtx_sys::trtx_logger_destroy(logger_ptr); - } - } -} - -extern "C" fn log_callback_mock( - user_data: *mut std::ffi::c_void, - severity: trtx_sys::TrtxLoggerSeverity, - msg: *const c_char, -) { - if user_data.is_null() || msg.is_null() { - return; - } - - unsafe { - let handler = &*(user_data as *const Box); - let msg_str = std::ffi::CStr::from_ptr(msg); - - let severity = match severity { - trtx_sys::TrtxLoggerSeverity::TRTX_SEVERITY_INTERNAL_ERROR => Severity::InternalError, - trtx_sys::TrtxLoggerSeverity::TRTX_SEVERITY_ERROR => Severity::Error, - trtx_sys::TrtxLoggerSeverity::TRTX_SEVERITY_WARNING => Severity::Warning, - trtx_sys::TrtxLoggerSeverity::TRTX_SEVERITY_INFO => Severity::Info, - trtx_sys::TrtxLoggerSeverity::TRTX_SEVERITY_VERBOSE => Severity::Verbose, - }; - - if let Ok(msg) = msg_str.to_str() { - handler.log(severity, msg); - } - } -} diff --git a/trtx/src/mock/mod.rs b/trtx/src/mock/mod.rs deleted file mode 100644 index e45ad89..0000000 --- a/trtx/src/mock/mod.rs +++ /dev/null @@ -1,17 +0,0 @@ -//! Unified mock for TensorRT-RTX and CUDA -//! -//! This module provides a single mock implementation used when the `mock` feature -//! is enabled. Mock mode allows development without TensorRT-RTX or GPU hardware. -//! All TensorRT (builder, network, runtime, engine) and CUDA (DeviceBuffer, streams) -//! operations are stubbed here - no real GPU or TensorRT libraries are required. - -pub(crate) mod builder; -pub(crate) mod cuda; -mod error; -pub(crate) mod logger; -pub(crate) mod network; -pub(crate) mod onnx_parser; -pub(crate) mod runtime; - -pub(crate) use error::*; -pub(crate) use network::default_engine_tensor_shape; diff --git a/trtx/src/mock/network.rs b/trtx/src/mock/network.rs deleted file mode 100644 index 009dd6d..0000000 --- a/trtx/src/mock/network.rs +++ /dev/null @@ -1,480 +0,0 @@ -//! Mock network implementations -//! -//! Helper functions and impl blocks for mock mode. -//! Most mock implementations return null pointers or default values. - -use crate::error::Result; -use crate::network::*; -use trtx_sys::MatrixOperation; -use trtx_sys::TopKOperation; - -//============================================================================== -// Helper functions (used by other mock modules via crate::mock::) -//============================================================================== - -/// Default tensor shape for mock (e.g., [1, 3, 224, 224]) -pub(crate) fn default_tensor_dimensions() -> Vec { - vec![1, 3, 224, 224] -} - -/// Default tensor name for mock -pub(crate) fn default_tensor_name() -> &'static str { - "mock_tensor" -} - -/// Default data type for mock (kFLOAT = 0) -pub(crate) fn default_tensor_type() -> i32 { - 0 -} - -/// Default shape for CudaEngine::get_tensor_shape (mock mode) -pub(crate) fn default_engine_tensor_shape() -> Vec { - vec![1_i64, 1000] -} - -/// Destroy network (mock mode) -pub(crate) fn destroy_network(network_ptr: *mut trtx_sys::TrtxNetworkDefinition) { - if !network_ptr.is_null() { - unsafe { - trtx_sys::trtx_network_destroy(network_ptr); - } - } -} - -//============================================================================== -// Impl blocks for network types (mock stubs) -//============================================================================== - -/// Macro to implement Layer trait for mock (stub implementations) -macro_rules! impl_layer_mock { - ($name:ident) => { - impl Layer for $name { - fn get_output(&self, _index: i32) -> Result { - Ok(Tensor { - inner: std::ptr::null_mut(), - }) - } - fn as_ptr(&self) -> *mut std::ffi::c_void { - self.inner - } - } - }; -} - -impl_layer_mock!(ShuffleLayer); -impl_layer_mock!(ActivationLayer); -impl_layer_mock!(ElementWiseLayer); -impl_layer_mock!(ResizeLayer); -impl_layer_mock!(TopKLayer); -impl_layer_mock!(GatherLayer); -impl_layer_mock!(ScatterLayer); -impl_layer_mock!(SelectLayer); -impl_layer_mock!(MatrixMultiplyLayer); -impl_layer_mock!(SoftMaxLayer); -impl_layer_mock!(ReduceLayer); -impl_layer_mock!(CumulativeLayer); -impl_layer_mock!(PoolingLayer); -impl_layer_mock!(ConvolutionLayer); -impl_layer_mock!(DeconvolutionLayer); -impl_layer_mock!(QuantizeLayer); -impl_layer_mock!(DequantizeLayer); -impl_layer_mock!(ConstantLayer); -impl_layer_mock!(ConcatenationLayer); -impl_layer_mock!(ScaleLayer); -impl_layer_mock!(SliceLayer); -impl_layer_mock!(UnaryLayer); -impl_layer_mock!(IdentityLayer); -impl_layer_mock!(PaddingLayer); -impl_layer_mock!(CastLayer); - -/// Macro to implement set_layer_name for mock (no-op). -macro_rules! impl_layer_set_name_mock { - ($($name:ident),* $(,)?) => { - $( - impl $name { - pub fn set_layer_name(&mut self, _name: &str) -> Result<()> { - Ok(()) - } - } - )* - }; -} -impl_layer_set_name_mock!( - ShuffleLayer, - ActivationLayer, - ElementWiseLayer, - ResizeLayer, - TopKLayer, - GatherLayer, - ScatterLayer, - SelectLayer, - MatrixMultiplyLayer, - SoftMaxLayer, - ReduceLayer, - CumulativeLayer, - PoolingLayer, - ConvolutionLayer, - DeconvolutionLayer, - QuantizeLayer, - DequantizeLayer, - ConstantLayer, - ConcatenationLayer, - ScaleLayer, - SliceLayer, - UnaryLayer, - IdentityLayer, - PaddingLayer, - CastLayer, -); - -// Layer-specific impls - all no-ops for mock -impl ShuffleLayer { - pub fn set_reshape_dimensions(&mut self, _dims: &[i32]) -> Result<()> { - Ok(()) - } - pub fn set_first_transpose(&mut self, _order: &[i32]) -> Result<()> { - Ok(()) - } -} -impl ResizeLayer { - pub fn set_output_dimensions(&mut self, _dims: &[i32]) -> Result<()> { - Ok(()) - } - pub fn set_resize_mode(&mut self, _mode: trtx_sys::ResizeMode) -> Result<()> { - Ok(()) - } -} -impl GatherLayer { - pub fn set_gather_mode(&mut self, _mode: trtx_sys::GatherMode) -> Result<()> { - Ok(()) - } -} -impl ScatterLayer { - pub fn set_scatter_mode(&mut self, _mode: trtx_sys::ScatterMode) -> Result<()> { - Ok(()) - } - pub fn set_axis(&mut self, _axis: i32) -> Result<()> { - Ok(()) - } -} -impl ConvolutionLayer { - pub fn set_stride(&mut self, _stride: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_padding(&mut self, _padding: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_dilation(&mut self, _dilation: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_num_groups(&mut self, _num_groups: i32) -> Result<()> { - Ok(()) - } - pub fn set_input(&mut self, _index: i32, _tensor: &Tensor) -> Result<()> { - Ok(()) - } -} -impl DeconvolutionLayer { - pub fn set_stride(&mut self, _stride: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_padding(&mut self, _padding: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_pre_padding(&mut self, _padding: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_post_padding(&mut self, _padding: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_dilation(&mut self, _dilation: &[i32; 2]) -> Result<()> { - Ok(()) - } - pub fn set_num_groups(&mut self, _num_groups: i32) -> Result<()> { - Ok(()) - } - pub fn set_input(&mut self, _index: i32, _tensor: &Tensor) -> Result<()> { - Ok(()) - } -} -impl ConcatenationLayer { - pub fn set_axis(&mut self, _axis: i32) -> Result<()> { - Ok(()) - } -} - -impl Tensor { - pub fn name(&self) -> Result { - Ok(default_tensor_name().to_string()) - } - pub fn set_name(&mut self, _name: &str) -> Result<()> { - Ok(()) - } - pub fn dimensions(&self) -> Result> { - Ok(default_tensor_dimensions()) - } - pub fn get_type(&self) -> Result { - Ok(default_tensor_type()) - } - pub fn set_allowed_formats(&mut self, _formats: u32) -> Result<()> { - Ok(()) - } -} - -impl NetworkDefinition { - pub(crate) fn from_ptr(ptr: *mut std::ffi::c_void) -> Self { - NetworkDefinition { inner: ptr } - } - pub(crate) fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void { - self.inner - } - - pub fn add_input( - &mut self, - _name: &str, - _data_type: trtx_sys::DataType, - _dims: &[i32], - ) -> Result { - Ok(Tensor { - inner: std::ptr::null_mut(), - }) - } - pub fn mark_output(&mut self, _tensor: &Tensor) -> Result<()> { - Ok(()) - } - pub fn get_nb_inputs(&self) -> i32 { - 0 - } - pub fn get_nb_outputs(&self) -> i32 { - 0 - } - pub fn get_input(&self, _index: i32) -> Result { - Ok(Tensor { - inner: std::ptr::null_mut(), - }) - } - pub fn get_output(&self, _index: i32) -> Result { - Ok(Tensor { - inner: std::ptr::null_mut(), - }) - } - - pub fn get_nb_layers(&self) -> i32 { - 0 - } - pub fn get_layer_name(&self, _layer_index: i32) -> Result { - Ok("(mock)".to_string()) - } - pub fn get_layer_type(&self, _layer_index: i32) -> Result { - Ok(0) - } - - pub fn add_activation( - &mut self, - _input: &Tensor, - _activation_type: trtx_sys::ActivationType, - ) -> Result { - Ok(ActivationLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_unary( - &mut self, - _input: &Tensor, - _op: trtx_sys::UnaryOperation, - ) -> Result { - Ok(UnaryLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_identity(&mut self, _input: &Tensor) -> Result { - Ok(IdentityLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_cast(&mut self, _input: &Tensor, _to_type: trtx_sys::DataType) -> Result { - Ok(CastLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_elementwise( - &mut self, - _input1: &Tensor, - _input2: &Tensor, - _op: trtx_sys::ElementWiseOperation, - ) -> Result { - Ok(ElementWiseLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_pooling( - &mut self, - _input: &Tensor, - _pooling_type: trtx_sys::PoolingType, - _window_size: &[i32; 2], - ) -> Result { - Ok(PoolingLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_shuffle(&mut self, _input: &Tensor) -> Result { - Ok(ShuffleLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_matrix_multiply( - &mut self, - _input0: &Tensor, - _op0: MatrixOperation, - _input1: &Tensor, - _op1: MatrixOperation, - ) -> Result { - Ok(MatrixMultiplyLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_convolution( - &mut self, - _input: &Tensor, - _nb_output_maps: i32, - _kernel_size: &[i32; 2], - _weights: &ConvWeights<'_>, - ) -> Result { - Ok(ConvolutionLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_deconvolution( - &mut self, - _input: &Tensor, - _nb_output_maps: i32, - _kernel_size: &[i32; 2], - _weights: &ConvWeights<'_>, - ) -> Result { - Ok(DeconvolutionLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_concatenation(&mut self, _inputs: &[&Tensor]) -> Result { - Ok(ConcatenationLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_constant( - &mut self, - _dims: &[i32], - _weights: &[u8], - _data_type: trtx_sys::DataType, - ) -> Result { - Ok(ConstantLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_softmax(&mut self, _input: &Tensor, _axes: u32) -> Result { - Ok(SoftMaxLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_scale( - &mut self, - _input: &Tensor, - _mode: i32, - _shift: &[u8], - _scale: &[u8], - _power: &[u8], - ) -> Result { - Ok(ScaleLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_reduce( - &mut self, - _input: &Tensor, - _op: trtx_sys::ReduceOperation, - _axes: u32, - _keep_dims: bool, - ) -> Result { - Ok(ReduceLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_cumulative( - &mut self, - _input: &Tensor, - _axis: i32, - _op: trtx_sys::CumulativeOperation, - _exclusive: bool, - _reverse: bool, - ) -> Result { - Ok(CumulativeLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_cumulative_with_axis_tensor( - &mut self, - _input: &Tensor, - _axis_tensor: &Tensor, - _op: trtx_sys::CumulativeOperation, - _exclusive: bool, - _reverse: bool, - ) -> Result { - Ok(CumulativeLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_slice( - &mut self, - _input: &Tensor, - _start: &[i32], - _size: &[i32], - _stride: &[i32], - ) -> Result { - Ok(SliceLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_resize(&mut self, _input: &Tensor) -> Result { - Ok(ResizeLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_topk( - &mut self, - _input: &Tensor, - _op: TopKOperation, - _k: i32, - _axes: u32, - ) -> Result { - Ok(TopKLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_gather( - &mut self, - _data: &Tensor, - _indices: &Tensor, - _axis: i32, - ) -> Result { - Ok(GatherLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_scatter( - &mut self, - _data: &Tensor, - _indices: &Tensor, - _updates: &Tensor, - _mode: trtx_sys::ScatterMode, - ) -> Result { - Ok(ScatterLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_quantize( - &mut self, - _input: &Tensor, - _scale: &Tensor, - _output_type: trtx_sys::DataType, - ) -> Result { - Ok(QuantizeLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_dequantize( - &mut self, - _input: &Tensor, - _scale: &Tensor, - _output_type: trtx_sys::DataType, - ) -> Result { - Ok(DequantizeLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_select( - &mut self, - _condition: &Tensor, - _then_input: &Tensor, - _else_input: &Tensor, - ) -> Result { - Ok(SelectLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_padding( - &mut self, - _input: &Tensor, - _pre_padding: &[i32], - _post_padding: &[i32], - ) -> Result { - Ok(PaddingLayer::from_ptr(std::ptr::null_mut())) - } - pub fn add_assertion(&mut self, _condition: &Tensor, _message: &str) -> Result<()> { - Ok(()) - } - pub fn add_loop(&mut self) -> Result<*mut std::ffi::c_void> { - Ok(std::ptr::null_mut()) - } - pub fn add_if_conditional(&mut self) -> Result<*mut std::ffi::c_void> { - Ok(std::ptr::null_mut()) - } -} - -impl Drop for NetworkDefinition { - fn drop(&mut self) { - if !self.inner.is_null() { - destroy_network(self.inner as *mut trtx_sys::TrtxNetworkDefinition); - } - } -} - -unsafe impl Send for NetworkDefinition {} diff --git a/trtx/src/mock/onnx_parser.rs b/trtx/src/mock/onnx_parser.rs deleted file mode 100644 index 0472f75..0000000 --- a/trtx/src/mock/onnx_parser.rs +++ /dev/null @@ -1,90 +0,0 @@ -//! Mock ONNX parser implementations - -use crate::error::Result; -use crate::logger::Logger; -use crate::network::NetworkDefinition; - -use super::from_ffi; - -/// ONNX parser (mock mode) -pub struct OnnxParser { - inner: *mut trtx_sys::TrtxOnnxParser, -} - -impl OnnxParser { - pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result { - let parser_ptr = trtx_onnx_parser_create( - network.as_mut_ptr() as *mut trtx_sys::TrtxNetworkDefinition, - logger.as_ptr(), - )?; - Ok(OnnxParser { inner: parser_ptr }) - } - - pub fn parse(&self, model_bytes: &[u8]) -> Result<()> { - parse(self.inner, model_bytes) - } -} - -impl Drop for OnnxParser { - fn drop(&mut self) { - destroy_parser(self.inner); - } -} - -unsafe impl Send for OnnxParser {} - -//------------------------------------------------------------------------------ -// Helper functions -//------------------------------------------------------------------------------ - -fn trtx_onnx_parser_create( - network_ptr: *mut trtx_sys::TrtxNetworkDefinition, - logger_ptr: *mut trtx_sys::TrtxLogger, -) -> Result<*mut trtx_sys::TrtxOnnxParser> { - let mut parser_ptr: *mut trtx_sys::TrtxOnnxParser = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_onnx_parser_create( - network_ptr, - logger_ptr, - &mut parser_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(parser_ptr) -} - -fn parse(parser_ptr: *mut trtx_sys::TrtxOnnxParser, model_bytes: &[u8]) -> Result<()> { - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_onnx_parser_parse( - parser_ptr, - model_bytes.as_ptr() as *const std::ffi::c_void, - model_bytes.len(), - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(()) -} - -fn destroy_parser(parser_ptr: *mut trtx_sys::TrtxOnnxParser) { - if !parser_ptr.is_null() { - unsafe { - trtx_sys::trtx_onnx_parser_destroy(parser_ptr); - } - } -} diff --git a/trtx/src/mock/runtime.rs b/trtx/src/mock/runtime.rs deleted file mode 100644 index 182b8b0..0000000 --- a/trtx/src/mock/runtime.rs +++ /dev/null @@ -1,302 +0,0 @@ -//! Mock runtime implementations - -use crate::error::Result; -use crate::logger::Logger; -use std::ffi::CStr; - -use super::from_ffi; - -/// CUDA engine (mock mode) -pub struct CudaEngine { - pub(crate) inner: *mut trtx_sys::TrtxCudaEngine, -} - -impl CudaEngine { - pub(crate) fn from_mock_ptr(ptr: *mut trtx_sys::TrtxCudaEngine) -> Self { - CudaEngine { inner: ptr } - } - - pub fn get_nb_io_tensors(&self) -> Result { - get_nb_io_tensors(self.inner) - } - - pub fn get_tensor_name(&self, index: i32) -> Result { - get_tensor_name(self.inner, index) - } - - pub fn get_tensor_shape(&self, _name: &str) -> Result> { - Ok(crate::mock::default_engine_tensor_shape()) - } - - /// Mock always returns kFLOAT (buffer sizing uses 4 bytes per element). - pub fn get_tensor_dtype(&self, _name: &str) -> Result { - Ok(trtx_sys::nvinfer1::DataType::kFLOAT) - } - - pub fn create_execution_context(&self) -> Result> { - let context_ptr = create_execution_context(self.inner)?; - Ok(ExecutionContext::from_mock_ptr(context_ptr)) - } -} - -impl Drop for CudaEngine { - fn drop(&mut self) { - destroy_engine(self.inner); - } -} - -unsafe impl Send for CudaEngine {} -unsafe impl Sync for CudaEngine {} - -/// Execution context (mock mode) -pub struct ExecutionContext<'a> { - inner: *mut trtx_sys::TrtxExecutionContext, - _engine: std::marker::PhantomData<&'a CudaEngine>, -} - -impl<'a> ExecutionContext<'a> { - pub(crate) fn from_mock_ptr(ptr: *mut trtx_sys::TrtxExecutionContext) -> Self { - ExecutionContext { - inner: ptr, - _engine: std::marker::PhantomData, - } - } - - /// Binds a tensor to a device memory address. - /// - /// # Safety - /// `data` must point to valid CUDA memory with at least the tensor's size in bytes, - /// and remain valid for the duration of inference. - pub unsafe fn set_tensor_address( - &mut self, - name: &str, - data: *mut std::ffi::c_void, - ) -> Result<()> { - set_tensor_address(self.inner, name, data) - } - - /// Enqueues inference on the given CUDA stream. - /// - /// # Safety - /// `cuda_stream` must be a valid CUDA stream, and all tensor addresses must - /// point to valid device memory. - pub unsafe fn enqueue_v3(&mut self, cuda_stream: *mut std::ffi::c_void) -> Result<()> { - enqueue_v3(self.inner, cuda_stream) - } -} - -impl Drop for ExecutionContext<'_> { - fn drop(&mut self) { - destroy_context(self.inner); - } -} - -unsafe impl Send for ExecutionContext<'_> {} - -/// Runtime (mock mode) -pub struct Runtime<'a> { - inner: *mut trtx_sys::TrtxRuntime, - _logger: &'a Logger, -} - -impl<'a> Runtime<'a> { - pub fn new(logger: &'a Logger) -> Result { - let runtime_ptr = trtx_runtime_create(logger.as_ptr())?; - Ok(Runtime { - inner: runtime_ptr, - _logger: logger, - }) - } - - pub fn deserialize_cuda_engine(&self, data: &[u8]) -> Result { - let engine_ptr = deserialize_cuda_engine(self.inner, data)?; - Ok(CudaEngine::from_mock_ptr(engine_ptr)) - } -} - -impl Drop for Runtime<'_> { - fn drop(&mut self) { - destroy_runtime(self.inner); - } -} - -unsafe impl Send for Runtime<'_> {} - -//------------------------------------------------------------------------------ -// Helper functions -//------------------------------------------------------------------------------ - -fn get_nb_io_tensors(engine_ptr: *mut trtx_sys::TrtxCudaEngine) -> Result { - let mut count: i32 = 0; - - let result = unsafe { trtx_sys::trtx_cuda_engine_get_nb_io_tensors(engine_ptr, &mut count) }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &[])); - } - - Ok(count) -} - -fn get_tensor_name(engine_ptr: *mut trtx_sys::TrtxCudaEngine, index: i32) -> Result { - let mut name_ptr: *const i8 = std::ptr::null(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_cuda_engine_get_tensor_name( - engine_ptr, - index, - &mut name_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - let name = unsafe { CStr::from_ptr(name_ptr) }.to_str()?.to_string(); - - Ok(name) -} - -fn trtx_runtime_create( - logger_ptr: *mut trtx_sys::TrtxLogger, -) -> Result<*mut trtx_sys::TrtxRuntime> { - let mut runtime_ptr: *mut trtx_sys::TrtxRuntime = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_runtime_create( - logger_ptr, - &mut runtime_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(runtime_ptr) -} - -fn deserialize_cuda_engine( - runtime_ptr: *mut trtx_sys::TrtxRuntime, - data: &[u8], -) -> Result<*mut trtx_sys::TrtxCudaEngine> { - let mut engine_ptr: *mut trtx_sys::TrtxCudaEngine = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_runtime_deserialize_cuda_engine( - runtime_ptr, - data.as_ptr() as *const std::ffi::c_void, - data.len(), - &mut engine_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(engine_ptr) -} - -fn create_execution_context( - engine_ptr: *mut trtx_sys::TrtxCudaEngine, -) -> Result<*mut trtx_sys::TrtxExecutionContext> { - let mut context_ptr: *mut trtx_sys::TrtxExecutionContext = std::ptr::null_mut(); - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_cuda_engine_create_execution_context( - engine_ptr, - &mut context_ptr, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(context_ptr) -} - -fn set_tensor_address( - context_ptr: *mut trtx_sys::TrtxExecutionContext, - name: &str, - data: *mut std::ffi::c_void, -) -> Result<()> { - let name_cstr = std::ffi::CString::new(name)?; - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_execution_context_set_tensor_address( - context_ptr, - name_cstr.as_ptr(), - data, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(()) -} - -fn enqueue_v3( - context_ptr: *mut trtx_sys::TrtxExecutionContext, - cuda_stream: *mut std::ffi::c_void, -) -> Result<()> { - let mut error_msg = [0i8; 1024]; - - let result = unsafe { - trtx_sys::trtx_execution_context_enqueue_v3( - context_ptr, - cuda_stream, - error_msg.as_mut_ptr(), - error_msg.len(), - ) - }; - - if result != trtx_sys::TRTX_SUCCESS as i32 { - return Err(from_ffi(result, &error_msg)); - } - - Ok(()) -} - -fn destroy_engine(engine_ptr: *mut trtx_sys::TrtxCudaEngine) { - if !engine_ptr.is_null() { - unsafe { - trtx_sys::trtx_cuda_engine_destroy(engine_ptr); - } - } -} - -fn destroy_context(context_ptr: *mut trtx_sys::TrtxExecutionContext) { - if !context_ptr.is_null() { - unsafe { - trtx_sys::trtx_execution_context_destroy(context_ptr); - } - } -} - -fn destroy_runtime(runtime_ptr: *mut trtx_sys::TrtxRuntime) { - if !runtime_ptr.is_null() { - unsafe { - trtx_sys::trtx_runtime_destroy(runtime_ptr); - } - } -} diff --git a/trtx/src/network.rs b/trtx/src/network.rs index fccad21..760ec5b 100644 --- a/trtx/src/network.rs +++ b/trtx/src/network.rs @@ -1,8 +1,32 @@ //! Network definition for building TensorRT engines -//! -//! Types and implementations. Real/mock impls live in real/ and mock/ folders. -use crate::error::Result; +use crate::interfaces::RecordError; +use cxx::UniquePtr; +use std::ffi::{CStr, CString}; +use std::marker::PhantomData; +use std::pin::Pin; +use trtx_sys::nvinfer1::{IConcatenationLayer, INetworkDefinition, ITensor}; +use trtx_sys::{nvinfer1, LayerType}; +use trtx_sys::{ConcreteTrtLayer, TrtLayer}; +use trtx_sys::{DataType, MatrixOperation, ScaleMode, TopKOperation}; + +/// Panics if the layer or tensor was created from a different network. +#[macro_export] +macro_rules! check_network { + ($network:ident, $this:ident) => { + if $network.inner.as_ptr() != $this.network { + panic!("Layer or tensor was created from different network") + } + }; + ($network:ident, $tensor:expr) => { + if $network.inner.as_ptr() != $tensor.network { + panic!("Layer or tensor was created from different network") + } + }; +} + +use crate::error::{Error, Result}; +use crate::interfaces::ErrorRecorder; /// Kernel and optional bias weights for convolution and deconvolution layers. #[derive(Clone)] @@ -14,62 +38,1590 @@ pub struct ConvWeights<'a> { } /// Tensor handle (opaque pointer) -#[allow(dead_code)] // inner is used by real::network and mock::network implspub struct Tensor { -pub struct Tensor { - pub(crate) inner: *mut std::ffi::c_void, +pub struct Tensor<'network> { + pub(crate) inner: *mut nvinfer1::ITensor, + pub(crate) network: &'network nvinfer1::INetworkDefinition, } +impl Tensor<'_> { + pub(crate) unsafe fn new( + network: *const nvinfer1::INetworkDefinition, + ptr: *mut nvinfer1::ITensor, + ) -> Result { + unsafe { + if ptr.is_null() { + return Err(Error::GetTensorFailed); + } + Ok(Self { + inner: ptr, + network: network.as_ref().unwrap(), + }) + } + } -/// Base trait for all layer types -pub trait Layer { - /// Get the output tensor at the specified index - fn get_output(&self, index: i32) -> Result; + #[allow(clippy::mut_from_ref)] + fn pin_mut(&self) -> Pin<&mut nvinfer1::ITensor> { + unsafe { Pin::new_unchecked(self.inner.as_mut().unwrap()) } + } + fn as_ref(&self) -> &nvinfer1::ITensor { + unsafe { self.inner.as_ref().unwrap() } + } + #[allow(clippy::mut_from_ref)] + fn as_mut(&self) -> &mut nvinfer1::ITensor { + unsafe { self.inner.as_mut().unwrap() } + } +} - /// Get the raw layer pointer (for internal use) - fn as_ptr(&self) -> *mut std::ffi::c_void; +pub struct Layer<'network, Inner: TrtLayer> { + pub(crate) inner: Pin<&'network mut Inner>, + pub(crate) network: *const nvinfer1::INetworkDefinition, } -/// Macro to define layer struct -macro_rules! define_network_layer { - ($name:ident) => { - pub struct $name { - pub(crate) inner: *mut std::ffi::c_void, +impl<'network, Inner: ConcreteTrtLayer> Layer<'network, Inner> { + /// See [nvinfer1::ILayer::getType] (compile time dispatch) + pub const fn layer_type(&self) -> LayerType { + Inner::TYPE + } + + pub(crate) fn new( + network: *const nvinfer1::INetworkDefinition, + ptr: *mut Inner, + ) -> Result { + unsafe { + let ptr = ptr + .as_mut() + .ok_or(Error::LayerCreationFailed(Inner::TYPE))?; + Ok(Self { + inner: Pin::new_unchecked(ptr), + network, + }) } + } +} +impl<'network> Layer<'network, nvinfer1::ILayer> { + /// See [nvinfer1::ILayer::getType] (dynamic dispatch) + pub fn layer_type_dynamic(&self) -> LayerType { + self.inner.as_layer().getType().into() + } - impl $name { - pub(crate) fn from_ptr(ptr: *mut std::ffi::c_void) -> Self { - Self { inner: ptr } - } + /// Create a generic ILayer (of unknown type) + pub(crate) fn new_dyn( + network: *const nvinfer1::INetworkDefinition, + ptr: *mut nvinfer1::ILayer, + ) -> Result { + unsafe { + let ptr = ptr.as_mut().ok_or(Error::GetLayerFailed)?; + Ok(Self { + inner: Pin::new_unchecked(ptr), + network, + }) } - }; + } +} + +impl<'network, Inner: TrtLayer> Layer<'network, Inner> { + /// See [nvinfer1::ILayer::getInput] + pub fn get_input( + &self, + network: &'_ NetworkDefinition, + index: i32, + ) -> Result> { + check_network!(network, self); + let tensor = self.inner.as_layer().getInput(index); + unsafe { Tensor::new(self.network, tensor) } + } + + /// See [nvinfer1::ILayer::getOutput] + pub fn get_output( + &self, + network: &'_ NetworkDefinition, + index: i32, + ) -> Result> { + check_network!(network, self); + let tensor = self.inner.as_layer().getOutput(index); + unsafe { Tensor::new(self.network, tensor) } + } + + /// See [nvinfer1::ILayer::setName] + pub fn set_name(&mut self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> { + check_network!(network, self); + let name = CString::new(name)?; + unsafe { + self.inner + .as_mut() + .get_unchecked_mut() + .as_layer_pin_mut() + .setName(name.as_ptr()) + }; + Ok(()) + } + + /// See [nvinfer1::ILayer::getName] + pub fn name(&self, network: &NetworkDefinition) -> String { + check_network!(network, self); + let name = self.inner.as_layer().getName(); + // must clone since layer may change name at any time! Cow from to_string_lossy() only + // possible if name immutable + if name.is_null() { + "(unamed)".to_string() + } else { + unsafe { CStr::from_ptr(name).to_string_lossy().to_string() } + } + } +} + +// Type aliases for every layer (Layer<_, I*Layer> where I*Layer: TrtLayer) +pub type ActivationLayer<'layer> = Layer<'layer, nvinfer1::IActivationLayer>; +pub type AssertionLayer<'layer> = Layer<'layer, nvinfer1::IAssertionLayer>; +pub type CastLayer<'layer> = Layer<'layer, nvinfer1::ICastLayer>; +pub type ConcatenationLayer<'layer> = Layer<'layer, nvinfer1::IConcatenationLayer>; +pub type ConstantLayer<'layer> = Layer<'layer, nvinfer1::IConstantLayer>; +pub type ConvolutionLayer<'layer> = Layer<'layer, nvinfer1::IConvolutionLayer>; +pub type CumulativeLayer<'layer> = Layer<'layer, nvinfer1::ICumulativeLayer>; +pub type DeconvolutionLayer<'layer> = Layer<'layer, nvinfer1::IDeconvolutionLayer>; +pub type DequantizeLayer<'layer> = Layer<'layer, nvinfer1::IDequantizeLayer>; +pub type DynamicQuantizeLayer<'layer> = Layer<'layer, nvinfer1::IDynamicQuantizeLayer>; +pub type ElementWiseLayer<'layer> = Layer<'layer, nvinfer1::IElementWiseLayer>; +pub type EinsumLayer<'layer> = Layer<'layer, nvinfer1::IEinsumLayer>; +pub type FillLayer<'layer> = Layer<'layer, nvinfer1::IFillLayer>; +pub type GatherLayer<'layer> = Layer<'layer, nvinfer1::IGatherLayer>; +pub type GridSampleLayer<'layer> = Layer<'layer, nvinfer1::IGridSampleLayer>; +pub type IdentityLayer<'layer> = Layer<'layer, nvinfer1::IIdentityLayer>; +pub type MatrixMultiplyLayer<'layer> = Layer<'layer, nvinfer1::IMatrixMultiplyLayer>; +pub type NMSLayer<'layer> = Layer<'layer, nvinfer1::INMSLayer>; +pub type NonZeroLayer<'layer> = Layer<'layer, nvinfer1::INonZeroLayer>; +pub type NormalizationLayer<'layer> = Layer<'layer, nvinfer1::INormalizationLayer>; +pub type PaddingLayer<'layer> = Layer<'layer, nvinfer1::IPaddingLayer>; +pub type ParametricReLULayer<'layer> = Layer<'layer, nvinfer1::IParametricReLULayer>; +pub type PoolingLayer<'layer> = Layer<'layer, nvinfer1::IPoolingLayer>; +pub type QuantizeLayer<'layer> = Layer<'layer, nvinfer1::IQuantizeLayer>; +pub type RaggedSoftMaxLayer<'layer> = Layer<'layer, nvinfer1::IRaggedSoftMaxLayer>; +pub type ReduceLayer<'layer> = Layer<'layer, nvinfer1::IReduceLayer>; +pub type ResizeLayer<'layer> = Layer<'layer, nvinfer1::IResizeLayer>; +pub type RotaryEmbeddingLayer<'layer> = Layer<'layer, nvinfer1::IRotaryEmbeddingLayer>; +pub type ScaleLayer<'layer> = Layer<'layer, nvinfer1::IScaleLayer>; +pub type ScatterLayer<'layer> = Layer<'layer, nvinfer1::IScatterLayer>; +pub type SelectLayer<'layer> = Layer<'layer, nvinfer1::ISelectLayer>; +pub type ShapeLayer<'layer> = Layer<'layer, nvinfer1::IShapeLayer>; +pub type ShuffleLayer<'layer> = Layer<'layer, nvinfer1::IShuffleLayer>; +pub type SliceLayer<'layer> = Layer<'layer, nvinfer1::ISliceLayer>; +pub type SoftMaxLayer<'layer> = Layer<'layer, nvinfer1::ISoftMaxLayer>; +pub type SqueezeLayer<'layer> = Layer<'layer, nvinfer1::ISqueezeLayer>; +pub type TopKLayer<'layer> = Layer<'layer, nvinfer1::ITopKLayer>; +pub type UnaryLayer<'layer> = Layer<'layer, nvinfer1::IUnaryLayer>; +pub type UnsqueezeLayer<'layer> = Layer<'layer, nvinfer1::IUnsqueezeLayer>; +pub type ReverseSequenceLayer<'layer> = Layer<'layer, nvinfer1::IReverseSequenceLayer>; +pub type KVCacheUpdateLayer<'layer> = Layer<'layer, nvinfer1::IKVCacheUpdateLayer>; +pub type LrnLayer<'layer> = Layer<'layer, nvinfer1::ILRNLayer>; +pub type OneHotLayer<'layer> = Layer<'layer, nvinfer1::IOneHotLayer>; + +// Loop and conditional boundary layers (created via Loop / IfConditional / add_attention) +pub type AttentionInputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionInputLayer>; +pub type AttentionOutputLayer<'layer> = Layer<'layer, nvinfer1::IAttentionOutputLayer>; +pub type AttentionBoundaryLayer<'layer> = Layer<'layer, nvinfer1::IAttentionBoundaryLayer>; +pub type LoopBoundaryLayer<'layer> = Layer<'layer, nvinfer1::ILoopBoundaryLayer>; +pub type RecurrenceLayer<'layer> = Layer<'layer, nvinfer1::IRecurrenceLayer>; +pub type LoopOutputLayer<'layer> = Layer<'layer, nvinfer1::ILoopOutputLayer>; +pub type TripLimitLayer<'layer> = Layer<'layer, nvinfer1::ITripLimitLayer>; +pub type IteratorLayer<'layer> = Layer<'layer, nvinfer1::IIteratorLayer>; +pub type ConditionLayer<'layer> = Layer<'layer, nvinfer1::IConditionLayer>; +pub type IfConditionalOutputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalOutputLayer>; +pub type IfConditionalInputLayer<'layer> = Layer<'layer, nvinfer1::IIfConditionalInputLayer>; + +pub type DynLayer<'layer> = Layer<'layer, nvinfer1::ILayer>; + +/// Attention block (query, key, value → output). Created by [`NetworkDefinition::add_attention`]. +/// Input/output layers are managed internally by TensorRT. +pub struct Attention<'network> { + pub(crate) inner: Pin<&'network mut nvinfer1::IAttention>, + pub(crate) network: *const nvinfer1::INetworkDefinition, } -define_network_layer!(ShuffleLayer); -define_network_layer!(ActivationLayer); -define_network_layer!(ElementWiseLayer); -define_network_layer!(ResizeLayer); -define_network_layer!(TopKLayer); -define_network_layer!(GatherLayer); -define_network_layer!(ScatterLayer); -define_network_layer!(SelectLayer); -define_network_layer!(MatrixMultiplyLayer); -define_network_layer!(SoftMaxLayer); -define_network_layer!(ReduceLayer); -define_network_layer!(CumulativeLayer); -define_network_layer!(PoolingLayer); -define_network_layer!(ConvolutionLayer); -define_network_layer!(DeconvolutionLayer); -define_network_layer!(QuantizeLayer); -define_network_layer!(DequantizeLayer); -define_network_layer!(ConstantLayer); -define_network_layer!(ConcatenationLayer); -define_network_layer!(ScaleLayer); -define_network_layer!(SliceLayer); -define_network_layer!(UnaryLayer); -define_network_layer!(IdentityLayer); -define_network_layer!(PaddingLayer); -define_network_layer!(CastLayer); +/// Loop construct for recurrent subgraphs. Created by [`NetworkDefinition::add_loop`]. +pub struct Loop<'network> { + pub(crate) inner: Pin<&'network mut nvinfer1::ILoop>, + pub(crate) network: *const nvinfer1::INetworkDefinition, +} + +/// If-conditional construct. Created by [`NetworkDefinition::add_if_conditional`]. +pub struct IfConditional<'network> { + pub(crate) inner: Pin<&'network mut nvinfer1::IIfConditional>, + pub(crate) network: *const nvinfer1::INetworkDefinition, +} +impl ShuffleLayer<'_> { + pub fn set_reshape_dimensions( + &mut self, + network: &mut NetworkDefinition, + dims: &[i64], + ) -> Result<()> { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(dims); + self.inner.as_mut().setReshapeDimensions(&dims_obj); + Ok(()) + } + + pub fn set_first_transpose( + &mut self, + network: &mut NetworkDefinition, + order: &[i32], + ) -> Result<()> { + crate::check_network!(network, self); + let mut order_arr = [0i32; 8]; + let n = order.len().min(8); + order_arr[..n].copy_from_slice(&order[..n]); + let perm = trtx_sys::nvinfer1::Permutation { order: order_arr }; + self.inner.as_mut().setFirstTranspose(perm); + Ok(()) + } +} + +impl ResizeLayer<'_> { + pub fn set_output_dimensions(&mut self, network: &mut NetworkDefinition, dims: &[i64]) { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(dims); + self.inner.as_mut().setOutputDimensions(&dims_obj); + } + pub fn set_resize_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::ResizeMode) { + crate::check_network!(network, self); + self.inner.as_mut().setResizeMode(mode.into()); + } +} + +impl GatherLayer<'_> { + pub fn set_gather_mode(&mut self, network: &mut NetworkDefinition, mode: trtx_sys::GatherMode) { + crate::check_network!(network, self); + self.inner.as_mut().setMode(mode.into()); + } +} + +impl<'network> ScatterLayer<'network> { + pub fn set_scatter_mode( + &mut self, + network: &mut NetworkDefinition, + mode: trtx_sys::ScatterMode, + ) { + crate::check_network!(network, self); + self.inner.as_mut().setMode(mode.into()); + } + pub fn set_axis(&mut self, network: &'_ mut NetworkDefinition, axis: i32) { + crate::check_network!(network, self); + self.inner.as_mut().setAxis(axis); + } +} + +impl<'network> ConvolutionLayer<'network> { + pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(stride); + self.inner.as_mut().setStrideNd(&dims_obj); + } + pub fn set_padding(&mut self, network: &mut NetworkDefinition, padding: &[i64; 2]) { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(padding); + self.inner.as_mut().setPaddingNd(&dims_obj); + } + pub fn set_dilation(&mut self, network: &mut NetworkDefinition, dilation: &[i64; 2]) { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(dilation); + self.inner.as_mut().setDilationNd(&dims_obj); + } + pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, num_groups: i64) { + crate::check_network!(network, self); + self.inner.as_mut().setNbGroups(num_groups); + } + + /// Set an input tensor by index. Input 0 is the activation; 1 is the kernel tensor; 2 is the bias tensor. + /// When using input 1 or 2, the layer must have been created with empty weights for that slot. + pub fn set_input( + &mut self, + network: &mut NetworkDefinition, + index: i32, + tensor: &'_ Tensor<'network>, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + unsafe { + let mut layer_pin = crate::autocxx_helpers::cast_and_pin::( + self.inner.as_mut().get_unchecked_mut() as *mut _ as *mut _, + ); + layer_pin.as_mut().setInput(index, tensor.pin_mut()); + } + Ok(()) + } +} + +impl<'network> DeconvolutionLayer<'network> { + pub fn set_stride(&mut self, network: &mut NetworkDefinition, stride: &[i64; 2]) -> Result<()> { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(stride); + self.inner.as_mut().setStrideNd(&dims_obj); + Ok(()) + } + + /// Set pre-padding (trim this many elements at the start of each spatial dimension of the output). + /// Pass [pre_h, pre_w] for 2D deconv; TensorRT applies to the spatial dimensions only. + pub fn set_pre_padding( + &mut self, + network: &mut NetworkDefinition, + padding: &[i64; 2], + ) -> Result<()> { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(padding); + self.inner.as_mut().setPrePadding(&dims_obj); + Ok(()) + } + /// Set post-padding (trim this many elements at the end of each spatial dimension of the output). + /// Pass [post_h, post_w] for 2D deconv; TensorRT applies to the spatial dimensions only. + pub fn set_post_padding( + &mut self, + network: &mut NetworkDefinition, + padding: &[i64; 2], + ) -> Result<()> { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(padding); + self.inner.as_mut().setPostPadding(&dims_obj); + Ok(()) + } + pub fn set_dilation( + &mut self, + network: &mut NetworkDefinition, + dilation: &[i64; 2], + ) -> Result<()> { + crate::check_network!(network, self); + let dims_obj = trtx_sys::Dims::from_slice(dilation); + self.inner.as_mut().setDilationNd(&dims_obj); + Ok(()) + } + + pub fn set_num_groups( + &mut self, + network: &mut NetworkDefinition, + num_groups: i64, + ) -> Result<()> { + crate::check_network!(network, self); + self.inner.as_mut().setNbGroups(num_groups); + Ok(()) + } + /// Set an input tensor by index. Input 0 is the activation; 1 is the kernel tensor; 2 is the bias tensor. + /// When using input 1 or 2, the layer must have been created with empty weights for that slot. + pub fn set_input( + &mut self, + network: &mut NetworkDefinition, + index: i32, + tensor: &'_ Tensor<'network>, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + unsafe { + let mut layer_pin = crate::autocxx_helpers::cast_and_pin::( + self.inner.as_mut().get_unchecked_mut() as *mut _ as *mut _, + ); + layer_pin.as_mut().setInput(index, tensor.pin_mut()); + } + Ok(()) + } +} + +impl ConcatenationLayer<'_> { + pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) { + crate::check_network!(network, self); + self.inner.as_mut().setAxis(axis); + } +} +impl NormalizationLayer<'_> { + pub fn set_epsilon(&mut self, network: &mut NetworkDefinition, eps: f32) { + crate::check_network!(network, self); + self.inner.as_mut().setEpsilon(eps); + } + pub fn get_epsilon(&self, network: &NetworkDefinition) -> f32 { + crate::check_network!(network, self); + self.inner.as_ref().getEpsilon() + } + pub fn set_axes(&mut self, network: &mut NetworkDefinition, axes: crate::Axes) { + crate::check_network!(network, self); + self.inner.as_mut().setAxes(axes.to_bits()); + } + pub fn get_axes(&self, network: &NetworkDefinition) -> crate::Axes { + crate::check_network!(network, self); + crate::Axes::from_bits(self.inner.as_ref().getAxes()) + } + pub fn set_num_groups(&mut self, network: &mut NetworkDefinition, groups: i64) { + crate::check_network!(network, self); + self.inner.as_mut().setNbGroups(groups); + } + pub fn get_num_groups(&self, network: &NetworkDefinition) -> i64 { + crate::check_network!(network, self); + self.inner.as_ref().getNbGroups() + } + pub fn set_compute_precision(&mut self, network: &mut NetworkDefinition, data_type: DataType) { + crate::check_network!(network, self); + self.inner.as_mut().setComputePrecision(data_type.into()); + } + pub fn get_compute_precision(&self, network: &NetworkDefinition) -> DataType { + crate::check_network!(network, self); + self.inner.as_ref().getComputePrecision().into() + } + pub fn is_v2(&self, network: &NetworkDefinition) -> bool { + crate::check_network!(network, self); + self.inner.as_ref().isV2() + } +} + +impl Tensor<'_> { + pub fn name(&self, network: &NetworkDefinition) -> Result { + crate::check_network!(network, self); + let name_ptr = self.as_ref().getName(); + if name_ptr.is_null() { + return Err(Error::Runtime("Failed to get tensor name".to_string())); + } + unsafe { Ok(std::ffi::CStr::from_ptr(name_ptr).to_str()?.to_string()) } + } + + pub fn set_name(&self, network: &'_ mut NetworkDefinition, name: &str) -> Result<()> { + crate::check_network!(network, self); + let name_cstr = std::ffi::CString::new(name)?; + unsafe { + self.pin_mut().setName(name_cstr.as_ptr()); + } + Ok(()) + } + + pub fn dimensions(&self, network: &NetworkDefinition) -> Result> { + crate::check_network!(network, self); + let result = self.as_ref().getDimensions(); + Ok(result.d[..result.nbDims as usize].to_vec()) + } + + pub fn get_type(&self, network: &NetworkDefinition) -> DataType { + crate::check_network!(network, self); + self.as_ref().getType().into() + } + + /// Set allowed tensor formats (bitmask of TensorFormat). E.g. 1u32 << TensorFormat::kHWC for channels-last. + /// TensorRT may insert reformat layers when connecting tensors with different formats. + pub fn set_allowed_formats( + &mut self, + network: &mut NetworkDefinition, + formats: u32, + ) -> Result<()> { + crate::check_network!(network, self); + self.pin_mut().setAllowedFormats(formats); + Ok(()) + } +} /// Network definition for building TensorRT engines -pub struct NetworkDefinition { - pub(crate) inner: *mut std::ffi::c_void, +pub struct NetworkDefinition<'builder> { + //pub(crate) inner: Mutex>, + pub(crate) inner: UniquePtr, + //error_recorder: Option>>, + _builder: PhantomData<&'builder trtx_sys::nvinfer1::IBuilder>, + small_copied_weights: Vec>, // for convenience we hold pointers to scalars here + error_recorder: Option>>, +} + +impl<'network> NetworkDefinition<'network> { + pub(crate) fn from_ptr(ptr: *mut INetworkDefinition) -> Self { + Self { + inner: unsafe { UniquePtr::from_raw(ptr) }, + error_recorder: None, + _builder: Default::default(), + small_copied_weights: Default::default(), + } + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addInput`]. + pub fn add_input( + &mut self, + name: &str, + data_type: trtx_sys::DataType, + dims: &[i32], + ) -> Result> { + let name_cstr = std::ffi::CString::new(name)?; + let dims_i64: Vec = dims.iter().map(|&d| d as i64).collect(); + let dims_struct = trtx_sys::Dims::from_slice(&dims_i64); + let tensor_ptr = unsafe { + self.inner + .pin_mut() + .addInput(name_cstr.as_ptr(), data_type.into(), &dims_struct) + }; + unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) } + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::markOutput`]. + pub fn mark_output(&mut self, tensor: &'_ Tensor) { + crate::check_network!(self, tensor); + self.inner.pin_mut().markOutput(tensor.pin_mut()); + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::markDebug`]. + /// Mark a tensor for debugging; [IExecutionContext::setDebugListener] will receive it during execution. + pub fn mark_tensor_debug(&mut self, tensor: &'_ Tensor) -> Result<()> { + crate::check_network!(self, tensor); + let success = self.inner.pin_mut().markDebug(tensor.pin_mut()); + if success { + Ok(()) + } else { + Err(Error::Runtime("markDebug failed".to_string())) + } + } + /// See [`trtx_sys::nvinfer1::INetworkDefinition::isDebugTensor`]. + /// Mark a tensor for debugging; [nvinfer1::IExecutionContext::setDebugListener] will receive it during execution. + pub fn is_debug_tensor(&self, tensor: &'_ Tensor) -> bool { + crate::check_network!(self, tensor); + self.inner.isDebugTensor(tensor.as_ref()) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::getNbInputs`]. + pub fn get_nb_inputs(&self) -> i32 { + self.inner.getNbInputs() + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::getNbOutputs`]. + pub fn get_nb_outputs(&self) -> i32 { + self.inner.getNbOutputs() + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::getInput`]. + pub fn get_input(&self, index: i32) -> Result> { + let tensor_ptr = self.inner.getInput(index); + if tensor_ptr.is_null() { + return Err(Error::Runtime(format!( + "Failed to get input at index {}", + index + ))); + } + unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) } + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::getOutput`]. + pub fn get_output(&self, index: i32) -> Result> { + let tensor_ptr = self.inner.getOutput(index); + if tensor_ptr.is_null() { + return Err(Error::Runtime(format!( + "Failed to get output at index {}", + index + ))); + } + unsafe { Tensor::new(self.inner.as_ptr(), tensor_ptr) } + } + + /// Number of layers in the network (for introspection/dumping). + pub fn get_nb_layers(&self) -> i32 { + self.inner.getNbLayers() + } + + pub fn get_layer(&self, layer_index: i32) -> Result> { + let layer_ptr = self.inner.getLayer(layer_index); + DynLayer::new_dyn(self.inner.as_ptr(), layer_ptr) + } + + /// Layer name at index (for introspection/dumping). Returns "(Unnamed)" if null. + #[deprecated = "use network.get_layer(index)?.name(&network)"] + pub fn get_layer_name(&self, layer_index: i32) -> Result { + let layer_ptr = self.inner.getLayer(layer_index); + unsafe { layer_ptr.as_mut() } + .ok_or_else(|| Error::Runtime(format!("No layer at index {}", layer_index)))?; + let name_ptr = unsafe { + crate::autocxx_helpers::cast_and_pin::(layer_ptr as *mut _) + .getName() + }; + Ok(if name_ptr.is_null() { + "(Unnamed)".to_string() + } else { + unsafe { std::ffi::CStr::from_ptr(name_ptr) } + .to_str() + .map_err(|e| Error::Runtime(e.to_string()))? + .to_string() + }) + } + + /// Layer type enum value at index (for introspection/dumping). See TensorRT LayerType. + #[deprecated = "use network.get_layer(index)?.layer_type_dyn()"] + pub fn get_layer_type(&self, layer_index: i32) -> Result { + let layer_ptr = self.inner.getLayer(layer_index); + if layer_ptr.is_null() { + return Err(Error::Runtime(format!("No layer at index {}", layer_index))); + } + let layer_type = unsafe { + crate::autocxx_helpers::cast_and_pin::(layer_ptr as *mut _) + .getType() + }; + Ok(layer_type as i32) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addActivation`]. + pub fn add_activation( + &mut self, + input: &'_ Tensor, + activation_type: trtx_sys::ActivationType, + ) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self + .inner + .pin_mut() + .addActivation(input.pin_mut(), activation_type.into()); + ActivationLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addUnary`]. + pub fn add_unary( + &mut self, + input: &'_ Tensor, + op: trtx_sys::UnaryOperation, + ) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self.inner.pin_mut().addUnary(input.pin_mut(), op.into()); + UnaryLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addIdentity`]. + pub fn add_identity(&mut self, input: &'_ Tensor) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self.inner.pin_mut().addIdentity(input.pin_mut()); + + IdentityLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addCast`]. + pub fn add_cast( + &mut self, + input: &'_ Tensor, + to_type: trtx_sys::DataType, + ) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self + .inner + .pin_mut() + .addCast(input.pin_mut(), to_type.into()); + CastLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addElementWise`]. + pub fn add_elementwise( + &mut self, + input1: &'_ Tensor, + input2: &'_ Tensor, + op: trtx_sys::ElementWiseOperation, + ) -> Result> { + crate::check_network!(self, input1); + crate::check_network!(self, input2); + let layer_ptr = + self.inner + .pin_mut() + .addElementWise(input1.pin_mut(), input2.pin_mut(), op.into()); + ElementWiseLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addPoolingNd`]. + pub fn add_pooling( + &'_ mut self, + input: &'_ Tensor, + pooling_type: trtx_sys::PoolingType, + window_size: &[i64; 2], + ) -> Result> { + crate::check_network!(self, input); + let window_dims = trtx_sys::Dims::new_2d(window_size[0], window_size[1]); + let layer_ptr = + self.inner + .pin_mut() + .addPoolingNd(input.pin_mut(), pooling_type.into(), &window_dims); + PoolingLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addShuffle`]. + pub fn add_shuffle(&'_ mut self, input: &'_ Tensor) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self.inner.pin_mut().addShuffle(input.pin_mut()); + ShuffleLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addMatrixMultiply`]. + pub fn add_matrix_multiply( + &'_ mut self, + input0: &'_ Tensor, + op0: MatrixOperation, + input1: &'_ Tensor, + op1: MatrixOperation, + ) -> Result> { + crate::check_network!(self, input0); + crate::check_network!(self, input1); + let layer_ptr = self.inner.pin_mut().addMatrixMultiply( + input0.pin_mut(), + op0.into(), + input1.pin_mut(), + op1.into(), + ); + MatrixMultiplyLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addConvolutionNd`]. + pub fn add_convolution( + &'_ mut self, + input: &'_ Tensor, + nb_output_maps: i32, + kernel_size: &[i32; 2], + weights: &ConvWeights<'network>, + ) -> Result> { + crate::check_network!(self, input); + let kernel_dtype = weights.kernel_dtype; + let kernel_weights = weights.kernel_weights; + let bias_weights = weights.bias_weights; + let bias_dtype = weights.bias_dtype; + let kernel_bpe = kernel_dtype.size_bits() / 8; + let weight_count = (kernel_weights.len() / kernel_bpe) as i64; + let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype); + let bias_bpe = bias_dtype_val.size_bits() / 8; + let bias_count = bias_weights + .map(|b| (b.len() / bias_bpe) as i64) + .unwrap_or(0); + let kernel_ptr = if weight_count > 0 { + kernel_weights.as_ptr() as *const std::ffi::c_void + } else { + std::ptr::null() + }; + let bias_ptr = if bias_count > 0 { + bias_weights + .map(|b| b.as_ptr() as *const std::ffi::c_void) + .unwrap_or(std::ptr::null()) + } else { + std::ptr::null() + }; + let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64); + let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type( + kernel_dtype.into(), + kernel_ptr, + weight_count, + ); + let bias_w = + trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count); + let layer_ptr = self.inner.pin_mut().addConvolutionNd( + input.pin_mut(), + nb_output_maps as i64, + &kernel_dims, + kernel_w, + bias_w, + ); + ConvolutionLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// Add a 2D deconvolution layer. Same input semantics as convolution: input 0 = activation, + /// input 1 = kernel tensor (use set_input(1, tensor) when kernel_weights is empty), + /// input 2 = bias tensor (use set_input(2, tensor) when bias_weights is None/empty). + /// + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addDeconvolutionNd`]. + pub fn add_deconvolution( + &mut self, + input: &'_ Tensor, + nb_output_maps: i64, + kernel_size: &[i64; 2], + weights: &ConvWeights<'network>, + ) -> Result> { + crate::check_network!(self, input); + let kernel_dtype = weights.kernel_dtype; + let kernel_weights = weights.kernel_weights; + let bias_weights = weights.bias_weights; + let bias_dtype = weights.bias_dtype; + let kernel_bpe = kernel_dtype.size_bits() / 8; + let weight_count = (kernel_weights.len() / kernel_bpe) as i64; + let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype); + let bias_bpe = bias_dtype_val.size_bits() / 8; + let bias_count = bias_weights + .map(|b| (b.len() / bias_bpe) as i64) + .unwrap_or(0); + let kernel_ptr = if weight_count > 0 { + kernel_weights.as_ptr() as *const std::ffi::c_void + } else { + std::ptr::null() + }; + let bias_ptr = if bias_count > 0 { + bias_weights + .map(|b| b.as_ptr() as *const std::ffi::c_void) + .unwrap_or(std::ptr::null()) + } else { + std::ptr::null() + }; + let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0], kernel_size[1]); + let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type( + kernel_dtype.into(), + kernel_ptr, + weight_count, + ); + let bias_w = + trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count); + let layer_ptr = self.inner.pin_mut().addDeconvolutionNd( + input.pin_mut(), + nb_output_maps, + kernel_dims, + kernel_w, + bias_w, + ); + DeconvolutionLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addConcatenation`]. + pub fn add_concatenation(&self, inputs: &[&'_ Tensor]) -> Result> { + for t in inputs.iter() { + crate::check_network!(self, t); + } + let mut input_ptrs: Vec<*mut std::ffi::c_void> = inputs + .iter() + .map(|t| t.as_mut() as *mut ITensor as *mut _) + .collect(); + let layer_ptr = unsafe { + trtx_sys::network_add_concatenation( + self.inner.as_mut_ptr() as *mut std::ffi::c_void, + input_ptrs.as_mut_ptr(), + inputs.len() as i32, + ) + } as *mut IConcatenationLayer; + ConcatenationLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addConstant`]. + /// Same as [`Self::add_constant`] just copying the provided weights for small weights like scalars + pub fn add_small_constant_copied( + &mut self, + dims: &[i64], + weights: &[u8], + data_type: trtx_sys::DataType, + ) -> Result> { + unsafe { self.add_constant_unsafe(dims, weights, data_type, true) } + } + + unsafe fn add_constant_unsafe( + &mut self, + dims: &[i64], + weights: &[u8], + data_type: trtx_sys::DataType, + copy: bool, + ) -> Result> { + let element_count: i64 = dims.iter().product(); + let expected_bytes = element_count * data_type.size_bits() as i64 / 8; + if weights.len() as i64 != expected_bytes { + panic!( + "Weight size mismatch: expected {expected_bytes} bytes, got {} bytes", + weights.len() + ); + } + let dims_struct = trtx_sys::Dims::from_slice(dims); + let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type( + data_type.into(), + if copy { + self.small_copied_weights.push(weights.to_vec()); + self.small_copied_weights + .last() + .expect("can't be empty. we just pushed") + .as_ptr() + } else { + weights.as_ptr() + } as *const std::ffi::c_void, + element_count, + ); + let layer_ptr = self + .inner + .pin_mut() + .addConstant(&dims_struct, weights_struct); + ConstantLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addConstant`]. + pub fn add_constant( + &mut self, + dims: &[i64], + weights: &'network [u8], + data_type: trtx_sys::DataType, + ) -> Result> { + unsafe { self.add_constant_unsafe(dims, weights, data_type, false) } + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addSoftMax`]. + pub fn add_softmax( + &mut self, + input: &'_ Tensor, + axes: crate::Axes, + ) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self.inner.pin_mut().addSoftMax(input.pin_mut()); + let mut rtn = SoftMaxLayer::new(self.inner.as_ptr(), layer_ptr)?; + rtn.inner.as_mut().setAxes(axes.to_bits()); + Ok(rtn) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addScale`]. + pub fn add_scale( + &mut self, + input: &'_ Tensor, + mode: ScaleMode, + shift: &[u8], + scale: &[u8], + power: &[u8], + ) -> Result> { + crate::check_network!(self, input); + let weight_count = match mode { + ScaleMode::kUNIFORM => 1i64, + ScaleMode::kCHANNEL => { + let input_dims = input.dimensions(self)?; + if input_dims.len() >= 4 { + input_dims[1] + } else if !input_dims.is_empty() { + input_dims[0] + } else { + 1i64 + } + } + ScaleMode::kELEMENTWISE => { + let input_dims = input.dimensions(self)?; + input_dims.iter().product::() + } + }; + + let shift_w = trtx_sys::nvinfer1::Weights::new_float( + shift.as_ptr() as *const std::ffi::c_void, + weight_count, + ); + let scale_w = trtx_sys::nvinfer1::Weights::new_float( + scale.as_ptr() as *const std::ffi::c_void, + weight_count, + ); + let power_w = trtx_sys::nvinfer1::Weights::new_float( + power.as_ptr() as *const std::ffi::c_void, + weight_count, + ); + let layer_ptr = + self.inner + .pin_mut() + .addScale(input.pin_mut(), mode.into(), shift_w, scale_w, power_w); + ScaleLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addReduce`]. + pub fn add_reduce( + &mut self, + input: &'_ Tensor, + op: trtx_sys::nvinfer1::ReduceOperation, + axes: crate::Axes, + keep_dims: bool, + ) -> Result> { + crate::check_network!(self, input); + let axes_bits = axes.to_bits(); + let layer_ptr = self + .inner + .pin_mut() + .addReduce(input.pin_mut(), op, axes_bits, keep_dims); + ReduceLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addCumulative`]. + pub fn add_cumulative( + &mut self, + input: &'_ Tensor, + axis: i32, + op: trtx_sys::CumulativeOperation, + exclusive: bool, + reverse: bool, + ) -> Result> { + crate::check_network!(self, input); + let axis_bytes = axis.to_le_bytes(); + let axis_constant = + self.add_small_constant_copied(&[], &axis_bytes, trtx_sys::DataType::kINT32)?; + let axis_tensor = axis_constant.get_output(self, 0)?; + self.add_cumulative_with_axis_tensor(input, &axis_tensor, op, exclusive, reverse) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addCumulative`]. + pub fn add_cumulative_with_axis_tensor( + &mut self, + input: &'_ Tensor, + axis_tensor: &'_ Tensor, + op: trtx_sys::CumulativeOperation, + exclusive: bool, + reverse: bool, + ) -> Result> { + crate::check_network!(self, input); + crate::check_network!(self, axis_tensor); + let layer_ptr = self.inner.pin_mut().addCumulative( + input.pin_mut(), + axis_tensor.pin_mut(), + op.into(), + exclusive, + reverse, + ); + CumulativeLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addSlice`]. + pub fn add_slice( + &mut self, + input: &'_ Tensor, + start: &[i64], + size: &[i64], + stride: &[i64], + ) -> Result> { + crate::check_network!(self, input); + if start.len() != size.len() || start.len() != stride.len() { + return Err(Error::Runtime( + "start, size, and stride must have the same length".to_string(), + )); + } + let start_dims = trtx_sys::Dims::from_slice(start); + let size_dims = trtx_sys::Dims::from_slice(size); + let stride_dims = trtx_sys::Dims::from_slice(stride); + let layer_ptr = + self.inner + .pin_mut() + .addSlice(input.pin_mut(), &start_dims, &size_dims, &stride_dims); + SliceLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addTopK`]. + pub fn add_topk( + &mut self, + input: &'_ Tensor, + op: TopKOperation, + k: i32, + axes: crate::Axes, + ) -> Result> { + crate::check_network!(self, input); + let axes_bits = axes.to_bits(); + let layer_ptr = self + .inner + .pin_mut() + .addTopK(input.pin_mut(), op.into(), k, axes_bits); + TopKLayer::new(self.inner.as_ptr(), layer_ptr) + } + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addResize`]. + pub fn add_resize(&mut self, input: &'_ Tensor) -> Result> { + crate::check_network!(self, input); + let layer_ptr = self.inner.pin_mut().addResize(input.pin_mut()); + ResizeLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addGather`]. + pub fn add_gather( + &'_ mut self, + data: &'_ Tensor, + indices: &'_ Tensor, + axis: i32, + ) -> Result> { + crate::check_network!(self, data); + crate::check_network!(self, indices); + let layer_ptr = self + .inner + .pin_mut() + .addGather(data.pin_mut(), indices.pin_mut(), axis); + GatherLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addScatter`]. + pub fn add_scatter( + &mut self, + data: &'_ Tensor, + indices: &'_ Tensor, + updates: &'_ Tensor, + mode: trtx_sys::ScatterMode, + ) -> Result> { + crate::check_network!(self, data); + crate::check_network!(self, indices); + crate::check_network!(self, updates); + let layer_ptr = self.inner.pin_mut().addScatter( + data.pin_mut(), + indices.pin_mut(), + updates.pin_mut(), + mode.into(), + ); + ScatterLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addQuantize`]. + pub fn add_quantize( + &'_ mut self, + input: &'_ Tensor, + scale: &'_ Tensor, + output_type: trtx_sys::DataType, + ) -> Result> { + crate::check_network!(self, input); + crate::check_network!(self, scale); + let layer_ptr = + self.inner + .pin_mut() + .addQuantize(input.pin_mut(), scale.pin_mut(), output_type.into()); + QuantizeLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addDequantize`]. + pub fn add_dequantize( + &mut self, + input: &'_ Tensor, + scale: &'_ Tensor, + output_type: trtx_sys::DataType, + ) -> Result> { + crate::check_network!(self, input); + crate::check_network!(self, scale); + let layer_ptr = self.inner.pin_mut().addDequantize( + input.pin_mut(), + scale.pin_mut(), + output_type.into(), + ); + DequantizeLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addSelect`]. + pub fn add_select( + &mut self, + condition: &'_ Tensor, + then_input: &'_ Tensor, + else_input: &'_ Tensor, + ) -> Result> { + crate::check_network!(self, condition); + crate::check_network!(self, then_input); + crate::check_network!(self, else_input); + let layer_ptr = self.inner.pin_mut().addSelect( + condition.pin_mut(), + then_input.pin_mut(), + else_input.pin_mut(), + ); + SelectLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addPaddingNd`]. + pub fn add_padding( + &mut self, + input: &'_ Tensor, + pre_padding: &[i64], + post_padding: &[i64], + ) -> Result> { + crate::check_network!(self, input); + if pre_padding.len() != post_padding.len() { + return Err(Error::Runtime( + "pre_padding and post_padding must have the same length".to_string(), + )); + } + let pre_dims = trtx_sys::Dims::from_slice(pre_padding); + let post_dims = trtx_sys::Dims::from_slice(post_padding); + let layer_ptr = self + .inner + .pin_mut() + .addPaddingNd(input.pin_mut(), &pre_dims, &post_dims); + PaddingLayer::new(self.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addAssertion`]. + pub fn add_assertion(&mut self, condition: &'_ Tensor, message: &str) -> Result<()> { + crate::check_network!(self, condition); + let message_cstr = std::ffi::CString::new(message)?; + let layer_ptr = unsafe { + self.inner + .pin_mut() + .addAssertion(condition.pin_mut(), message_cstr.as_ptr()) + }; + let _ = AssertionLayer::new(self.inner.as_ptr(), layer_ptr)?; + Ok(()) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addLoop`]. + pub fn add_loop(&mut self) -> Result> { + let loop_ptr = self.inner.pin_mut().addLoop(); + let loop_ptr = unsafe { loop_ptr.as_mut() } + .ok_or_else(|| Error::Runtime("Failed to add loop".to_string()))?; + Ok(Loop { + inner: unsafe { Pin::new_unchecked(loop_ptr) }, + network: self.inner.as_ptr(), + }) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addIfConditional`]. + pub fn add_if_conditional(&mut self) -> Result> { + let if_ptr = self.inner.pin_mut().addIfConditional(); + let if_ptr = unsafe { if_ptr.as_mut() } + .ok_or_else(|| Error::Runtime("Failed to add if conditional".to_string()))?; + Ok(IfConditional { + inner: unsafe { Pin::new_unchecked(if_ptr) }, + network: self.inner.as_ptr(), + }) + } + + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addAttention`]. + /// Creates an attention block (internally creates [`AttentionInputLayer`] and [`AttentionOutputLayer`]). + pub fn add_attention( + &mut self, + query: &'_ Tensor, + key: &'_ Tensor, + value: &'_ Tensor, + norm_op: trtx_sys::AttentionNormalizationOp, + causal: bool, + ) -> Result> { + crate::check_network!(self, query); + crate::check_network!(self, key); + crate::check_network!(self, value); + let attn_ptr = self.inner.pin_mut().addAttention( + query.pin_mut(), + key.pin_mut(), + value.pin_mut(), + norm_op.into(), + causal, + ); + let attn = unsafe { attn_ptr.as_mut() } + .ok_or_else(|| Error::Runtime("Failed to add attention".to_string()))?; + Ok(Attention { + inner: unsafe { Pin::new_unchecked(attn) }, + network: self.inner.as_ptr(), + }) + } +} + +// --- Attention: get_output --- + +impl<'network> Attention<'network> { + /// See [`trtx_sys::nvinfer1::IAttention::getOutput`]. IAttention has one output (index 0). + pub fn get_output(&self, network: &NetworkDefinition, index: i32) -> Result> { + crate::check_network!(network, self); + let tensor_ptr = self.inner.getOutput(index); + unsafe { Tensor::new(self.network, tensor_ptr) } + } +} + +// --- Loop boundary layers (ILoop::addRecurrence, addTripLimit, addIterator, addLoopOutput) --- + +impl<'network> Loop<'network> { + /// See [`trtx_sys::nvinfer1::ILoop::addRecurrence`]. + pub fn add_recurrence( + &mut self, + network: &mut NetworkDefinition, + initial_value: &'_ Tensor, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, initial_value); + let layer_ptr = { self.inner.as_mut().addRecurrence(initial_value.pin_mut()) }; + RecurrenceLayer::new(network.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::ILoop::addTripLimit`]. + pub fn add_trip_limit( + &mut self, + network: &mut NetworkDefinition, + tensor: &'_ Tensor, + limit: trtx_sys::TripLimit, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + let layer_ptr = { + self.inner + .as_mut() + .addTripLimit(tensor.pin_mut(), limit.into()) + }; + TripLimitLayer::new(network.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::ILoop::addIterator`]. + pub fn add_iterator( + &mut self, + network: &mut NetworkDefinition, + tensor: &'_ Tensor, + axis: i32, + reverse: bool, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + let layer_ptr = { + self.inner + .as_mut() + .addIterator(tensor.pin_mut(), axis, reverse) + }; + IteratorLayer::new(network.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::ILoop::addLoopOutput`]. + pub fn add_loop_output( + &mut self, + network: &mut NetworkDefinition, + tensor: &'_ Tensor, + output_kind: trtx_sys::nvinfer1::LoopOutput, + axis: i32, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + let layer_ptr = self + .inner + .as_mut() + .addLoopOutput(tensor.pin_mut(), output_kind, axis); + LoopOutputLayer::new(network.inner.as_ptr(), layer_ptr) + } +} + +// --- IfConditional boundary layers (IIfConditional::setCondition, addInput, addOutput) --- + +impl<'network> IfConditional<'network> { + /// See [`trtx_sys::nvinfer1::IIfConditional::setCondition`]. + pub fn set_condition( + &mut self, + network: &mut NetworkDefinition, + condition: &'_ Tensor, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, condition); + let layer_ptr = self.inner.as_mut().setCondition(condition.pin_mut()); + ConditionLayer::new(network.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::IIfConditional::addInput`]. + pub fn add_input( + &mut self, + network: &mut NetworkDefinition, + input: &'_ Tensor, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, input); + let layer_ptr = self.inner.as_mut().addInput(input.pin_mut()); + IfConditionalInputLayer::new(network.inner.as_ptr(), layer_ptr) + } + + /// See [`trtx_sys::nvinfer1::IIfConditional::addOutput`]. + pub fn add_output( + &mut self, + network: &mut NetworkDefinition, + true_output: &'_ Tensor, + false_output: &'_ Tensor, + ) -> Result> { + crate::check_network!(network, self); + crate::check_network!(network, true_output); + crate::check_network!(network, false_output); + let layer_ptr = self + .inner + .as_mut() + .addOutput(true_output.pin_mut(), false_output.pin_mut()); + IfConditionalOutputLayer::new(network.inner.as_ptr(), layer_ptr) + } +} + +// --- RecurrenceLayer: set_input(1, tensor) for value from inside loop --- + +impl<'network> RecurrenceLayer<'network> { + /// See [`trtx_sys::nvinfer1::IRecurrenceLayer`]. Input 0 = initial value (set at creation); input 1 = value from previous iteration (from inside loop). + pub fn set_input( + &mut self, + network: &mut NetworkDefinition, + index: i32, + tensor: &'_ Tensor<'network>, + ) -> Result<()> { + crate::check_network!(network, self); + crate::check_network!(network, tensor); + unsafe { + let mut layer_pin = crate::autocxx_helpers::cast_and_pin::( + self.inner.as_mut().get_unchecked_mut() as *mut _ as *mut _, + ); + layer_pin.as_mut().setInput(index, tensor.pin_mut()); + } + Ok(()) + } +} + +// --- IteratorLayer: set_axis, set_reverse --- + +impl IteratorLayer<'_> { + /// See [`trtx_sys::nvinfer1::IIteratorLayer::setAxis`]. + pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) { + crate::check_network!(network, self); + self.inner.as_mut().setAxis(axis); + } + /// See [`trtx_sys::nvinfer1::IIteratorLayer::setReverse`]. + pub fn set_reverse(&mut self, network: &mut NetworkDefinition, reverse: bool) { + crate::check_network!(network, self); + self.inner.as_mut().setReverse(reverse); + } +} + +// --- LoopOutputLayer: get_loop_output, set_axis (for concatenation), set_input for index 1 --- + +impl LoopOutputLayer<'_> { + /// See [`trtx_sys::nvinfer1::ILoopOutputLayer::getLoopOutput`]. + pub fn get_loop_output(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::LoopOutput { + crate::check_network!(network, self); + self.inner.as_ref().getLoopOutput() + } + /// See [`trtx_sys::nvinfer1::ILoopOutputLayer::setAxis`]. Ignored if output kind is kLAST_VALUE. + pub fn set_axis(&mut self, network: &mut NetworkDefinition, axis: i32) { + crate::check_network!(network, self); + self.inner.as_mut().setAxis(axis); + } +} + +// --- TripLimitLayer: get_trip_limit (getter only) --- + +impl TripLimitLayer<'_> { + /// See [`trtx_sys::nvinfer1::ITripLimitLayer::getTripLimit`]. + pub fn get_trip_limit(&self, network: &NetworkDefinition) -> trtx_sys::nvinfer1::TripLimit { + crate::check_network!(network, self); + self.inner.as_ref().getTripLimit() + } +} + +impl<'builder> NetworkDefinition<'builder> { + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addNormalization`]. + pub fn add_normalization( + &mut self, + input: &'_ Tensor, + scale: &'_ Tensor, + bias: &'_ Tensor, + axes_mask: crate::Axes, + ) -> Result> { + crate::check_network!(self, input); + crate::check_network!(self, scale); + crate::check_network!(self, bias); + let axes_bits = axes_mask.to_bits(); + let ptr = self.inner.pin_mut().addNormalization( + input.pin_mut(), + scale.pin_mut(), + bias.pin_mut(), + axes_bits, + ); + NormalizationLayer::new(self.inner.as_ptr(), ptr) + } + /// See [`trtx_sys::nvinfer1::INetworkDefinition::addNormalizationV2`]. + pub fn add_normalization_v2( + &mut self, + input: &'_ Tensor, + scale: &'_ Tensor, + bias: &'_ Tensor, + axes_mask: crate::Axes, + ) -> Result> { + crate::check_network!(self, input); + crate::check_network!(self, scale); + crate::check_network!(self, bias); + let axes_bits = axes_mask.to_bits(); + let ptr = self.inner.pin_mut().addNormalizationV2( + input.pin_mut(), + scale.pin_mut(), + bias.pin_mut(), + axes_bits, + ); + NormalizationLayer::new(self.inner.as_ptr(), ptr) + } + + /// See [y nvinfer1::INetworkDefinition::setErrorRecorder] + /// + /// The Rust bindings only allow setting the error recorder once + pub fn set_error_recorder(&mut self, error_recorder: Box) -> Result<()> { + let error_recorder = ErrorRecorder::new(error_recorder)?; + if self.error_recorder.is_some() { + // would need to make sure that we don't destroy a monitor still in use + // could offer this as an unsafe method for users who only set this when there is no + // build process active. Or we only accept a ref to progress monitor and force user + // via lifetimes to keep this alive for builder config lifetime + panic!("Setting a progress monitor more than once not supported at the moment"); + } + self.error_recorder = Some(error_recorder); + let rec = self + .error_recorder + .as_mut() + .unwrap() + .as_trt_error_recorder(); + #[cfg(not(feature = "mock"))] + unsafe { + self.inner.pin_mut().setErrorRecorder(rec) + }; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use trtx_sys::LayerType; + + use crate::{Builder, Logger}; + + #[test] + #[cfg(not(feature = "mock"))] + fn test_get_layer() { + let logger = Logger::stderr().unwrap(); + let mut builder = Builder::new(&logger).unwrap(); + let mut network = builder.create_network(0).unwrap(); + let input = network + .add_input("a", trtx_sys::DataType::kFLOAT, &[1]) + .unwrap(); + let a = network + .add_activation(&input, trtx_sys::ActivationType::kRELU) + .unwrap() + .get_output(&network, 0) + .unwrap(); + let b = network + .add_activation(&a, trtx_sys::ActivationType::kRELU) + .unwrap() + .get_output(&network, 0) + .unwrap(); + let c = network + .add_activation(&b, trtx_sys::ActivationType::kRELU) + .unwrap() + .get_output(&network, 0) + .unwrap(); + a.set_name(&mut network, "Fritz").unwrap(); + b.set_name(&mut network, "Adam").unwrap(); + c.set_name(&mut network, "James").unwrap(); + + assert_eq!( + &network + .get_layer(0) + .unwrap() + .get_output(&network, 0) + .unwrap() + .name(&network) + .unwrap(), + "Fritz" + ); + assert_eq!( + &network + .get_layer(1) + .unwrap() + .get_output(&network, 0) + .unwrap() + .name(&network) + .unwrap(), + "Adam" + ); + assert_eq!( + &network + .get_layer(2) + .unwrap() + .get_output(&network, 0) + .unwrap() + .name(&network) + .unwrap(), + "James" + ); + assert_eq!( + network.get_layer(2).unwrap().layer_type_dynamic(), + LayerType::kACTIVATION + ); + network + .get_layer(1) + .unwrap() + .set_name(&mut network, "Eva") + .unwrap(); + assert_eq!( + &network + .get_layer(1) + .unwrap() + .get_output(&network, 0) + .unwrap() + .name(&network) + .unwrap(), + &network + .get_layer(2) + .unwrap() + .get_input(&network, 0) + .unwrap() + .name(&network) + .unwrap(), + ); + assert_eq!( + "Adam", + &network + .get_layer(2) + .unwrap() + .get_input(&network, 0) + .unwrap() + .name(&network) + .unwrap() + ); + assert_eq!(&network.get_layer(1).unwrap().name(&network), "Eva"); + } } diff --git a/trtx/src/onnx_parser.rs b/trtx/src/onnx_parser.rs index 88cf355..99b5fd6 100644 --- a/trtx/src/onnx_parser.rs +++ b/trtx/src/onnx_parser.rs @@ -1,11 +1,125 @@ //! ONNX model parser for TensorRT -//! -//! Delegates to real/ or mock/ based on feature flag. -#[cfg(feature = "mock")] -pub use crate::mock::onnx_parser::OnnxParser; -#[cfg(not(feature = "mock"))] -pub use crate::real::onnx_parser::OnnxParser; +use std::marker::PhantomData; + +use cxx::UniquePtr; +use std::ffi::c_void; +use trtx_sys::{nvinfer1, nvonnxparser}; + +use crate::error::{Error, Result}; +use crate::logger::Logger; +use crate::network::NetworkDefinition; + +/// ONNX parser +pub struct OnnxParser<'network> { + inner: UniquePtr, + _network: PhantomData<&'network nvinfer1::INetworkDefinition>, +} + +impl OnnxParser<'_> { + #[cfg(not(any( + feature = "link_tensorrt_onnxparser", + feature = "dlopen_tensorrt_onnxparser" + )))] + pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result { + Err(Error::TrtOnnxParserLibraryNotLoaded) + } + + #[cfg(any( + feature = "link_tensorrt_onnxparser", + feature = "dlopen_tensorrt_onnxparser" + ))] + pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result { + #[cfg(not(feature = "mock"))] + { + let network_ptr = network.inner.as_mut_ptr(); + let logger_ptr = logger.as_logger_ptr(); + let parser_ptr = { + #[cfg(feature = "link_tensorrt_onnxparser")] + unsafe { + trtx_sys::create_onnx_parser(network_ptr, logger_ptr) + } + #[cfg(not(feature = "link_tensorrt_onnxparser"))] + #[cfg(feature = "dlopen_tensorrt_rtx")] + unsafe { + use libloading::Symbol; + use trtx_sys::nvinfer1::INetworkDefinition; + + use crate::TRT_ONNXPARSER_LIB; + + if !TRT_ONNXPARSER_LIB.read()?.is_some() { + crate::dynamically_load_tensorrt_onnxparser(None::)?; + } + + let lock = TRT_ONNXPARSER_LIB + .read() + .map_err(|_| Error::LockPoisining)?; + let create_onnx_parser: Symbol< + fn(*mut INetworkDefinition, *mut c_void, u32) -> *mut c_void, + > = lock + .as_ref() + .ok_or(Error::TrtOnnxParserLibraryNotLoaded)? + .get(b"createNvOnnxParser_INTERNAL")?; + create_onnx_parser(network_ptr, logger_ptr, trtx_sys::get_tensorrt_version()) + } + } as *mut nvonnxparser::IParser; + if parser_ptr.is_null() { + return Err(Error::Runtime("Failed to create ONNX parser".to_string())); + } + Ok(OnnxParser { + inner: unsafe { UniquePtr::from_raw(parser_ptr) }, + _network: Default::default(), + }) + } + #[cfg(feature = "mock")] + Ok(OnnxParser { + inner: UniquePtr::null(), + _network: Default::default(), + }) + } + + pub fn parse(&mut self, model_bytes: &[u8]) -> Result<()> { + #[cfg(not(feature = "mock"))] + { + if self.inner.is_null() { + return Err(Error::Runtime("Invalid parser".to_string())); + } + let parser_ptr = self.inner.as_mut_ptr() as *mut c_void; + let success = unsafe { + trtx_sys::parser_parse( + parser_ptr, + model_bytes.as_ptr() as *const std::ffi::c_void, + model_bytes.len(), + ) + }; + if !success { + let error_msg = unsafe { + let num_errors = trtx_sys::parser_get_nb_errors(parser_ptr); + if num_errors > 0 { + let err_ptr = trtx_sys::parser_get_error(parser_ptr, 0); + if !err_ptr.is_null() { + let desc_ptr = trtx_sys::parser_error_desc(err_ptr); + if !desc_ptr.is_null() { + std::ffi::CStr::from_ptr(desc_ptr) + .to_str() + .unwrap_or("Failed to parse ONNX model") + .to_string() + } else { + "Failed to parse ONNX model".to_string() + } + } else { + "Failed to parse ONNX model".to_string() + } + } else { + "Failed to parse ONNX model".to_string() + } + }; + return Err(Error::Runtime(error_msg)); + } + } + Ok(()) + } +} #[cfg(test)] mod tests { @@ -17,7 +131,7 @@ mod tests { #[ignore] // Requires TensorRT runtime initialization (can hang in test context) fn test_onnx_parser_creation() { let logger = Logger::stderr().unwrap(); - let builder = Builder::new(&logger).unwrap(); + let mut builder = Builder::new(&logger).unwrap(); let mut network = builder .create_network(network_flags::EXPLICIT_BATCH) .unwrap(); @@ -34,11 +148,11 @@ mod tests { ); let model_bytes = std::fs::read(model_path).expect("Failed to read test ONNX model"); let logger = Logger::stderr().unwrap(); - let builder = Builder::new(&logger).unwrap(); + let mut builder = Builder::new(&logger).unwrap(); let mut network = builder .create_network(network_flags::EXPLICIT_BATCH) .unwrap(); - let parser = OnnxParser::new(&mut network, &logger).unwrap(); + let mut parser = OnnxParser::new(&mut network, &logger).unwrap(); let result = parser.parse(&model_bytes); assert!( result.is_ok(), diff --git a/trtx/src/optimization_profile.rs b/trtx/src/optimization_profile.rs new file mode 100644 index 0000000..8e0ec9a --- /dev/null +++ b/trtx/src/optimization_profile.rs @@ -0,0 +1,98 @@ +use std::{ffi::CString, marker::PhantomData, pin::Pin}; + +use crate::{error::PropertySetAttempt, Error, Result}; +use trtx_sys::{nvinfer1, Dims64, OptProfileSelector}; + +/// See [nvinfer1::IOptimizationProfile] +pub struct OptimizationProfile<'builder> { + pub(crate) inner: Pin<&'builder mut nvinfer1::IOptimizationProfile>, + _builder: PhantomData<&'builder nvinfer1::IBuilder>, +} + +impl<'builder> OptimizationProfile<'builder> { + pub fn from_raw(profile: &'builder mut nvinfer1::IOptimizationProfile) -> Self { + Self { + inner: unsafe { Pin::new_unchecked(profile) }, + _builder: Default::default(), + } + } + + pub fn get_dimensions(&self, input_name: &str, select: OptProfileSelector) -> Dims64 { + let input_name_c = + CString::new(input_name).expect("User provided string that contains \\0 characters"); + unsafe { + self.inner + .getDimensions(input_name_c.as_ptr(), select.into()) + } + } + pub fn set_dimensions( + &mut self, + input_name: &str, + select: OptProfileSelector, + dims: &Dims64, + ) -> Result<()> { + let input_name_c = + CString::new(input_name).expect("User provided string that contains \\0 characters"); + unsafe { + if self + .inner + .as_mut() + .setDimensions(input_name_c.as_ptr(), select.into(), dims) + { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::OptimizationProfileSetDimensions, + )) + } + } + } + + pub fn is_valid(&self) -> bool { + self.inner.isValid() + } + + pub fn set_extra_memory_target(&mut self, target: f32) -> Result<()> { + if self.inner.as_mut().setExtraMemoryTarget(target) { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::OptimizationProfileSetExtraMemoryTarget, + )) + } + } + pub fn get_extra_memory_target(&self) -> f32 { + self.inner.getExtraMemoryTarget() + } + + pub fn set_shape_values_v2( + &mut self, + input_name: &str, + select: OptProfileSelector, + values: &[i64], + ) -> Result<()> { + let input_name_c = + CString::new(input_name).expect("User provided string that contains \\0 characters"); + if unsafe { + self.inner.as_mut().setShapeValuesV2( + input_name_c.as_ptr(), + select.into(), + values.as_ptr(), + values.len().try_into().expect("Vector to long for a i32"), + ) + } { + Ok(()) + } else { + Err(Error::FailedToSetProperty( + PropertySetAttempt::OptimizationProfileSetShapeValues, + )) + } + } +} +impl Drop for OptimizationProfile<'_> { + fn drop(&mut self) { + unsafe { + std::ptr::drop_in_place(self.inner.as_mut().get_unchecked_mut()); + } + } +} diff --git a/trtx/src/real/builder.rs b/trtx/src/real/builder.rs deleted file mode 100644 index b395f32..0000000 --- a/trtx/src/real/builder.rs +++ /dev/null @@ -1,124 +0,0 @@ -//! Real TensorRT builder implementation -use crate::error::{Error, Result}; -use crate::logger::Logger; -use crate::network::NetworkDefinition; -use crate::real::host_memory::HostMemory; - -pub use super::builder_config::BuilderConfig; - -/// Builder (real mode) -pub struct Builder<'a> { - inner: *mut std::ffi::c_void, - _logger: &'a Logger, -} - -impl<'builder> Builder<'builder> { - #[cfg(not(feature = "link_tensorrt_rtx"))] - #[cfg(not(feature = "dlopen_tensorrt_rtx"))] - pub fn new(logger: &'a Logger) -> Result { - Err(Error::TrtRtxLibraryNotLoaded) - } - - #[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))] - pub fn new(logger: &'builder Logger) -> Result { - let logger_ptr = logger.as_logger_ptr(); - let builder_ptr = { - #[cfg(feature = "link_tensorrt_rtx")] - unsafe { - trtx_sys::create_infer_builder(logger_ptr) - } - #[cfg(not(feature = "link_tensorrt_rtx"))] - #[cfg(feature = "dlopen_tensorrt_rtx")] - unsafe { - use libloading::Symbol; - use std::ffi::c_void; - - use crate::TRTLIB; - if !TRTLIB.read()?.is_some() { - crate::dynamically_load_tensorrt(None::)?; - } - - let lock = TRTLIB.read()?; - let create_infer_builder: Symbol *mut c_void> = lock - .as_ref() - .ok_or(Error::TrtRtxLibraryNotLoaded)? - .get(b"createInferBuilder_INTERNAL")?; - create_infer_builder(logger_ptr, trtx_sys::get_tensorrt_version()) - } - }; - if builder_ptr.is_null() { - return Err(Error::Runtime("Failed to create builder".to_string())); - } - Ok(Builder { - inner: builder_ptr, - _logger: logger, - }) - } - - pub fn create_network(&self, flags: u32) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid builder".to_string())); - } - let network_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .createNetworkV2(flags) - }; - if network_ptr.is_null() { - return Err(Error::Runtime("Failed to create network".to_string())); - } - Ok(NetworkDefinition::from_ptr(network_ptr as *mut _)) - } - - pub fn create_config(&self) -> Result> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid builder".to_string())); - } - unsafe { - let config_ptr = - crate::autocxx_helpers::cast_and_pin::(self.inner) - .createBuilderConfig() - .as_mut() - .ok_or_else(|| Error::Runtime("Failed to create builder config".to_string()))?; - Ok(BuilderConfig { - inner: std::pin::Pin::new_unchecked(config_ptr), - }) - } - } - - pub fn build_serialized_network<'network, 'config, 'config_borrow>( - &self, - network: &'network mut NetworkDefinition, - config: &'config_borrow mut BuilderConfig<'config>, - ) -> Result> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid builder".to_string())); - } - let network_ptr = network.as_mut_ptr(); - - let serialized_engine = unsafe { - let builder = &mut *(self.inner as *mut trtx_sys::nvinfer1::IBuilder); - let network = &mut *(network_ptr as *mut trtx_sys::nvinfer1::INetworkDefinition); - let mut builder_pin = std::pin::Pin::new_unchecked(builder); - builder_pin - .as_mut() - .buildSerializedNetwork( - std::pin::Pin::new_unchecked(network), - config.inner.as_mut(), - ) - .as_mut() - } - .ok_or_else(|| Error::Runtime("Failed to build serialized network".to_string()))?; - - Ok(unsafe { HostMemory::from_raw_ref(self, serialized_engine) }) - } -} - -impl Drop for Builder<'_> { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_builder(self.inner); - } - } - } -} diff --git a/trtx/src/real/builder_config.rs b/trtx/src/real/builder_config.rs deleted file mode 100644 index 8f627ce..0000000 --- a/trtx/src/real/builder_config.rs +++ /dev/null @@ -1,242 +0,0 @@ -//! Real TensorRT builder config implementation - -use std::pin::Pin; -use std::ptr; - -use trtx_sys::nvinfer1::IBuilderConfig; -use trtx_sys::{ - BuilderFlag, ComputeCapability, DeviceType, EngineCapability, HardwareCompatibilityLevel, - MemoryPoolType, PreviewFeature, ProfilingVerbosity, RuntimePlatform, TilingOptimizationLevel, -}; - -/// Builder configuration (real mode) -pub struct BuilderConfig<'builder> { - pub(crate) inner: Pin<&'builder mut IBuilderConfig>, -} - -impl<'builder> BuilderConfig<'builder> { - /// See [IBuilderConfig::setMemoryPoolLimit] - pub fn set_memory_pool_limit(&mut self, pool: MemoryPoolType, size: usize) { - self.inner.as_mut().setMemoryPoolLimit(pool.into(), size); - } - - /// See [IBuilderConfig::setProfilingVerbosity] - pub fn set_profiling_verbosity(&mut self, verbosity: ProfilingVerbosity) { - self.inner.as_mut().setProfilingVerbosity(verbosity.into()); - } - - /// See [IBuilderConfig::getProfilingVerbosity] - pub fn get_profiling_verbosity(&self) -> ProfilingVerbosity { - self.inner.as_ref().getProfilingVerbosity().into() - } - - /// See [IBuilderConfig::setAvgTimingIterations] - pub fn set_avg_timing_iterations(&mut self, avg_timing: i32) { - self.inner.as_mut().setAvgTimingIterations(avg_timing); - } - - /// See [IBuilderConfig::getAvgTimingIterations] - pub fn get_avg_timing_iterations(&self) -> i32 { - self.inner.as_ref().getAvgTimingIterations() - } - - /// See [IBuilderConfig::setEngineCapability] - pub fn set_engine_capability(&mut self, capability: EngineCapability) { - self.inner.as_mut().setEngineCapability(capability.into()); - } - - /// See [IBuilderConfig::getEngineCapability] - pub fn get_engine_capability(&self) -> EngineCapability { - self.inner.as_ref().getEngineCapability().into() - } - - /// See [IBuilderConfig::setFlags] - pub fn set_flags(&mut self, flags: u32) { - self.inner.as_mut().setFlags(flags); - } - - /// See [IBuilderConfig::getFlags] - pub fn get_flags(&self) -> u32 { - self.inner.as_ref().getFlags() - } - - /// See [IBuilderConfig::setFlag] - pub fn set_flag(&mut self, flag: BuilderFlag) { - self.inner.as_mut().setFlag(flag.into()); - } - - /// See [IBuilderConfig::clearFlag] - pub fn clear_flag(&mut self, flag: BuilderFlag) { - self.inner.as_mut().clearFlag(flag.into()); - } - - /// See [IBuilderConfig::getFlag] - pub fn get_flag(&self, flag: BuilderFlag) -> bool { - self.inner.as_ref().getFlag(flag.into()) - } - - /// See [IBuilderConfig::setDLACore] - pub fn set_dla_core(&mut self, dla_core: i32) { - self.inner.as_mut().setDLACore(dla_core); - } - - /// See [IBuilderConfig::getDLACore] - pub fn get_dla_core(&self) -> i32 { - self.inner.as_ref().getDLACore() - } - - /// See [IBuilderConfig::setDefaultDeviceType] - pub fn set_default_device_type(&mut self, device_type: DeviceType) { - self.inner.as_mut().setDefaultDeviceType(device_type.into()); - } - - /// See [IBuilderConfig::getDefaultDeviceType] - pub fn get_default_device_type(&self) -> DeviceType { - self.inner.as_ref().getDefaultDeviceType().into() - } - - /// See [IBuilderConfig::reset] - pub fn reset(&mut self) { - self.inner.as_mut().reset(); - } - - /// See [IBuilderConfig::getNbOptimizationProfiles] - pub fn get_nb_optimization_profiles(&self) -> i32 { - self.inner.as_ref().getNbOptimizationProfiles() - } - - /// See [IBuilderConfig::setTacticSources] - pub fn set_tactic_sources(&mut self, sources: u32) -> bool { - self.inner.as_mut().setTacticSources(sources) - } - - /// See [IBuilderConfig::getTacticSources] - pub fn get_tactic_sources(&self) -> u32 { - self.inner.as_ref().getTacticSources() - } - - /// See [IBuilderConfig::getMemoryPoolLimit] - pub fn get_memory_pool_limit(&self, pool: MemoryPoolType) -> usize { - self.inner.as_ref().getMemoryPoolLimit(pool.into()) - } - - /// See [IBuilderConfig::setPreviewFeature] - pub fn set_preview_feature(&mut self, feature: PreviewFeature, enable: bool) { - self.inner - .as_mut() - .setPreviewFeature(feature.into(), enable); - } - - /// See [IBuilderConfig::getPreviewFeature] - pub fn get_preview_feature(&self, feature: PreviewFeature) -> bool { - self.inner.as_ref().getPreviewFeature(feature.into()) - } - - /// See [IBuilderConfig::setBuilderOptimizationLevel] - pub fn set_builder_optimization_level(&mut self, level: i32) { - self.inner.as_mut().setBuilderOptimizationLevel(level); - } - - /// See [IBuilderConfig::getBuilderOptimizationLevel] - pub fn get_builder_optimization_level(&mut self) -> i32 { - self.inner.as_mut().getBuilderOptimizationLevel() - } - - /// See [IBuilderConfig::setHardwareCompatibilityLevel] - pub fn set_hardware_compatibility_level(&mut self, level: HardwareCompatibilityLevel) { - self.inner - .as_mut() - .setHardwareCompatibilityLevel(level.into()); - } - - /// See [IBuilderConfig::getHardwareCompatibilityLevel] - pub fn get_hardware_compatibility_level(&self) -> HardwareCompatibilityLevel { - self.inner.as_ref().getHardwareCompatibilityLevel().into() - } - - /// See [IBuilderConfig::setMaxAuxStreams] - pub fn set_max_aux_streams(&mut self, nb_streams: i32) { - self.inner.as_mut().setMaxAuxStreams(nb_streams); - } - - /// See [IBuilderConfig::getMaxAuxStreams] - pub fn get_max_aux_streams(&self) -> i32 { - self.inner.as_ref().getMaxAuxStreams() - } - - /// See [IBuilderConfig::setRuntimePlatform] - pub fn set_runtime_platform(&mut self, platform: RuntimePlatform) { - self.inner.as_mut().setRuntimePlatform(platform.into()); - } - - /// See [IBuilderConfig::getRuntimePlatform] - pub fn get_runtime_platform(&self) -> RuntimePlatform { - self.inner.as_ref().getRuntimePlatform().into() - } - - /// See [IBuilderConfig::setMaxNbTactics] - pub fn set_max_nb_tactics(&mut self, max_nb_tactics: i32) { - self.inner.as_mut().setMaxNbTactics(max_nb_tactics); - } - - /// See [IBuilderConfig::getMaxNbTactics] - pub fn get_max_nb_tactics(&self) -> i32 { - self.inner.as_ref().getMaxNbTactics() - } - - /// See [IBuilderConfig::setTilingOptimizationLevel] - pub fn set_tiling_optimization_level(&mut self, level: TilingOptimizationLevel) -> bool { - self.inner.as_mut().setTilingOptimizationLevel(level.into()) - } - - /// See [IBuilderConfig::getTilingOptimizationLevel] - pub fn get_tiling_optimization_level(&self) -> TilingOptimizationLevel { - self.inner.as_ref().getTilingOptimizationLevel().into() - } - - /// See [IBuilderConfig::setL2LimitForTiling] - pub fn set_l2_limit_for_tiling(&mut self, size: i64) -> bool { - self.inner.as_mut().setL2LimitForTiling(size) - } - - /// See [IBuilderConfig::getL2LimitForTiling] - pub fn get_l2_limit_for_tiling(&self) -> i64 { - self.inner.as_ref().getL2LimitForTiling() - } - - /// See [IBuilderConfig::setNbComputeCapabilities] - pub fn set_nb_compute_capabilities(&mut self, max_nb_compute_capabilities: i32) -> bool { - self.inner - .as_mut() - .setNbComputeCapabilities(max_nb_compute_capabilities) - } - - /// See [IBuilderConfig::getNbComputeCapabilities] - pub fn get_nb_compute_capabilities(&self) -> i32 { - self.inner.as_ref().getNbComputeCapabilities() - } - - /// See [IBuilderConfig::setComputeCapability] - pub fn set_compute_capability( - &mut self, - compute_capability: ComputeCapability, - index: i32, - ) -> bool { - self.inner - .as_mut() - .setComputeCapability(compute_capability.into(), index) - } - - /// See [IBuilderConfig::getComputeCapability] - pub fn get_compute_capability(&self, index: i32) -> ComputeCapability { - self.inner.as_ref().getComputeCapability(index).into() - } -} - -impl Drop for BuilderConfig<'_> { - fn drop(&mut self) { - unsafe { - ptr::drop_in_place(self.inner.as_mut().get_unchecked_mut()); - } - } -} diff --git a/trtx/src/real/cuda.rs b/trtx/src/real/cuda.rs deleted file mode 100644 index dcfe220..0000000 --- a/trtx/src/real/cuda.rs +++ /dev/null @@ -1,69 +0,0 @@ -//! Real CUDA implementation (cudarc is always required when real mode is enabled) - -use crate::error::{Error, Result}; - -use cudarc::driver::{CudaDevice, CudaSlice, DevicePtr}; - -/// RAII wrapper for CUDA device memory (real mode) -pub struct DeviceBuffer { - ptr: CudaSlice, - device: std::sync::Arc, - size: usize, -} - -impl DeviceBuffer { - pub fn new(size: usize) -> Result { - let device = CudaDevice::new(0) - .map_err(|e| Error::Cuda(format!("Failed to initialize CUDA device: {:?}", e)))?; - let ptr = device - .alloc_zeros::(size) - .map_err(|e| Error::Cuda(format!("Failed to allocate CUDA memory: {:?}", e)))?; - Ok(DeviceBuffer { ptr, device, size }) - } - - pub fn as_ptr(&self) -> *mut std::ffi::c_void { - *self.ptr.device_ptr() as *mut std::ffi::c_void - } - - pub fn size(&self) -> usize { - self.size - } - - pub fn copy_from_host(&mut self, data: &[u8]) -> Result<()> { - if data.len() > self.size { - return Err(Error::InvalidArgument( - "Data size exceeds buffer size".to_string(), - )); - } - self.device - .htod_copy_into(data.to_vec(), &mut self.ptr) - .map_err(|e| Error::Cuda(format!("Failed to copy to device: {:?}", e))) - } - - pub fn copy_to_host(&self, data: &mut [u8]) -> Result<()> { - if data.len() > self.size { - return Err(Error::InvalidArgument( - "Data size exceeds buffer size".to_string(), - )); - } - self.device - .dtoh_sync_copy_into(&self.ptr, data) - .map_err(|e| Error::Cuda(format!("Failed to copy from device: {:?}", e))) - } -} - -unsafe impl Send for DeviceBuffer {} - -/// Synchronize CUDA device -pub fn synchronize() -> Result<()> { - let device = CudaDevice::new(0) - .map_err(|e| Error::Cuda(format!("Failed to get CUDA device: {:?}", e)))?; - device - .synchronize() - .map_err(|e| Error::Cuda(format!("Failed to synchronize device: {:?}", e))) -} - -/// Get the default CUDA stream -pub fn get_default_stream() -> *mut std::ffi::c_void { - std::ptr::null_mut() -} diff --git a/trtx/src/real/host_memory.rs b/trtx/src/real/host_memory.rs deleted file mode 100644 index 370cbd9..0000000 --- a/trtx/src/real/host_memory.rs +++ /dev/null @@ -1,56 +0,0 @@ -use core::slice; -use std::ops::Deref; -use std::pin::Pin; -use std::ptr; - -use super::builder::Builder; -use trtx_sys::nvinfer1::IHostMemory; -use trtx_sys::DataType; - -pub struct HostMemory<'builder> { - pub(crate) inner: Pin<&'builder mut IHostMemory>, -} - -impl<'builder> HostMemory<'builder> { - /// assumes ownership of ref - pub(crate) unsafe fn from_raw_ref( - _builder: &'builder Builder, - ptr: &'builder mut IHostMemory, - ) -> Self { - unsafe { - HostMemory { - inner: Pin::new_unchecked(ptr), - } - } - } - - pub fn data_type(&self) -> DataType { - self.inner.as_ref().type_().into() - } -} - -impl<'memory> AsRef<[u8]> for HostMemory<'memory> { - fn as_ref(&self) -> &'memory [u8] { - unsafe { - slice::from_raw_parts( - self.inner.as_ref().data() as *const u8, - self.inner.as_ref().size(), - ) - } - } -} - -impl<'builder> Deref for HostMemory<'builder> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - // You can leverage your existing AsRef implementation here - self.as_ref() - } -} - -impl Drop for HostMemory<'_> { - fn drop(&mut self) { - unsafe { ptr::drop_in_place(self.inner.as_mut().get_unchecked_mut()) }; - } -} diff --git a/trtx/src/real/logger.rs b/trtx/src/real/logger.rs deleted file mode 100644 index f2b2461..0000000 --- a/trtx/src/real/logger.rs +++ /dev/null @@ -1,81 +0,0 @@ -//! Real TensorRT logger implementation - -use crate::error::Result; -use crate::logger::{LogHandler, Severity}; -use std::ffi::{c_void, CStr}; -use std::os::raw::c_char; - -/// Logger (real mode - uses Rust bridge) -pub struct Logger { - bridge: *mut trtx_sys::RustLoggerBridge, - user_data: *mut std::ffi::c_void, -} - -impl Logger { - pub fn new(handler: H) -> Result { - let handler_box: Box = Box::new(handler); - let user_data = Box::into_raw(Box::new(handler_box)) as *mut c_void; - - let bridge = unsafe { trtx_sys::create_rust_logger_bridge(Self::log_callback, user_data) }; - - if bridge.is_null() { - unsafe { - let outer = Box::from_raw(user_data as *mut Box); - let _ = *outer; - } - return Err(crate::error::Error::Runtime( - "Failed to create logger bridge".to_string(), - )); - } - - Ok(Logger { bridge, user_data }) - } - - pub fn stderr() -> Result { - Self::new(crate::logger::StderrLogger) - } - - pub(crate) fn as_logger_ptr(&self) -> *mut c_void { - unsafe { trtx_sys::get_logger_interface(self.bridge) } - } - - extern "C" fn log_callback(user_data: *mut c_void, severity: i32, msg: *const c_char) { - if user_data.is_null() || msg.is_null() { - return; - } - unsafe { - let handler_box = &*(user_data as *const Box); - let msg_str = CStr::from_ptr(msg); - let severity = match severity { - 0 => Severity::InternalError, - 1 => Severity::Error, - 2 => Severity::Warning, - 3 => Severity::Info, - 4 => Severity::Verbose, - _ => Severity::Verbose, - }; - if let Ok(msg) = msg_str.to_str() { - handler_box.log(severity, msg); - } - } - } -} - -impl Drop for Logger { - fn drop(&mut self) { - if !self.bridge.is_null() { - unsafe { - trtx_sys::destroy_rust_logger_bridge(self.bridge); - } - } - if !self.user_data.is_null() { - unsafe { - let outer = Box::from_raw(self.user_data as *mut Box); - let _ = *outer; - } - } - } -} - -unsafe impl Send for Logger {} -unsafe impl Sync for Logger {} diff --git a/trtx/src/real/mod.rs b/trtx/src/real/mod.rs deleted file mode 100644 index 156295a..0000000 --- a/trtx/src/real/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! Real TensorRT implementations -//! No #[cfg] - this module is only compiled when mock feature is disabled - -pub mod builder; -pub mod builder_config; -pub mod cuda; -pub mod host_memory; -pub mod logger; -pub mod network; -pub mod onnx_parser; -pub mod runtime; diff --git a/trtx/src/real/network.rs b/trtx/src/real/network.rs deleted file mode 100644 index 1477b37..0000000 --- a/trtx/src/real/network.rs +++ /dev/null @@ -1,1440 +0,0 @@ -//! Real TensorRT network implementation -//! No #[cfg] - this module is only compiled when mock feature is disabled - -use trtx_sys::{DataType, MatrixOperation, ScaleMode, TopKOperation}; - -use crate::error::{Error, Result}; -use crate::network::*; - -/// Macro to implement Layer trait for real TensorRT types -macro_rules! impl_layer_real { - ($name:ident, $trt_type:path) => { - impl Layer for $name { - fn get_output(&self, index: i32) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid layer".to_string())); - } - let tensor_ptr = unsafe { - let layer_ref = &mut *(self.inner as *mut $trt_type); - layer_ref.as_ref().getOutput(index) - }; - if tensor_ptr.is_null() { - return Err(Error::Runtime("Failed to get output tensor".to_string())); - } - Ok(Tensor { - inner: tensor_ptr as *mut _, - }) - } - fn as_ptr(&self) -> *mut std::ffi::c_void { - self.inner - } - } - }; -} - -impl_layer_real!(ShuffleLayer, trtx_sys::nvinfer1::IShuffleLayer); -impl_layer_real!(ActivationLayer, trtx_sys::nvinfer1::IActivationLayer); -impl_layer_real!(ElementWiseLayer, trtx_sys::nvinfer1::IElementWiseLayer); -impl_layer_real!(ResizeLayer, trtx_sys::nvinfer1::IResizeLayer); -impl_layer_real!(TopKLayer, trtx_sys::nvinfer1::ITopKLayer); -impl_layer_real!(GatherLayer, trtx_sys::nvinfer1::IGatherLayer); -impl_layer_real!(ScatterLayer, trtx_sys::nvinfer1::IScatterLayer); -impl_layer_real!(SelectLayer, trtx_sys::nvinfer1::ISelectLayer); -impl_layer_real!( - MatrixMultiplyLayer, - trtx_sys::nvinfer1::IMatrixMultiplyLayer -); -impl_layer_real!(SoftMaxLayer, trtx_sys::nvinfer1::ISoftMaxLayer); -impl_layer_real!(ReduceLayer, trtx_sys::nvinfer1::IReduceLayer); -impl_layer_real!(CumulativeLayer, trtx_sys::nvinfer1::ICumulativeLayer); -impl_layer_real!(PoolingLayer, trtx_sys::nvinfer1::IPoolingLayer); -impl_layer_real!(ConvolutionLayer, trtx_sys::nvinfer1::IConvolutionLayer); -impl_layer_real!(DeconvolutionLayer, trtx_sys::nvinfer1::IDeconvolutionLayer); -impl_layer_real!(QuantizeLayer, trtx_sys::nvinfer1::IQuantizeLayer); -impl_layer_real!(DequantizeLayer, trtx_sys::nvinfer1::IDequantizeLayer); -impl_layer_real!(ConstantLayer, trtx_sys::nvinfer1::IConstantLayer); -impl_layer_real!(ConcatenationLayer, trtx_sys::nvinfer1::IConcatenationLayer); -impl_layer_real!(ScaleLayer, trtx_sys::nvinfer1::IScaleLayer); -impl_layer_real!(SliceLayer, trtx_sys::nvinfer1::ISliceLayer); -impl_layer_real!(UnaryLayer, trtx_sys::nvinfer1::IUnaryLayer); -impl_layer_real!(IdentityLayer, trtx_sys::nvinfer1::IIdentityLayer); -impl_layer_real!(PaddingLayer, trtx_sys::nvinfer1::IPaddingLayer); -impl_layer_real!(CastLayer, trtx_sys::nvinfer1::ICastLayer); - -/// Macro to implement set_layer_name for all layer types (sets ILayer::setName in TensorRT). -macro_rules! impl_layer_set_name { - ($($name:ident),* $(,)?) => { - $( - impl $name { - /// Set the TensorRT layer name (used in error messages and introspection). - pub fn set_layer_name(&mut self, name: &str) -> Result<()> { - let name_cstr = std::ffi::CString::new(name)?; - unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .as_mut() - .setName(name_cstr.as_ptr()); - } - Ok(()) - } - } - )* - }; -} -impl_layer_set_name!( - ShuffleLayer, - ActivationLayer, - ElementWiseLayer, - ResizeLayer, - TopKLayer, - GatherLayer, - ScatterLayer, - SelectLayer, - MatrixMultiplyLayer, - SoftMaxLayer, - ReduceLayer, - CumulativeLayer, - PoolingLayer, - ConvolutionLayer, - DeconvolutionLayer, - QuantizeLayer, - DequantizeLayer, - ConstantLayer, - ConcatenationLayer, - ScaleLayer, - SliceLayer, - UnaryLayer, - IdentityLayer, - PaddingLayer, - CastLayer, -); - -impl ShuffleLayer { - pub fn set_reshape_dimensions(&mut self, dims: &[i32]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid shuffle layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IShuffleLayer, - >(self.inner); - let dims_i64: Vec = dims.iter().map(|&d| d as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setReshapeDimensions(&dims_obj); - } - Ok(()) - } - pub fn set_first_transpose(&mut self, order: &[i32]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid shuffle layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IShuffleLayer, - >(self.inner); - let mut order_arr = [0i32; 8]; - let n = order.len().min(8); - order_arr[..n].copy_from_slice(&order[..n]); - let perm = trtx_sys::nvinfer1::Permutation { order: order_arr }; - layer_pin.as_mut().setFirstTranspose(perm); - } - Ok(()) - } -} - -impl ResizeLayer { - pub fn set_output_dimensions(&mut self, dims: &[i32]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid resize layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IResizeLayer, - >(self.inner); - let dims_i64: Vec = dims.iter().map(|&d| d as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setOutputDimensions(&dims_obj); - } - Ok(()) - } - pub fn set_resize_mode(&mut self, mode: trtx_sys::ResizeMode) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid resize layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IResizeLayer, - >(self.inner); - layer_pin.as_mut().setResizeMode(mode); - } - Ok(()) - } -} - -impl GatherLayer { - pub fn set_gather_mode(&mut self, mode: trtx_sys::GatherMode) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid gather layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IGatherLayer, - >(self.inner); - layer_pin.as_mut().setMode(mode.into()); - } - Ok(()) - } -} - -impl ScatterLayer { - pub fn set_scatter_mode(&mut self, mode: trtx_sys::ScatterMode) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid scatter layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IScatterLayer, - >(self.inner); - layer_pin.as_mut().setMode(mode.into()); - } - Ok(()) - } - pub fn set_axis(&mut self, axis: i32) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid scatter layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IScatterLayer, - >(self.inner); - layer_pin.as_mut().setAxis(axis); - } - Ok(()) - } -} - -impl ConvolutionLayer { - pub fn set_stride(&mut self, stride: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid convolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IConvolutionLayer, - >(self.inner); - let dims_i64: Vec = stride.iter().map(|&s| s as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setStrideNd(&dims_obj); - } - Ok(()) - } - pub fn set_padding(&mut self, padding: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid convolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IConvolutionLayer, - >(self.inner); - let dims_i64: Vec = padding.iter().map(|&p| p as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setPaddingNd(&dims_obj); - } - Ok(()) - } - pub fn set_dilation(&mut self, dilation: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid convolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IConvolutionLayer, - >(self.inner); - let dims_i64: Vec = dilation.iter().map(|&d| d as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setDilationNd(&dims_obj); - } - Ok(()) - } - pub fn set_num_groups(&mut self, num_groups: i32) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid convolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IConvolutionLayer, - >(self.inner); - layer_pin.as_mut().setNbGroups(num_groups as i64); - } - Ok(()) - } - - /// Set an input tensor by index. Input 0 is the activation; 1 is the kernel tensor; 2 is the bias tensor. - /// When using input 1 or 2, the layer must have been created with empty weights for that slot. - pub fn set_input(&mut self, index: i32, tensor: &Tensor) -> Result<()> { - if self.inner.is_null() || tensor.inner.is_null() { - return Err(Error::Runtime( - "Invalid convolution layer or tensor".to_string(), - )); - } - unsafe { - let mut layer_pin = - crate::autocxx_helpers::cast_and_pin::(self.inner); - let mut tensor_pin = - crate::autocxx_helpers::cast_and_pin::(tensor.inner); - layer_pin.as_mut().setInput(index, tensor_pin.as_mut()); - } - Ok(()) - } -} - -impl DeconvolutionLayer { - pub fn set_stride(&mut self, stride: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - let dims_i64: Vec = stride.iter().map(|&s| s as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setStrideNd(&dims_obj); - } - Ok(()) - } - pub fn set_padding(&mut self, padding: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - let dims_i64: Vec = padding.iter().map(|&p| p as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setPaddingNd(&dims_obj); - } - Ok(()) - } - /// Set pre-padding (trim this many elements at the start of each spatial dimension of the output). - /// Pass [pre_h, pre_w] for 2D deconv; TensorRT applies to the spatial dimensions only. - pub fn set_pre_padding(&mut self, padding: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - let dims_i64: Vec = padding.iter().map(|&p| p as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setPrePadding(&dims_obj); - } - Ok(()) - } - /// Set post-padding (trim this many elements at the end of each spatial dimension of the output). - /// Pass [post_h, post_w] for 2D deconv; TensorRT applies to the spatial dimensions only. - pub fn set_post_padding(&mut self, padding: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - let dims_i64: Vec = padding.iter().map(|&p| p as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setPostPadding(&dims_obj); - } - Ok(()) - } - pub fn set_dilation(&mut self, dilation: &[i32; 2]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - let dims_i64: Vec = dilation.iter().map(|&d| d as i64).collect(); - let dims_obj = trtx_sys::Dims::from_slice(&dims_i64); - layer_pin.as_mut().setDilationNd(&dims_obj); - } - Ok(()) - } - pub fn set_num_groups(&mut self, num_groups: i32) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid deconvolution layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IDeconvolutionLayer, - >(self.inner); - layer_pin.as_mut().setNbGroups(num_groups as i64); - } - Ok(()) - } - /// Set an input tensor by index. Input 0 is the activation; 1 is the kernel tensor; 2 is the bias tensor. - /// When using input 1 or 2, the layer must have been created with empty weights for that slot. - pub fn set_input(&mut self, index: i32, tensor: &Tensor) -> Result<()> { - if self.inner.is_null() || tensor.inner.is_null() { - return Err(Error::Runtime( - "Invalid deconvolution layer or tensor".to_string(), - )); - } - unsafe { - let mut layer_pin = - crate::autocxx_helpers::cast_and_pin::(self.inner); - let mut tensor_pin = - crate::autocxx_helpers::cast_and_pin::(tensor.inner); - layer_pin.as_mut().setInput(index, tensor_pin.as_mut()); - } - Ok(()) - } -} - -impl ConcatenationLayer { - pub fn set_axis(&mut self, axis: i32) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid concatenation layer".to_string())); - } - unsafe { - let mut layer_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::IConcatenationLayer, - >(self.inner); - layer_pin.as_mut().setAxis(axis); - } - Ok(()) - } -} - -impl Tensor { - pub fn name(&self) -> Result { - let name_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .getName() - }; - if name_ptr.is_null() { - return Err(Error::Runtime("Failed to get tensor name".to_string())); - } - unsafe { Ok(std::ffi::CStr::from_ptr(name_ptr).to_str()?.to_string()) } - } - - pub fn set_name(&mut self, name: &str) -> Result<()> { - let name_cstr = std::ffi::CString::new(name)?; - unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .setName(name_cstr.as_ptr()); - } - Ok(()) - } - - pub fn dimensions(&self) -> Result> { - let mut dims = [0i32; 8]; - let mut nb_dims = 0i32; - let result = - unsafe { trtx_sys::tensor_get_dimensions(self.inner, dims.as_mut_ptr(), &mut nb_dims) }; - if result.is_null() { - return Err(Error::Runtime( - "Failed to get tensor dimensions".to_string(), - )); - } - if nb_dims < 0 { - return Err(Error::Runtime( - "Tensor dimensions not set (nbDims = -1)".to_string(), - )); - } - Ok(dims[..nb_dims as usize].to_vec()) - } - - pub fn get_type(&self) -> Result { - let data_type = unsafe { trtx_sys::tensor_get_type(self.inner) }; - Ok(data_type) - } - - /// Set allowed tensor formats (bitmask of TensorFormat). E.g. 1u32 << TensorFormat::kHWC for channels-last. - /// TensorRT may insert reformat layers when connecting tensors with different formats. - pub fn set_allowed_formats(&mut self, formats: u32) -> Result<()> { - unsafe { - let tensor_ref = &mut *(self.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut tensor_pin = std::pin::Pin::new_unchecked(tensor_ref); - tensor_pin.as_mut().setAllowedFormats(formats); - } - Ok(()) - } -} - -impl NetworkDefinition { - pub(crate) fn from_ptr(ptr: *mut std::ffi::c_void) -> Self { - NetworkDefinition { inner: ptr } - } - - pub(crate) fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void { - self.inner - } - - pub fn add_input( - &mut self, - name: &str, - data_type: trtx_sys::DataType, - dims: &[i32], - ) -> Result { - let name_cstr = std::ffi::CString::new(name)?; - let dims_i64: Vec = dims.iter().map(|&d| d as i64).collect(); - let dims_struct = trtx_sys::Dims::from_slice(&dims_i64); - let network_ref = - unsafe { &mut *(self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition) }; - let mut network_pin = unsafe { std::pin::Pin::new_unchecked(network_ref) }; - let tensor_ptr = unsafe { - network_pin - .as_mut() - .addInput(name_cstr.as_ptr(), data_type.into(), &dims_struct) - }; - if tensor_ptr.is_null() { - return Err(Error::Runtime(format!("Failed to add input: {}", name))); - } - Ok(Tensor { - inner: tensor_ptr as *mut _, - }) - } - - pub fn mark_output(&mut self, tensor: &Tensor) -> Result<()> { - unsafe { - let tensor_ref = &mut *(tensor.inner as *mut trtx_sys::nvinfer1::ITensor); - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .markOutput(std::pin::Pin::new_unchecked(tensor_ref)); - } - Ok(()) - } - - pub fn get_nb_inputs(&self) -> i32 { - unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getNbInputs() - } - } - - pub fn get_nb_outputs(&self) -> i32 { - unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getNbOutputs() - } - } - - pub fn get_input(&self, index: i32) -> Result { - let tensor_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getInput(index) - }; - if tensor_ptr.is_null() { - return Err(Error::Runtime(format!( - "Failed to get input at index {}", - index - ))); - } - Ok(Tensor { - inner: tensor_ptr as *mut _, - }) - } - - pub fn get_output(&self, index: i32) -> Result { - let tensor_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getOutput(index) - }; - if tensor_ptr.is_null() { - return Err(Error::Runtime(format!( - "Failed to get output at index {}", - index - ))); - } - Ok(Tensor { - inner: tensor_ptr as *mut _, - }) - } - - /// Number of layers in the network (for introspection/dumping). - pub fn get_nb_layers(&self) -> i32 { - unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getNbLayers() - } - } - - /// Layer name at index (for introspection/dumping). Returns "(Unnamed)" if null. - pub fn get_layer_name(&self, layer_index: i32) -> Result { - let layer_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getLayer(layer_index) - }; - if layer_ptr.is_null() { - return Err(Error::Runtime(format!("No layer at index {}", layer_index))); - } - let name_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::(layer_ptr as *mut _) - .getName() - }; - Ok(if name_ptr.is_null() { - "(Unnamed)".to_string() - } else { - unsafe { std::ffi::CStr::from_ptr(name_ptr) } - .to_str() - .map_err(|e| Error::Runtime(e.to_string()))? - .to_string() - }) - } - - /// Layer type enum value at index (for introspection/dumping). See TensorRT LayerType. - pub fn get_layer_type(&self, layer_index: i32) -> Result { - let layer_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .getLayer(layer_index) - }; - if layer_ptr.is_null() { - return Err(Error::Runtime(format!("No layer at index {}", layer_index))); - } - let layer_type = unsafe { - crate::autocxx_helpers::cast_and_pin::(layer_ptr as *mut _) - .getType() - }; - Ok(layer_type as i32) - } - - pub fn add_activation( - &mut self, - input: &Tensor, - activation_type: trtx_sys::ActivationType, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addActivation( - std::pin::Pin::new_unchecked(input_ref), - activation_type.into(), - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add activation layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ActivationLayer::from_ptr(layer_ptr)) - } - - pub fn add_unary( - &mut self, - input: &Tensor, - op: trtx_sys::UnaryOperation, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addUnary(std::pin::Pin::new_unchecked(input_ref), op.into()); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add unary layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(UnaryLayer::from_ptr(layer_ptr)) - } - - pub fn add_identity(&mut self, input: &Tensor) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addIdentity(std::pin::Pin::new_unchecked(input_ref)); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add identity layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(IdentityLayer::from_ptr(layer_ptr)) - } - - pub fn add_cast( - &mut self, - input: &Tensor, - to_type: trtx_sys::nvinfer1::DataType, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addCast(std::pin::Pin::new_unchecked(input_ref), to_type); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add cast layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(CastLayer::from_ptr(layer_ptr)) - } - - pub fn add_elementwise( - &mut self, - input1: &Tensor, - input2: &Tensor, - op: trtx_sys::ElementWiseOperation, - ) -> Result { - let layer_ptr = unsafe { - let input1_ref = &mut *(input1.inner as *mut trtx_sys::nvinfer1::ITensor); - let input2_ref = &mut *(input2.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addElementWise( - std::pin::Pin::new_unchecked(input1_ref), - std::pin::Pin::new_unchecked(input2_ref), - op.into(), - ); - if layer_ptr.is_null() { - return Err(Error::Runtime( - "Failed to add elementwise layer".to_string(), - )); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ElementWiseLayer::from_ptr(layer_ptr)) - } - - pub fn add_pooling( - &mut self, - input: &Tensor, - pooling_type: trtx_sys::PoolingType, - window_size: &[i32; 2], - ) -> Result { - let window_dims = trtx_sys::Dims::new_2d(window_size[0] as i64, window_size[1] as i64); - let network_ref = - unsafe { &mut *(self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition) }; - let mut network_pin = unsafe { std::pin::Pin::new_unchecked(network_ref) }; - let input_ref = unsafe { &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor) }; - let mut input_pin = unsafe { std::pin::Pin::new_unchecked(input_ref) }; - let layer_ptr = network_pin.as_mut().addPoolingNd( - input_pin.as_mut(), - pooling_type.into(), - &window_dims, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add pooling layer".to_string())); - } - Ok(PoolingLayer::from_ptr(layer_ptr as *mut _)) - } - - pub fn add_shuffle(&mut self, input: &Tensor) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addShuffle(std::pin::Pin::new_unchecked(input_ref)); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add shuffle layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ShuffleLayer::from_ptr(layer_ptr)) - } - - pub fn add_matrix_multiply( - &mut self, - input0: &Tensor, - op0: MatrixOperation, - input1: &Tensor, - op1: MatrixOperation, - ) -> Result { - let layer_ptr = unsafe { - let input0_ref = &mut *(input0.inner as *mut trtx_sys::nvinfer1::ITensor); - let input1_ref = &mut *(input1.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addMatrixMultiply( - std::pin::Pin::new_unchecked(input0_ref), - op0.into(), - std::pin::Pin::new_unchecked(input1_ref), - op1.into(), - ); - if layer_ptr.is_null() { - return Err(Error::Runtime( - "Failed to add matrix multiply layer".to_string(), - )); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(MatrixMultiplyLayer::from_ptr(layer_ptr)) - } - - pub fn add_convolution( - &mut self, - input: &Tensor, - nb_output_maps: i32, - kernel_size: &[i32; 2], - weights: &ConvWeights<'_>, - ) -> Result { - let kernel_dtype = weights.kernel_dtype; - let kernel_weights = weights.kernel_weights; - let bias_weights = weights.bias_weights; - let bias_dtype = weights.bias_dtype; - let kernel_bpe = match kernel_dtype { - DataType::kFLOAT => 4, - DataType::kHALF => 2, - DataType::kINT8 => 1, - DataType::kINT32 => 4, - _ => { - return Err(Error::Runtime(format!( - "Unsupported kernel weight type for convolution: {kernel_dtype:?}", - ))) - } - }; - let weight_count = (kernel_weights.len() / kernel_bpe) as i64; - let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype); - let bias_bpe = match bias_dtype_val { - DataType::kFLOAT => 4, - DataType::kHALF => 2, - DataType::kINT8 => 1, - DataType::kINT32 => 4, - _ => { - return Err(Error::Runtime(format!( - "Unsupported bias weight type for convolution: {bias_dtype_val:?}", - ))) - } - }; - let bias_count = bias_weights - .map(|b| (b.len() / bias_bpe) as i64) - .unwrap_or(0); - let kernel_ptr = if weight_count > 0 { - kernel_weights.as_ptr() as *const std::ffi::c_void - } else { - std::ptr::null() - }; - let bias_ptr = if bias_count > 0 { - bias_weights - .map(|b| b.as_ptr() as *const std::ffi::c_void) - .unwrap_or(std::ptr::null()) - } else { - std::ptr::null() - }; - let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64); - let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type( - kernel_dtype.into(), - kernel_ptr, - weight_count, - ); - let bias_w = - trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count); - let layer_ptr = unsafe { - let network_ptr = self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition; - let mut network_pin = std::pin::Pin::new_unchecked(&mut *network_ptr); - let input_ptr = input.inner as *mut trtx_sys::nvinfer1::ITensor; - let input_ref = std::pin::Pin::new_unchecked(&mut *input_ptr); - network_pin.as_mut().addConvolutionNd( - input_ref, - nb_output_maps as i64, - &kernel_dims, - kernel_w, - bias_w, - ) as *mut std::ffi::c_void - }; - if layer_ptr.is_null() { - return Err(Error::Runtime( - "Failed to add convolution layer".to_string(), - )); - } - Ok(ConvolutionLayer::from_ptr(layer_ptr)) - } - - /// Add a 2D deconvolution layer. Same input semantics as convolution: input 0 = activation, - /// input 1 = kernel tensor (use set_input(1, tensor) when kernel_weights is empty), - /// input 2 = bias tensor (use set_input(2, tensor) when bias_weights is None/empty). - pub fn add_deconvolution( - &mut self, - input: &Tensor, - nb_output_maps: i32, - kernel_size: &[i32; 2], - weights: &ConvWeights<'_>, - ) -> Result { - let kernel_dtype = weights.kernel_dtype; - let kernel_weights = weights.kernel_weights; - let bias_weights = weights.bias_weights; - let bias_dtype = weights.bias_dtype; - let kernel_bpe = match kernel_dtype { - DataType::kFLOAT => 4, - DataType::kHALF => 2, - DataType::kINT8 => 1, - DataType::kINT32 => 4, - _ => { - return Err(Error::Runtime(format!( - "Unsupported kernel weight type for deconvolution: {kernel_dtype:?}", - ))) - } - }; - let weight_count = (kernel_weights.len() / kernel_bpe) as i64; - let bias_dtype_val = bias_dtype.unwrap_or(kernel_dtype); - let bias_bpe = match bias_dtype_val { - DataType::kFLOAT => 4, - DataType::kHALF => 2, - DataType::kINT8 => 1, - DataType::kINT32 => 4, - _ => { - return Err(Error::Runtime(format!( - "Unsupported bias weight type for deconvolution: {bias_dtype:?}", - ))) - } - }; - let bias_count = bias_weights - .map(|b| (b.len() / bias_bpe) as i64) - .unwrap_or(0); - let kernel_ptr = if weight_count > 0 { - kernel_weights.as_ptr() as *const std::ffi::c_void - } else { - std::ptr::null() - }; - let bias_ptr = if bias_count > 0 { - bias_weights - .map(|b| b.as_ptr() as *const std::ffi::c_void) - .unwrap_or(std::ptr::null()) - } else { - std::ptr::null() - }; - let kernel_dims = trtx_sys::Dims::new_2d(kernel_size[0] as i64, kernel_size[1] as i64); - let kernel_w = trtx_sys::nvinfer1::Weights::new_with_type( - kernel_dtype.into(), - kernel_ptr, - weight_count, - ); - let bias_w = - trtx_sys::nvinfer1::Weights::new_with_type(bias_dtype_val.into(), bias_ptr, bias_count); - let layer_ptr = unsafe { - let network_ptr = self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition; - let mut network_pin = std::pin::Pin::new_unchecked(&mut *network_ptr); - let input_ptr = input.inner as *mut trtx_sys::nvinfer1::ITensor; - let input_ref = std::pin::Pin::new_unchecked(&mut *input_ptr); - network_pin.as_mut().addDeconvolutionNd( - input_ref, - nb_output_maps as i64, - kernel_dims, - kernel_w, - bias_w, - ) as *mut std::ffi::c_void - }; - if layer_ptr.is_null() { - return Err(Error::Runtime( - "Failed to add deconvolution layer".to_string(), - )); - } - Ok(DeconvolutionLayer::from_ptr(layer_ptr)) - } - - pub fn add_concatenation(&mut self, inputs: &[&Tensor]) -> Result { - let mut input_ptrs: Vec<*mut std::ffi::c_void> = inputs.iter().map(|t| t.inner).collect(); - let layer_ptr = unsafe { - trtx_sys::network_add_concatenation( - self.inner, - input_ptrs.as_mut_ptr(), - inputs.len() as i32, - ) - }; - if layer_ptr.is_null() { - return Err(Error::Runtime( - "Failed to add concatenation layer".to_string(), - )); - } - Ok(ConcatenationLayer::from_ptr(layer_ptr)) - } - - pub fn add_constant( - &mut self, - dims: &[i32], - weights: &[u8], - data_type: trtx_sys::DataType, - ) -> Result { - use trtx_sys::DataType; - let element_count: i64 = dims.iter().map(|&d| d as i64).product(); - let bytes_per_element = match data_type { - DataType::kFLOAT => 4, - DataType::kHALF => 2, - DataType::kINT8 => 1, - DataType::kINT32 => 4, - DataType::kUINT8 => 1, - DataType::kBOOL => 1, - _ => { - return Err(Error::Runtime(format!( - "Unsupported data type: {data_type:?}", - ))) - } - }; - let expected_bytes = element_count * bytes_per_element; - if weights.len() as i64 != expected_bytes { - return Err(Error::Runtime(format!( - "Weight size mismatch: expected {} bytes, got {} bytes", - expected_bytes, - weights.len() - ))); - } - let dims_i64: Vec = dims.iter().map(|&d| d as i64).collect(); - let dims_struct = trtx_sys::Dims::from_slice(&dims_i64); - let weights_struct = trtx_sys::nvinfer1::Weights::new_with_type( - data_type.into(), - weights.as_ptr() as *const std::ffi::c_void, - element_count, - ); - let layer_ptr = unsafe { - let network_ptr = self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition; - let mut network_pin = std::pin::Pin::new_unchecked(&mut *network_ptr); - network_pin - .as_mut() - .addConstant(&dims_struct, weights_struct) as *mut std::ffi::c_void - }; - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add constant tensor".to_string())); - } - Ok(ConstantLayer::from_ptr(layer_ptr)) - } - - pub fn add_softmax(&mut self, input: &Tensor, axes: u32) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addSoftMax(std::pin::Pin::new_unchecked(input_ref)); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add softmax layer".to_string())); - } - let mut layer_pin = std::pin::Pin::new_unchecked(&mut *layer_ptr); - layer_pin.as_mut().setAxes(axes); - layer_ptr as *mut std::ffi::c_void - }; - Ok(SoftMaxLayer::from_ptr(layer_ptr)) - } - - pub fn add_scale( - &mut self, - input: &Tensor, - mode: ScaleMode, - shift: &[u8], - scale: &[u8], - power: &[u8], - ) -> Result { - let weight_count = match mode { - ScaleMode::kUNIFORM => 1i64, - ScaleMode::kCHANNEL => { - let input_dims = input.dimensions()?; - if input_dims.len() >= 4 { - input_dims[1] as i64 - } else if !input_dims.is_empty() { - input_dims[0] as i64 - } else { - 1i64 - } - } - ScaleMode::kELEMENTWISE => { - let input_dims = input.dimensions()?; - input_dims.iter().map(|&d| d as i64).product::() - } - }; - - let shift_w = trtx_sys::nvinfer1::Weights::new_float( - shift.as_ptr() as *const std::ffi::c_void, - weight_count, - ); - let scale_w = trtx_sys::nvinfer1::Weights::new_float( - scale.as_ptr() as *const std::ffi::c_void, - weight_count, - ); - let power_w = trtx_sys::nvinfer1::Weights::new_float( - power.as_ptr() as *const std::ffi::c_void, - weight_count, - ); - let layer_ptr = unsafe { - let network_ptr = self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition; - let mut network_pin = std::pin::Pin::new_unchecked(&mut *network_ptr); - let input_ptr = input.inner as *mut trtx_sys::nvinfer1::ITensor; - let input_ref = std::pin::Pin::new_unchecked(&mut *input_ptr); - network_pin - .as_mut() - .addScale(input_ref, mode.into(), shift_w, scale_w, power_w) - as *mut std::ffi::c_void - }; - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add scale layer".to_string())); - } - Ok(ScaleLayer::from_ptr(layer_ptr)) - } - - pub fn add_reduce( - &mut self, - input: &Tensor, - op: trtx_sys::nvinfer1::ReduceOperation, - axes: u32, - keep_dims: bool, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addReduce( - std::pin::Pin::new_unchecked(input_ref), - op, - axes, - keep_dims, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add reduce layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ReduceLayer::from_ptr(layer_ptr)) - } - - pub fn add_cumulative( - &mut self, - input: &Tensor, - axis: i32, - op: trtx_sys::CumulativeOperation, - exclusive: bool, - reverse: bool, - ) -> Result { - let axis_bytes = axis.to_le_bytes(); - let axis_constant = self.add_constant(&[], &axis_bytes, trtx_sys::DataType::kINT32)?; - let axis_tensor = axis_constant.get_output(0)?; - self.add_cumulative_with_axis_tensor(input, &axis_tensor, op, exclusive, reverse) - } - - pub fn add_cumulative_with_axis_tensor( - &mut self, - input: &Tensor, - axis_tensor: &Tensor, - op: trtx_sys::CumulativeOperation, - exclusive: bool, - reverse: bool, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let axis_ref = &mut *(axis_tensor.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addCumulative( - std::pin::Pin::new_unchecked(input_ref), - std::pin::Pin::new_unchecked(axis_ref), - op.into(), - exclusive, - reverse, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add cumulative layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(CumulativeLayer::from_ptr(layer_ptr)) - } - - pub fn add_slice( - &mut self, - input: &Tensor, - start: &[i32], - size: &[i32], - stride: &[i32], - ) -> Result { - if start.len() != size.len() || start.len() != stride.len() { - return Err(Error::Runtime( - "start, size, and stride must have the same length".to_string(), - )); - } - let start_i64: Vec = start.iter().map(|&d| d as i64).collect(); - let size_i64: Vec = size.iter().map(|&d| d as i64).collect(); - let stride_i64: Vec = stride.iter().map(|&d| d as i64).collect(); - let start_dims = trtx_sys::Dims::from_slice(&start_i64); - let size_dims = trtx_sys::Dims::from_slice(&size_i64); - let stride_dims = trtx_sys::Dims::from_slice(&stride_i64); - let network_ref = - unsafe { &mut *(self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition) }; - let mut network_pin = unsafe { std::pin::Pin::new_unchecked(network_ref) }; - let input_ref = unsafe { &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor) }; - let mut input_pin = unsafe { std::pin::Pin::new_unchecked(input_ref) }; - let layer_ptr = network_pin.as_mut().addSlice( - input_pin.as_mut(), - &start_dims, - &size_dims, - &stride_dims, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add slice layer".to_string())); - } - Ok(SliceLayer::from_ptr(layer_ptr as *mut _)) - } - - pub fn add_resize(&mut self, input: &Tensor) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin - .as_mut() - .addResize(std::pin::Pin::new_unchecked(input_ref)); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add resize layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ResizeLayer::from_ptr(layer_ptr)) - } - - pub fn add_topk( - &mut self, - input: &Tensor, - op: TopKOperation, - k: i32, - axes: u32, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addTopK( - std::pin::Pin::new_unchecked(input_ref), - op.into(), - k, - axes, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add topk layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(TopKLayer::from_ptr(layer_ptr)) - } - - pub fn add_gather( - &mut self, - data: &Tensor, - indices: &Tensor, - axis: i32, - ) -> Result { - let layer_ptr = unsafe { - let data_ref = &mut *(data.inner as *mut trtx_sys::nvinfer1::ITensor); - let indices_ref = &mut *(indices.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addGather( - std::pin::Pin::new_unchecked(data_ref), - std::pin::Pin::new_unchecked(indices_ref), - axis, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add gather layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(GatherLayer::from_ptr(layer_ptr)) - } - - pub fn add_scatter( - &mut self, - data: &Tensor, - indices: &Tensor, - updates: &Tensor, - mode: trtx_sys::nvinfer1::ScatterMode, - ) -> Result { - let layer_ptr = unsafe { - let data_ref = &mut *(data.inner as *mut trtx_sys::nvinfer1::ITensor); - let indices_ref = &mut *(indices.inner as *mut trtx_sys::nvinfer1::ITensor); - let updates_ref = &mut *(updates.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addScatter( - std::pin::Pin::new_unchecked(data_ref), - std::pin::Pin::new_unchecked(indices_ref), - std::pin::Pin::new_unchecked(updates_ref), - mode, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add scatter layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(ScatterLayer::from_ptr(layer_ptr)) - } - - pub fn add_quantize( - &mut self, - input: &Tensor, - scale: &Tensor, - output_type: trtx_sys::nvinfer1::DataType, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let scale_ref = &mut *(scale.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addQuantize( - std::pin::Pin::new_unchecked(input_ref), - std::pin::Pin::new_unchecked(scale_ref), - output_type, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add quantize layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(QuantizeLayer::from_ptr(layer_ptr)) - } - - pub fn add_dequantize( - &mut self, - input: &Tensor, - scale: &Tensor, - output_type: trtx_sys::nvinfer1::DataType, - ) -> Result { - let layer_ptr = unsafe { - let input_ref = &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor); - let scale_ref = &mut *(scale.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addDequantize( - std::pin::Pin::new_unchecked(input_ref), - std::pin::Pin::new_unchecked(scale_ref), - output_type, - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add dequantize layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(DequantizeLayer::from_ptr(layer_ptr)) - } - - pub fn add_select( - &mut self, - condition: &Tensor, - then_input: &Tensor, - else_input: &Tensor, - ) -> Result { - let layer_ptr = unsafe { - let condition_ref = &mut *(condition.inner as *mut trtx_sys::nvinfer1::ITensor); - let then_ref = &mut *(then_input.inner as *mut trtx_sys::nvinfer1::ITensor); - let else_ref = &mut *(else_input.inner as *mut trtx_sys::nvinfer1::ITensor); - let mut network_pin = crate::autocxx_helpers::cast_and_pin::< - trtx_sys::nvinfer1::INetworkDefinition, - >(self.inner); - let layer_ptr = network_pin.as_mut().addSelect( - std::pin::Pin::new_unchecked(condition_ref), - std::pin::Pin::new_unchecked(then_ref), - std::pin::Pin::new_unchecked(else_ref), - ); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add select layer".to_string())); - } - layer_ptr as *mut std::ffi::c_void - }; - Ok(SelectLayer::from_ptr(layer_ptr)) - } - - pub fn add_padding( - &mut self, - input: &Tensor, - pre_padding: &[i32], - post_padding: &[i32], - ) -> Result { - if pre_padding.len() != post_padding.len() { - return Err(Error::Runtime( - "pre_padding and post_padding must have the same length".to_string(), - )); - } - let pre_i64: Vec = pre_padding.iter().map(|&d| d as i64).collect(); - let post_i64: Vec = post_padding.iter().map(|&d| d as i64).collect(); - let pre_dims = trtx_sys::Dims::from_slice(&pre_i64); - let post_dims = trtx_sys::Dims::from_slice(&post_i64); - let network_ref = - unsafe { &mut *(self.inner as *mut trtx_sys::nvinfer1::INetworkDefinition) }; - let mut network_pin = unsafe { std::pin::Pin::new_unchecked(network_ref) }; - let input_ref = unsafe { &mut *(input.inner as *mut trtx_sys::nvinfer1::ITensor) }; - let mut input_pin = unsafe { std::pin::Pin::new_unchecked(input_ref) }; - let layer_ptr = - network_pin - .as_mut() - .addPaddingNd(input_pin.as_mut(), &pre_dims, &post_dims); - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add padding layer".to_string())); - } - Ok(PaddingLayer::from_ptr(layer_ptr as *mut _)) - } - - pub fn add_assertion(&mut self, condition: &Tensor, message: &str) -> Result<()> { - let message_cstr = std::ffi::CString::new(message)?; - let layer_ptr = unsafe { - trtx_sys::network_add_assertion(self.inner, condition.inner, message_cstr.as_ptr()) - }; - if layer_ptr.is_null() { - return Err(Error::Runtime("Failed to add assertion layer".to_string())); - } - Ok(()) - } - - pub fn add_loop(&mut self) -> Result<*mut std::ffi::c_void> { - let loop_ptr = unsafe { trtx_sys::network_add_loop(self.inner) }; - if loop_ptr.is_null() { - return Err(Error::Runtime("Failed to add loop".to_string())); - } - Ok(loop_ptr) - } - - pub fn add_if_conditional(&mut self) -> Result<*mut std::ffi::c_void> { - let if_ptr = unsafe { trtx_sys::network_add_if_conditional(self.inner) }; - if if_ptr.is_null() { - return Err(Error::Runtime("Failed to add if conditional".to_string())); - } - Ok(if_ptr) - } -} - -impl Drop for NetworkDefinition { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_network(self.inner); - } - } - } -} - -unsafe impl Send for NetworkDefinition {} diff --git a/trtx/src/real/onnx_parser.rs b/trtx/src/real/onnx_parser.rs deleted file mode 100644 index 3f06915..0000000 --- a/trtx/src/real/onnx_parser.rs +++ /dev/null @@ -1,110 +0,0 @@ -//! Real TensorRT ONNX parser implementation - -use crate::error::{Error, Result}; -use crate::logger::Logger; -use crate::network::NetworkDefinition; - -/// ONNX parser (real mode) -pub struct OnnxParser { - inner: *mut std::ffi::c_void, -} - -impl OnnxParser { - #[cfg(not(any( - feature = "link_tensorrt_onnxparser", - feature = "dlopen_tensorrt_onnxparser" - )))] - pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result { - Err(Error::TrtOnnxParserLibraryNotLoaded) - } - - #[cfg(any( - feature = "link_tensorrt_onnxparser", - feature = "dlopen_tensorrt_onnxparser" - ))] - pub fn new(network: &mut NetworkDefinition, logger: &Logger) -> Result { - let network_ptr = network.as_mut_ptr(); - let logger_ptr = logger.as_logger_ptr(); - let parser_ptr = { - #[cfg(feature = "link_tensorrt_onnxparser")] - unsafe { - trtx_sys::create_onnx_parser(network_ptr, logger_ptr) - } - #[cfg(not(feature = "link_tensorrt_onnxparser"))] - #[cfg(feature = "dlopen_tensorrt_rtx")] - unsafe { - use libloading::Symbol; - use std::ffi::c_void; - - use crate::TRT_ONNXPARSER_LIB; - - if !TRT_ONNXPARSER_LIB.read()?.is_some() { - crate::dynamically_load_tensorrt_onnxparser(None::)?; - } - - let lock = TRT_ONNXPARSER_LIB - .read() - .map_err(|_| Error::LockPoisining)?; - let create_onnx_parser: Symbol *mut c_void> = - lock.as_ref() - .ok_or(Error::TrtOnnxParserLibraryNotLoaded)? - .get(b"createNvOnnxParser_INTERNAL")?; - create_onnx_parser(network_ptr, logger_ptr, trtx_sys::get_tensorrt_version()) - } - }; - if parser_ptr.is_null() { - return Err(Error::Runtime("Failed to create ONNX parser".to_string())); - } - Ok(OnnxParser { inner: parser_ptr }) - } - - pub fn parse(&self, model_bytes: &[u8]) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid parser".to_string())); - } - let success = unsafe { - trtx_sys::parser_parse( - self.inner, - model_bytes.as_ptr() as *const std::ffi::c_void, - model_bytes.len(), - ) - }; - if !success { - let error_msg = unsafe { - let num_errors = trtx_sys::parser_get_nb_errors(self.inner); - if num_errors > 0 { - let err_ptr = trtx_sys::parser_get_error(self.inner, 0); - if !err_ptr.is_null() { - let desc_ptr = trtx_sys::parser_error_desc(err_ptr); - if !desc_ptr.is_null() { - std::ffi::CStr::from_ptr(desc_ptr) - .to_str() - .unwrap_or("Failed to parse ONNX model") - .to_string() - } else { - "Failed to parse ONNX model".to_string() - } - } else { - "Failed to parse ONNX model".to_string() - } - } else { - "Failed to parse ONNX model".to_string() - } - }; - return Err(Error::Runtime(error_msg)); - } - Ok(()) - } -} - -impl Drop for OnnxParser { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_parser(self.inner); - } - } - } -} - -unsafe impl Send for OnnxParser {} diff --git a/trtx/src/real/runtime.rs b/trtx/src/real/runtime.rs deleted file mode 100644 index 93dd54b..0000000 --- a/trtx/src/real/runtime.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! Real TensorRT runtime implementation - -use crate::error::{Error, Result}; -use crate::logger::Logger; -use std::ffi::CStr; - -/// CUDA engine (real mode) -pub struct CudaEngine { - inner: *mut std::ffi::c_void, -} - -impl CudaEngine { - pub fn get_nb_io_tensors(&self) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid engine".to_string())); - } - let count = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .getNbIOTensors() - }; - Ok(count) - } - - pub fn get_tensor_name(&self, index: i32) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid engine".to_string())); - } - let name_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .getIOTensorName(index) - }; - if name_ptr.is_null() { - return Err(Error::InvalidArgument("Invalid tensor index".to_string())); - } - Ok(unsafe { CStr::from_ptr(name_ptr) }.to_str()?.to_string()) - } - - pub fn get_tensor_shape(&self, name: &str) -> Result> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid engine".to_string())); - } - let name_cstr = std::ffi::CString::new(name)?; - let dims = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .getTensorShape(name_cstr.as_ptr()) - }; - let nb_dims = dims.nbDims as usize; - if nb_dims > 8 { - return Err(Error::Runtime("Tensor has too many dimensions".to_string())); - } - let mut shape = Vec::with_capacity(nb_dims); - for i in 0..nb_dims { - shape.push(dims.d[i]); - } - Ok(shape) - } - - /// Returns the data type of the tensor (e.g. kFLOAT, kHALF). - /// Required for correct buffer sizing and f32/f16 conversion when I/O uses half precision. - pub fn get_tensor_dtype(&self, name: &str) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid engine".to_string())); - } - let name_cstr = std::ffi::CString::new(name)?; - let dtype = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .getTensorDataType(name_cstr.as_ptr()) - }; - Ok(dtype) - } - - pub fn create_execution_context(&self) -> Result> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid engine".to_string())); - } - let context_ptr = unsafe { - crate::autocxx_helpers::cast_and_pin::(self.inner) - .createExecutionContext( - trtx_sys::nvinfer1::ExecutionContextAllocationStrategy::kSTATIC, - ) - }; - if context_ptr.is_null() { - return Err(Error::Runtime( - "Failed to create execution context".to_string(), - )); - } - Ok(ExecutionContext { - inner: context_ptr as *mut _, - _engine: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub(crate) fn as_ptr(&self) -> *const std::ffi::c_void { - self.inner - } -} - -impl Drop for CudaEngine { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_engine(self.inner); - } - } - } -} - -unsafe impl Send for CudaEngine {} -unsafe impl Sync for CudaEngine {} - -/// Execution context (real mode) -pub struct ExecutionContext<'a> { - inner: *mut std::ffi::c_void, - _engine: std::marker::PhantomData<&'a CudaEngine>, -} - -impl<'a> ExecutionContext<'a> { - /// Binds a tensor to a device memory address. - /// - /// # Safety - /// `data` must point to valid CUDA memory with at least the tensor's size in bytes, - /// and remain valid for the duration of inference. - pub unsafe fn set_tensor_address( - &mut self, - name: &str, - data: *mut std::ffi::c_void, - ) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid execution context".to_string())); - } - let name_cstr = std::ffi::CString::new(name)?; - let success = - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .setTensorAddress(name_cstr.as_ptr(), data as *mut _); - if !success { - return Err(Error::Runtime("Failed to set tensor address".to_string())); - } - Ok(()) - } - - /// Enqueues inference on the given CUDA stream. - /// - /// # Safety - /// `cuda_stream` must be a valid CUDA stream, and all tensor addresses must - /// point to valid device memory. - pub unsafe fn enqueue_v3(&mut self, cuda_stream: *mut std::ffi::c_void) -> Result<()> { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid execution context".to_string())); - } - let success = - crate::autocxx_helpers::cast_and_pin::( - self.inner, - ) - .enqueueV3(cuda_stream as *mut _); - if !success { - return Err(Error::Runtime("Failed to enqueue inference".to_string())); - } - Ok(()) - } -} - -impl Drop for ExecutionContext<'_> { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_context(self.inner); - } - } - } -} - -unsafe impl Send for ExecutionContext<'_> {} - -/// Runtime (real mode) -pub struct Runtime<'a> { - inner: *mut std::ffi::c_void, - _logger: &'a Logger, -} - -impl<'a> Runtime<'a> { - #[cfg(not(feature = "link_tensorrt_rtx"))] - #[cfg(not(feature = "dlopen_tensorrt_rtx"))] - pub fn new(logger: &'a Logger) -> Result { - Err(Error::TrtRtxLibraryNotLoaded) - } - - #[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))] - pub fn new(logger: &'a Logger) -> Result { - let logger_ptr = logger.as_logger_ptr(); - let runtime_ptr = { - #[cfg(feature = "link_tensorrt_rtx")] - unsafe { - trtx_sys::create_infer_builder(logger_ptr) - } - #[cfg(not(feature = "link_tensorrt_rtx"))] - #[cfg(feature = "dlopen_tensorrt_rtx")] - unsafe { - use libloading::Symbol; - use std::ffi::c_void; - - use crate::TRTLIB; - if !TRTLIB.read()?.is_some() { - crate::dynamically_load_tensorrt(None::)?; - } - - let lock = TRTLIB.read()?; - let create_infer_builder: Symbol *mut c_void> = lock - .as_ref() - .ok_or(Error::TrtRtxLibraryNotLoaded)? - .get(b"createInferRuntime_INTERNAL")?; - create_infer_builder(logger_ptr, trtx_sys::get_tensorrt_version()) - } - }; - if runtime_ptr.is_null() { - return Err(Error::Runtime("Failed to create runtime".to_string())); - } - Ok(Runtime { - inner: runtime_ptr, - _logger: logger, - }) - } - - pub fn deserialize_cuda_engine(&self, data: &[u8]) -> Result { - if self.inner.is_null() { - return Err(Error::Runtime("Invalid runtime".to_string())); - } - let engine_ptr = unsafe { - trtx_sys::runtime_deserialize_cuda_engine( - self.inner, - data.as_ptr() as *const _, - data.len(), - ) - }; - if engine_ptr.is_null() { - return Err(Error::Runtime("Failed to deserialize engine".to_string())); - } - Ok(CudaEngine { - inner: engine_ptr as *mut _, - }) - } -} - -impl Drop for Runtime<'_> { - fn drop(&mut self) { - if !self.inner.is_null() { - unsafe { - trtx_sys::delete_runtime(self.inner); - } - } - } -} - -unsafe impl Send for Runtime<'_> {} diff --git a/trtx/src/refitter.rs b/trtx/src/refitter.rs new file mode 100644 index 0000000..8556a96 --- /dev/null +++ b/trtx/src/refitter.rs @@ -0,0 +1,540 @@ +use std::ffi::{CStr, CString}; +use std::marker::PhantomData; +use std::pin::Pin; + +use crate::interfaces::{ErrorRecorder, RecordError}; +use crate::{ + error::{Error, Result}, + CudaEngine, Logger, +}; +use autocxx::cxx::UniquePtr; +use trtx_sys::{nvinfer1, DataType}; + +pub struct Weights<'data, T> { + data: &'data [T], + data_type: DataType, +} + +impl Weights<'_, T> { + fn as_raw(&self) -> nvinfer1::Weights { + nvinfer1::Weights { + type_: self.data_type.into(), + values: self.data.as_ptr() as *const std::ffi::c_void, + count: (size_of_val(self.data) * 8 / self.data_type.size_bits()) as i64, + } + } +} + +pub struct Refitter<'logger, 'engine> { + inner: UniquePtr, + error_recorder: Option>>, + _logger: PhantomData<&'logger Logger>, + _engine: PhantomData<&'engine CudaEngine<'engine>>, +} + +impl<'logger, 'engine> Refitter<'logger, 'engine> { + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(not(feature = "dlopen_tensorrt_rtx"))] + pub fn new(_cuda_engine: &'engine CudaEngine, _logger: &'logger Logger) -> Result { + Err(Error::TrtRtxLibraryNotLoaded) + } + + #[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))] + pub fn new(cuda_engine: &'engine CudaEngine, logger: &'logger Logger) -> Result { + #[cfg(not(feature = "mock"))] + { + let logger_ptr = logger.as_logger_ptr(); + let engine_ptr = cuda_engine.inner.as_mut_ptr() as *mut std::ffi::c_void; + let refitter = { + #[cfg(feature = "link_tensorrt_rtx")] + unsafe { + trtx_sys::create_infer_refitter(engine_ptr, logger_ptr) + } + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(feature = "dlopen_tensorrt_rtx")] + unsafe { + use libloading::Symbol; + use std::ffi::c_void; + + use crate::TRTLIB; + if !TRTLIB.read()?.is_some() { + crate::dynamically_load_tensorrt(None::)?; + } + + let lock = TRTLIB.read()?; + let create_infer_refitter: Symbol< + fn(*mut c_void, *mut c_void, u32) -> *mut nvinfer1::IRefitter, + > = lock + .as_ref() + .ok_or(Error::TrtRtxLibraryNotLoaded)? + .get(b"createInferRefitter_INTERNAL")?; + create_infer_refitter(engine_ptr, logger_ptr, trtx_sys::get_tensorrt_version()) + } + }; + if refitter.is_null() { + return Err(Error::Runtime("Failed to create refitter".to_string())); + } + Ok(Self { + inner: unsafe { UniquePtr::from_raw(refitter) }, + error_recorder: None, + _engine: Default::default(), + _logger: Default::default(), + }) + } + #[cfg(feature = "mock")] + Ok(Refitter { + inner: UniquePtr::null(), + error_recorder: None, + _engine: Default::default(), + _logger: Default::default(), + }) + } + + /// Specify new weights for a layer of given name and role. + pub fn set_weights( + &mut self, + layer_name: &str, + role: nvinfer1::WeightsRole, + weights: Weights<'engine, T>, + ) -> Result<()> { + let name_cstr = CString::new(layer_name)?; + if unsafe { + self.inner + .pin_mut() + .setWeights(name_cstr.as_ptr(), role, weights.as_raw()) + } { + Ok(()) + } else { + Err(Error::Runtime( + "setWeights rejected (invalid layer/role/count/type)".to_string(), + )) + } + } + + /// Refit the associated engine. Returns an error if validation fails or there are missing weights. + pub fn refit_cuda_engine(&mut self) -> Result<()> { + if self.inner.pin_mut().refitCudaEngine() { + Ok(()) + } else { + Err(Error::Runtime( + "refitCudaEngine failed (validation or getMissingWeights != 0)".to_string(), + )) + } + } + + /// Get descriptions of missing weights (layer name + role). Call with a size to limit results. + pub fn get_missing(&self, max_count: i32) -> Result> { + let n = max_count.max(0) as usize; + let mut layer_names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n]; + let mut roles: Vec = vec![0; n]; + let refitter_ptr = self.refitter_ptr(); + let count = unsafe { + trtx_sys::trtx_refitter_get_missing( + refitter_ptr, + n as i32, + layer_names.as_mut_ptr(), + roles.as_mut_ptr(), + ) + }; + let count = count.max(0) as usize; + let mut out = Vec::with_capacity(count); + for i in 0..count.min(layer_names.len()) { + let ptr = layer_names[i]; + if ptr.is_null() { + break; + } + let s = unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string(); + let role = unsafe { std::mem::transmute::(roles[i]) }; + out.push((s, role)); + } + Ok(out) + } + + /// Get descriptions of all weights that could be refit (layer name + role). + pub fn get_all(&self, max_count: i32) -> Result> { + let n = max_count.max(0) as usize; + let mut layer_names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n]; + let mut roles: Vec = vec![0; n]; + let refitter_ptr = self.refitter_ptr(); + let count = unsafe { + trtx_sys::trtx_refitter_get_all( + refitter_ptr, + n as i32, + layer_names.as_mut_ptr(), + roles.as_mut_ptr(), + ) + }; + let count = count.max(0) as usize; + let mut out = Vec::with_capacity(count); + for i in 0..count.min(layer_names.len()) { + let ptr = layer_names[i]; + if ptr.is_null() { + break; + } + let s = unsafe { CStr::from_ptr(ptr) }.to_str()?.to_string(); + let role = unsafe { std::mem::transmute::(roles[i]) }; + out.push((s, role)); + } + Ok(out) + } + + /// See [nvinfer1::IRefitter::setErrorRecorder] + /// + /// The Rust bindings only allow setting the error recorder once + pub fn set_error_recorder(&mut self, error_recorder: Box) -> Result<()> { + let error_recorder = ErrorRecorder::new(error_recorder)?; + if self.error_recorder.is_some() { + // would need to make sure that we don't destroy a monitor still in use + // could offer this as an unsafe method for users who only set this when there is no + // build process active. Or we only accept a ref to progress monitor and force user + // via lifetimes to keep this alive for builder config lifetime + panic!("Setting a progress monitor more than once not supported at the moment"); + } + self.error_recorder = Some(error_recorder); + let rec = self + .error_recorder + .as_mut() + .unwrap() + .as_trt_error_recorder(); + #[cfg(not(feature = "mock"))] + unsafe { + self.inner.pin_mut().setErrorRecorder(rec) + }; + Ok(()) + } + + /// Get the assigned error recorder, or null if none. + pub fn get_error_recorder(&self) -> *mut nvinfer1::IErrorRecorder { + self.inner.getErrorRecorder() + } + + /// Specify new weights by name (host location by default). + pub fn set_named_weights( + &mut self, + name: &str, + weights: &Weights<'engine, T>, + ) -> Result<()> { + let name_cstr = CString::new(name)?; + if unsafe { + self.inner + .pin_mut() + .setNamedWeights(name_cstr.as_ptr(), weights.as_raw()) + } { + Ok(()) + } else { + Err(Error::Runtime( + "setNamedWeights rejected (invalid name/count/type)".to_string(), + )) + } + } + + /// Get names of missing weights. + pub fn get_missing_weights(&self, max_count: i32) -> Result> { + let n = max_count.max(0) as usize; + let mut names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n]; + let count = unsafe { + trtx_sys::trtx_refitter_get_missing_weights( + self.refitter_ptr(), + n as i32, + names.as_mut_ptr(), + ) + }; + let count = count.max(0) as usize; + let out = names + .iter() + .take(count.min(names.len())) + .take_while(|n| !n.is_null()) + .map(|n| unsafe { CStr::from_ptr(*n).to_string_lossy().to_string() }) + .collect(); + Ok(out) + } + + /// Get names of all weights that could be refit. + pub fn get_all_weights(&self, max_count: i32) -> Result> { + let n = max_count.max(0) as usize; + let mut names: Vec<*const std::ffi::c_char> = vec![std::ptr::null(); n]; + let count = unsafe { + trtx_sys::trtx_refitter_get_all_weights( + self.refitter_ptr(), + n as i32, + names.as_mut_ptr(), + ) + }; + let count = count.max(0) as usize; + let out = names + .iter() + .take(count.min(names.len())) + .take_while(|n| !n.is_null()) + .map(|n| unsafe { CStr::from_ptr(*n).to_string_lossy().to_string() }) + .collect(); + Ok(out) + } + + /// Raw pointer to the underlying IRefitter (for C wrappers). Caller must not use after Refitter is dropped. + fn refitter_ptr(&self) -> *mut std::ffi::c_void { + self.inner.as_ptr() as *mut std::ffi::c_void + } + + /// Get the logger with which the refitter was created. Returns raw pointer. + pub fn get_logger(&self) -> *mut nvinfer1::ILogger { + self.inner.getLogger() + } + + /// Set the maximum number of threads used by the refitter. + pub fn set_max_threads(&mut self, max_threads: i32) -> Result<()> { + if self.inner.pin_mut().setMaxThreads(max_threads) { + Ok(()) + } else { + Err(Error::InvalidArgument("setMaxThreads failed".to_string())) + } + } + + /// Get the maximum number of threads that can be used by the refitter. + pub fn get_max_threads(&self) -> i32 { + self.inner.getMaxThreads() + } + + /// Specify new weights by name with explicit host/device location. + /// + /// # Safety + /// data of weights must by a valid pointer with correct size and alignment for the data type + /// the pointer must correspond to location and must be valid while being used by TensorRT + pub unsafe fn set_named_weights_with_location( + &mut self, + name: &str, + weights: nvinfer1::Weights, + location: nvinfer1::TensorLocation, + ) -> Result<()> { + let name_cstr = CString::new(name)?; + if unsafe { + self.inner + .pin_mut() + .setNamedWeights1(name_cstr.as_ptr(), weights, location) + } { + Ok(()) + } else { + Err(Error::Runtime( + "setNamedWeights (with location) rejected".to_string(), + )) + } + } + + /// Get weights currently associated with the given name. + pub fn get_named_weights(&self, weights_name: &str) -> nvinfer1::Weights { + let name_cstr = CString::new(weights_name).expect("name contains null"); + unsafe { self.inner.getNamedWeights(name_cstr.as_ptr()) } + } + + /// Get the location for weights associated with the given name. + pub fn get_weights_location(&self, weights_name: &str) -> nvinfer1::TensorLocation { + let name_cstr = CString::new(weights_name).expect("name contains null"); + unsafe { self.inner.getWeightsLocation(name_cstr.as_ptr()) } + } + + /// Unset weights for the given name. Returns false if they were never set. + pub fn unset_named_weights(&mut self, weights_name: &str) -> bool { + let name_cstr = CString::new(weights_name).expect("name contains null"); + unsafe { self.inner.pin_mut().unsetNamedWeights(name_cstr.as_ptr()) } + } + + /// Set whether to validate weights during refitting (default true). + pub fn set_weights_validation(&mut self, weights_validation: bool) { + self.inner + .pin_mut() + .setWeightsValidation(weights_validation); + } + + /// Get whether weights validation is enabled during refitting. + pub fn get_weights_validation(&self) -> bool { + self.inner.getWeightsValidation() + } + + /// Enqueue weights refitting on the given CUDA stream. + /// + /// # Safety + /// + /// `cuda_stream` must be a valid CUDA stream + pub unsafe fn refit_cuda_engine_async( + &mut self, + cuda_stream: *mut std::ffi::c_void, + ) -> Result<()> { + if self + .inner + .pin_mut() + .refitCudaEngineAsync(cuda_stream as *mut _) + { + Ok(()) + } else { + Err(Error::Runtime( + "refitCudaEngineAsync failed (validation or getMissingWeights != 0)".to_string(), + )) + } + } + + /// Get the weights prototype (type and count) for the given name. Values pointer is null. + pub fn get_weights_prototype(&self, weights_name: &str) -> nvinfer1::Weights { + let name_cstr = CString::new(weights_name).expect("name contains null"); + unsafe { self.inner.getWeightsPrototype(name_cstr.as_ptr()) } + } +} + +#[cfg(test)] +#[cfg(not(feature = "mock"))] +mod tests { + use crate::interfaces::RecordError; + use std::sync::atomic::{AtomicI32, Ordering}; + use std::sync::{Arc, Mutex}; + use trtx_sys::BuilderFlag; + use trtx_sys::ErrorCode; + + use super::*; + use crate::builder::MemoryPoolType; + use crate::refitter::Weights; + use crate::{Builder, DataType, Logger, Runtime}; + + /// Error recorder that collects reported errors into a shared `Vec<(ErrorCode, String)>`. + struct VecErrorRecorder { + messages: Arc>>, + ref_count: AtomicI32, + } + + impl VecErrorRecorder { + fn new(messages: Arc>>) -> Self { + Self { + messages, + ref_count: AtomicI32::new(0), + } + } + } + + impl RecordError for VecErrorRecorder { + fn nb_errors(&self) -> i32 { + self.messages.lock().unwrap().len() as i32 + } + fn error_code(&self, error_idx: i32) -> ErrorCode { + self.messages.lock().unwrap()[error_idx as usize].0 + } + fn error_desc(&self, _error_idx: i32) -> &CStr { + static EMPTY: &[u8] = b"\0"; + unsafe { CStr::from_bytes_with_nul_unchecked(EMPTY) } + } + fn has_overflowed(&self) -> bool { + false + } + fn clear(&self) { + self.messages.lock().unwrap().clear(); + } + fn report_error(&self, val: ErrorCode, desc: &str) -> bool { + self.messages.lock().unwrap().push((val, desc.to_string())); + true + } + fn inc_ref_count(&self) -> i32 { + self.ref_count.fetch_add(1, Ordering::SeqCst) + 1 + } + fn dec_ref_count(&self) -> i32 { + self.ref_count.fetch_sub(1, Ordering::SeqCst) - 1 + } + } + + /// Build a minimal network with one refittable constant layer: constant [1,4] -> output. + fn build_constant_network(logger: &Logger) -> Result> { + let mut builder = Builder::new(logger)?; + let mut network = builder.create_network(0)?; + + // Single constant layer with 4 floats, named for refitter lookup + let dims = [1, 4]; + let initial: [f32; 4] = [1.0, 2.0, 3.0, 4.0]; + let weights_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + initial.as_ptr() as *const u8, + initial.len() * std::mem::size_of::(), + ) + }; + let mut const_layer = network.add_constant(&dims, weights_bytes, DataType::kFLOAT)?; + const_layer.set_name(&mut network, "refit_const")?; + + let mut output = const_layer.get_output(&network, 0)?; + output.set_name(&mut network, "output")?; + network.mark_output(&mut output); + + let mut config = builder.create_config()?; + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 24); + config.set_flag(BuilderFlag::kREFIT); + let unstripped_engine_data = builder.build_serialized_network(&mut network, &mut config)?; + config.set_flag(BuilderFlag::kSTRIP_PLAN); + let engine_data = builder.build_serialized_network(&mut network, &mut config)?; + assert!(engine_data.len() < unstripped_engine_data.len()); + Ok(engine_data.to_vec()) + } + + #[test] + fn refitter_from_constant_network() { + let logger = Logger::stderr().expect("logger"); + let engine_data = build_constant_network(&logger).expect("build network"); + assert!(!engine_data.is_empty()); + + let mut runtime = Runtime::new(&logger).expect("runtime"); + let engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize engine"); + + let mut refitter = Refitter::new(&engine, &logger).expect("refitter"); + + // Discover refittable weights + let all = refitter.get_all_weights(64).expect("get_all_weights"); + assert!( + !all.is_empty(), + "engine should have at least one refittable weight (constant layer)" + ); + let weight_name = &all[0]; + + // Prototype: type and count (values pointer is null) + let proto = refitter.get_weights_prototype(weight_name); + assert!(proto.count >= 0 || proto.count == -1); + // Refit with same shape: 4 floats + let new_vals: [f32; 4] = [10.0, 20.0, 30.0, 40.0]; + let new_weights = Weights { + data_type: DataType::kFLOAT, + data: &new_vals, + }; + refitter + .set_named_weights(weight_name, &new_weights) + .expect("set_named_weights"); + + refitter.refit_cuda_engine().expect("refit_cuda_engine"); + } + + #[test] + fn refitter_error_recorder_collects_invalid_weight_error() { + let logger = Logger::stderr().expect("logger"); + let engine_data = build_constant_network(&logger).expect("build network"); + assert!(!engine_data.is_empty()); + + let mut runtime = Runtime::new(&logger).expect("runtime"); + let engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize engine"); + + let mut refitter = Refitter::new(&engine, &logger).expect("refitter"); + + let weight_name = refitter.get_all_weights(64).expect("get_all_weights")[0].clone(); + + let errors: Arc>> = Arc::new(Mutex::new(Vec::new())); + let recorder = Box::new(VecErrorRecorder::new(Arc::clone(&errors))); + refitter.set_error_recorder(recorder).unwrap(); + + let wrong_weights = Weights { + data_type: DataType::kFLOAT, + data: &[1.0f32], + }; + let _ = refitter.set_named_weights(&weight_name, &wrong_weights); + + refitter.get_named_weights("nonexistent_weight_name"); + + let collected = errors.lock().unwrap(); + assert!( + !collected.is_empty(), + "error recorder should have collected at least one error (invalid weight or nonexistent name)" + ); + } +} diff --git a/trtx/src/runtime.rs b/trtx/src/runtime.rs index 71bee27..b0c45ce 100644 --- a/trtx/src/runtime.rs +++ b/trtx/src/runtime.rs @@ -1,8 +1,503 @@ //! Runtime for deserializing and managing TensorRT engines -//! -//! Delegates to real/ or mock/ based on feature flag. -#[cfg(feature = "mock")] -pub use crate::mock::runtime::*; +use std::ffi::CString; +use std::marker::PhantomData; +use std::pin::Pin; + +use cxx::UniquePtr; +use trtx_sys::nvinfer1; + +pub use crate::cuda_engine::CudaEngine; +pub use crate::engine_inspector::EngineInspector; +use crate::error::{Error, Result}; +use crate::interfaces::{DebugListener, ProcessDebugTensor}; +use crate::logger::Logger; + +/// Execution context for running inference +pub struct ExecutionContext<'a> { + inner: UniquePtr, + _engine: std::marker::PhantomData<&'a CudaEngine<'a>>, + debug_listener: Option>>, +} + +impl<'a> ExecutionContext<'a> { + pub(crate) unsafe fn from_ptr( + execution_context: *mut nvinfer1::IExecutionContext, + ) -> Result { + #[cfg(not(feature = "mock"))] + if execution_context.is_null() { + return Err(Error::Runtime( + "Failed to create ExecutionContext".to_string(), + )); + } + Ok(ExecutionContext { + inner: UniquePtr::from_raw(execution_context), + _engine: Default::default(), + debug_listener: None, + }) + } + + /// See [nvinfer1::IExecutionContext::setDebugListener]. + /// The Rust bindings only allow setting the debug listener once per execution context. + pub fn set_debug_listener(&mut self, listener: Box) -> Result<()> { + let debug_listener = DebugListener::new(listener)?; + if self.debug_listener.is_some() { + panic!("Setting a debug listener more than once not supported at the moment"); + } + self.debug_listener = Some(debug_listener); + #[cfg(not(feature = "mock"))] + { + let success = unsafe { + self.inner.pin_mut().setDebugListener( + self.debug_listener + .as_ref() + .expect("debug_listener can't be empty, we just set it") + .as_raw(), + ) + }; + if !success { + self.debug_listener = None; + return Err(Error::Runtime("setDebugListener failed".to_string())); + } + } + Ok(()) + } + + /// See [nvinfer1::IExecutionContext::setTensorDebugState]. + pub fn set_tensor_debug_state(&mut self, name: &str, flag: bool) -> Result<()> { + let name = CString::new(name)?; + if !unsafe { + self.inner + .pin_mut() + .setTensorDebugState(name.as_ptr(), flag) + } { + Err(Error::FailedToSetProperty( + crate::error::PropertySetAttempt::ExecutionContextTensorDebugState, + )) + } else { + Ok(()) + } + } + + /// See [nvinfer1::IExecutionContext::getDebugState]. + pub fn get_tensor_debug_state(&self, name: &str) -> Result { + let name = CString::new(name)?; + unsafe { Ok(self.inner.getDebugState(name.as_ptr())) } + } + + /// See [nvinfer1::IExecutionContext::setAllTensorsDebugState]. + pub fn set_all_tensors_debug_state(&mut self, flag: bool) -> Result<()> { + if !self.inner.pin_mut().setAllTensorsDebugState(flag) { + Err(Error::FailedToSetProperty( + crate::error::PropertySetAttempt::ExecutionContextTensorDebugState, + )) + } else { + Ok(()) + } + } + /// See [nvinfer1::IExecutionContext::setUnfusedTensorsDebugState]. + pub fn set_unfused_tensors_debug_state(&mut self, flag: bool) -> Result<()> { + if !self.inner.pin_mut().setUnfusedTensorsDebugState(flag) { + Err(Error::FailedToSetProperty( + crate::error::PropertySetAttempt::ExecutionContextTensorDebugState, + )) + } else { + Ok(()) + } + } + /// See [nvinfer1::IExecutionContext::getUnfusedTensorsDebugState]. + pub fn get_unfused_tensor_debug_state(&self) -> bool { + self.inner.getUnfusedTensorsDebugState() + } + + /// Binds a tensor to a device memory address. + /// + /// # Safety + /// `data` must point to valid CUDA memory with at least the tensor's size in bytes, + /// and remain valid for the duration of inference. + pub unsafe fn set_tensor_address( + &mut self, + name: &str, + data: *mut std::ffi::c_void, + ) -> Result<()> { + #[cfg(not(feature = "mock"))] + { + if self.inner.is_null() { + return Err(Error::Runtime("Invalid execution context".to_string())); + } + let name_cstr = std::ffi::CString::new(name)?; + let success = self + .inner + .pin_mut() + .setTensorAddress(name_cstr.as_ptr(), data as *mut _); + if !success { + return Err(Error::Runtime("Failed to set tensor address".to_string())); + } + } + Ok(()) + } + + /// Enqueues inference on the given CUDA stream. + /// + /// # Safety + /// `cuda_stream` must be a valid CUDA stream, and all tensor addresses must + /// point to valid device memory. + pub unsafe fn enqueue_v3(&mut self, cuda_stream: *mut std::ffi::c_void) -> Result<()> { + #[cfg(not(feature = "mock"))] + { + if self.inner.is_null() { + return Err(Error::Runtime("Invalid execution context".to_string())); + } + let success = self.inner.pin_mut().enqueueV3(cuda_stream as *mut _); + if !success { + return Err(Error::Runtime("Failed to enqueue inference".to_string())); + } + } + Ok(()) + } +} + +/// Runtime for deserializing engines +pub struct Runtime<'logger> { + inner: UniquePtr, + _logger: PhantomData<&'logger Logger>, +} + +impl<'runtime> Runtime<'runtime> { + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(not(feature = "dlopen_tensorrt_rtx"))] + pub fn new(logger: &'runtime Logger) -> Result { + Err(Error::TrtRtxLibraryNotLoaded) + } + + #[cfg(any(feature = "link_tensorrt_rtx", feature = "dlopen_tensorrt_rtx"))] + pub fn new(logger: &'runtime Logger) -> Result { + #[cfg(not(feature = "mock"))] + { + let logger_ptr = logger.as_logger_ptr(); + let runtime_ptr = { + #[cfg(feature = "link_tensorrt_rtx")] + unsafe { + trtx_sys::create_infer_runtime(logger_ptr) + } + #[cfg(not(feature = "link_tensorrt_rtx"))] + #[cfg(feature = "dlopen_tensorrt_rtx")] + unsafe { + use libloading::Symbol; + use std::ffi::c_void; + + use crate::TRTLIB; + if !TRTLIB.read()?.is_some() { + crate::dynamically_load_tensorrt(None::)?; + } + + let lock = TRTLIB.read()?; + let create_infer_runtime: Symbol *mut c_void> = lock + .as_ref() + .ok_or(Error::TrtRtxLibraryNotLoaded)? + .get(b"createInferRuntime_INTERNAL")?; + create_infer_runtime(logger_ptr, trtx_sys::get_tensorrt_version()) + } + } as *mut nvinfer1::IRuntime; + if runtime_ptr.is_null() { + return Err(Error::Runtime("Failed to create runtime".to_string())); + } + Ok(Runtime { + inner: unsafe { UniquePtr::from_raw(runtime_ptr) }, + _logger: Default::default(), + }) + } + #[cfg(feature = "mock")] + Ok(Runtime { + inner: UniquePtr::null(), + _logger: Default::default(), + }) + } + + pub fn deserialize_cuda_engine(&'_ mut self, data: &[u8]) -> Result> { + if cfg!(feature = "mock") { + Ok(unsafe { CudaEngine::from_ptr(std::ptr::null_mut()) }) + } else { + unsafe { + let engine = self.inner.pin_mut().deserializeCudaEngine( + data.as_ref().as_ptr() as *const autocxx::c_void, + data.len(), + ); + Ok(CudaEngine::from_ptr(engine.as_mut().ok_or_else(|| { + Error::Runtime("Failed to deserialize engine".to_string()) + })?)) + } + } + } + //pub fn deserialize_cuda_engine_v2( + //&'_ mut self, + //stream_reader: &'runtime mut StreamReaderV2, + //) -> Result> { + //if cfg!(feature = "mock") { + //Ok(unsafe { CudaEngine::from_ptr(std::ptr::null_mut()) }) + //} else { + //unsafe { + //let engine = self + //.inner + //.pin_mut() + //.deserializeCudaEngine1(stream_reader.pin_mut()); + //Ok(CudaEngine::from_ptr(engine.as_mut().ok_or_else(|| { + //Error::Runtime("Failed to deserialize engine".to_string()) + //})?)) + //} + //} + //} +} + +#[cfg(test)] #[cfg(not(feature = "mock"))] -pub use crate::real::runtime::*; +mod tests { + use std::sync::{Arc, Mutex}; + + use crate::builder::{Builder, MemoryPoolType}; + use crate::cuda::{synchronize, DeviceBuffer}; + use crate::interfaces::{ProcessDebugTensor, ProcessDebugTensorResult}; + use crate::logger::Logger; + use crate::{DataType, ElementWiseOperation, Runtime}; + use trtx_sys::{Dims64, TensorLocation}; + + /// Builds a network: input tensor_0 [1] -> +1 -> tensor_1 -> +1 -> tensor_2 -> +1 -> tensor_3 -> +1 -> tensor_4 (output). + /// Each intermediate tensor is named and marked for debug. + fn build_plus1_chain(logger: &Logger) -> crate::Result<(Vec, Vec)> { + let mut builder = Builder::new(logger)?; + let mut network = builder.create_network(0)?; + + let one_bytes = 1.0f32.to_le_bytes(); + let mut tensor = network.add_input("tensor_0", DataType::kFLOAT, &[1])?; + let mut debug_names = Vec::new(); + + for i in 1..=4 { + let one_layer = + network.add_small_constant_copied(&[1], &one_bytes, DataType::kFLOAT)?; + let one_t = one_layer.get_output(&network, 0)?; + let mut sum_layer = + network.add_elementwise(&tensor, &one_t, ElementWiseOperation::kSUM)?; + sum_layer.set_name(&mut network, &format!("plus1_{}", i))?; + tensor = sum_layer.get_output(&network, 0)?; + let name = format!("tensor_{}", i); + tensor.set_name(&mut network, &name)?; + network.mark_tensor_debug(&tensor)?; + assert!(network.is_debug_tensor(&tensor)); + debug_names.push(name); + } + network.mark_output(&tensor); + + let mut config = builder.create_config()?; + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 20); + //config.set_flag(trtx_sys::BuilderFlag::kDEBUG); + let engine_data = builder.build_serialized_network(&mut network, &mut config)?; + Ok((engine_data.to_vec(), debug_names)) + } + + type ExpectedResults = Vec<(String, Vec)>; + /// Debug listener that collects tensor names and shapes for verification. + struct CollectingDebugListener { + seen: Arc>, + } + + impl ProcessDebugTensor for CollectingDebugListener { + unsafe fn process_debug_tensor( + &self, + _addr: *const std::ffi::c_void, + _location: TensorLocation, + _type_: DataType, + shape: &Dims64, + name: Option<&str>, + _stream: *mut std::ffi::c_void, + ) -> ProcessDebugTensorResult { + let dims: Vec = shape + .d + .iter() + .take(shape.nbDims as usize) + .copied() + .collect(); + self.seen + .lock() + .unwrap() + .push((name.unwrap().to_string(), dims)); + Ok(()) + } + } + + /// Builds a small conv network: input [1,1,4,4] -> conv(1->4) -> conv(4->4) -> conv(4->4) -> output. + /// Each conv output is named and marked for debug. + fn build_conv_chain(logger: &Logger) -> crate::Result<(Vec, Vec)> { + // Declare kernel bytes before builder so their lifetime outlives 'network. + // conv0: out=4, in=1, 3x3 conv1/2: out=4, in=4, 3x3 + let make_kernel = |out_ch: usize, in_ch: usize| -> Vec { + std::iter::repeat_n(0.1f32, out_ch * in_ch * 3 * 3) + .flat_map(|v| v.to_le_bytes()) + .collect() + }; + let kernel_0 = make_kernel(4, 1); + let kernel_1 = make_kernel(4, 4); + let kernel_2 = make_kernel(4, 4); + + let mut builder = Builder::new(logger)?; + let mut network = builder.create_network(0)?; + + // Input: [N=1, C=1, H=4, W=4] — TensorRT conv requires at least 4D + let mut tensor = network.add_input("input", DataType::kFLOAT, &[1, 1, 4, 4])?; + let mut debug_names = Vec::new(); + + let conv_defs: [(i32, &Vec); 3] = [(4, &kernel_0), (4, &kernel_1), (4, &kernel_2)]; + for (i, &(out_ch, kbytes)) in conv_defs.iter().enumerate() { + let weights = crate::ConvWeights { + kernel_weights: kbytes, + kernel_dtype: DataType::kFLOAT, + bias_weights: None, + bias_dtype: None, + }; + let mut conv = network.add_convolution(&tensor, out_ch, &[3, 3], &weights)?; + conv.set_padding(&mut network, &[1i64, 1i64]); + let name = format!("conv_out_{}", i); + conv.set_name(&mut network, &name)?; + tensor = conv.get_output(&network, 0)?; + tensor.set_name(&mut network, &name)?; + network.mark_tensor_debug(&tensor)?; + debug_names.push(name); + } + network.mark_output(&tensor); + + let mut config = builder.create_config()?; + config.set_memory_pool_limit(MemoryPoolType::kWORKSPACE, 1 << 20); + let engine_data = builder.build_serialized_network(&mut network, &mut config)?; + Ok((engine_data.to_vec(), debug_names)) + } + + #[test] + #[ignore = "only works on TRT enterprise at the moment"] + fn set_debug_listener_conv_chain() { + let logger = Logger::stderr().expect("logger"); + let (engine_data, _debug_names) = build_conv_chain(&logger).expect("build conv network"); + + let mut runtime = Runtime::new(&logger).expect("runtime"); + let mut engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize"); + let mut context = engine + .create_execution_context() + .expect("execution context"); + + let seen = Arc::new(Mutex::new(Vec::<(String, Vec)>::new())); + context + .set_debug_listener(Box::new(CollectingDebugListener { + seen: Arc::clone(&seen), + })) + .expect("set_debug_listener"); + context.set_all_tensors_debug_state(true).unwrap(); + context.set_unfused_tensors_debug_state(true).unwrap(); + + // input: 1 channel 4x4, output: 4 channels 4x4 + let input_elems = 4 * 4; + let output_elems = 4 * 4 * 4; + let elem_size = std::mem::size_of::(); + let input_bytes: Vec = std::iter::repeat_n(1.0f32, input_elems) + .flat_map(|v| v.to_le_bytes()) + .collect(); + let mut input_device = DeviceBuffer::new(input_elems * elem_size).expect("input buffer"); + let output_device = DeviceBuffer::new(output_elems * elem_size).expect("output buffer"); + input_device + .copy_from_host(&input_bytes) + .expect("copy input"); + + unsafe { + context + .set_tensor_address("input", input_device.as_ptr()) + .expect("set input"); + context + .set_tensor_address("conv_out_2", output_device.as_ptr()) + .expect("set output"); + context + .enqueue_v3(crate::cuda::get_default_stream()) + .expect("enqueue"); + } + synchronize().expect("sync"); + + let seen = seen.lock().unwrap(); + assert!( + !seen.is_empty(), + "debug listener should have seen at least one tensor, saw 0" + ); + } + + #[test] + #[ignore = "only works on TRT enterprise at the moment"] + fn set_debug_listener_plus1_chain() { + let logger = Logger::stderr().expect("logger"); + let (engine_data, expected_debug_names) = + build_plus1_chain(&logger).expect("build network"); + assert_eq!( + expected_debug_names, + ["tensor_1", "tensor_2", "tensor_3", "tensor_4"] + ); + + let mut runtime = Runtime::new(&logger).expect("runtime"); + let mut engine = runtime + .deserialize_cuda_engine(&engine_data) + .expect("deserialize"); + let mut context = engine + .create_execution_context() + .expect("execution context"); + + let seen = Arc::new(Mutex::new(Vec::<(String, Vec)>::new())); + context + .set_debug_listener(Box::new(CollectingDebugListener { + seen: Arc::clone(&seen), + })) + .expect("set_debug_listener"); + context.set_all_tensors_debug_state(true).unwrap(); + context.set_unfused_tensors_debug_state(true).unwrap(); + + let elem_size = std::mem::size_of::(); + let mut input_device = DeviceBuffer::new(elem_size).expect("input buffer"); + let output_device = DeviceBuffer::new(elem_size).expect("output buffer"); + input_device + .copy_from_host(&0.0f32.to_le_bytes()) + .expect("copy input"); + + unsafe { + context + .set_tensor_address("tensor_0", input_device.as_ptr()) + .expect("set input"); + context + .set_tensor_address("tensor_4", output_device.as_ptr()) + .expect("set output"); + context + .enqueue_v3(crate::cuda::get_default_stream()) + .expect("enqueue"); + } + synchronize().expect("sync"); + + let mut output_bytes = [0u8; 4]; + output_device + .copy_to_host(&mut output_bytes) + .expect("copy output"); + let output_val = f32::from_le_bytes(output_bytes); + assert!( + (output_val - 4.0f32).abs() < 1e-5, + "expected output 4.0 (0+1+1+1+1), got {}", + output_val + ); + + let seen = seen.lock().unwrap(); + assert!( + seen.len() >= 4, + "debug listener should see at least 4 tensors, saw {}", + seen.len() + ); + for expected in &expected_debug_names { + assert!( + seen.iter().any(|(n, _)| n.contains(expected.as_str())), + "expected debug tensor {:?} among names {:?}", + expected, + seen.iter().map(|(n, _)| n.as_str()).collect::>() + ); + } + } +} diff --git a/trtx/tests/dynloading.rs b/trtx/tests/dynloading.rs index fee1a8d..6fd2ad2 100644 --- a/trtx/tests/dynloading.rs +++ b/trtx/tests/dynloading.rs @@ -9,29 +9,13 @@ mod tests { // test binary #[test] fn dynloading() { - let logger = Logger::stderr().unwrap(); - - // not linking let's builder creation fail - #[cfg(not(feature = "link_tensorrt_rtx"))] - assert!(matches!( - Builder::new(&logger), - Err(trtx::Error::TrtRtxLibraryNotLoaded) - )); - // Loading the library fixes the error trtx::dynamically_load_tensorrt(None::).unwrap(); let logger = Logger::stderr().unwrap(); - let builder = Builder::new(&logger).unwrap(); + let mut builder = Builder::new(&logger).unwrap(); let mut network = builder.create_network(0).unwrap(); - // not linking let's builder creation fail - #[cfg(not(feature = "link_tensorrt_onnxparser"))] - assert!(matches!( - OnnxParser::new(&mut network, &logger), - Err(trtx::Error::TrtOnnxParserLibraryNotLoaded) - )); - // Loading the library fixes the error trtx::dynamically_load_tensorrt_onnxparser(None::).unwrap(); OnnxParser::new(&mut network, &logger).unwrap(); diff --git a/trtx/tests/test_builder_config.rs b/trtx/tests/test_builder_config.rs index 134739e..0b9b2c5 100644 --- a/trtx/tests/test_builder_config.rs +++ b/trtx/tests/test_builder_config.rs @@ -14,7 +14,7 @@ mod tests { trtx::dynamically_load_tensorrt(None::).unwrap(); let logger = Logger::stderr().unwrap(); - let builder = Builder::new(&logger).unwrap(); + let mut builder = Builder::new(&logger).unwrap(); let mut config = builder.create_config().unwrap(); // Test timing iterations @@ -40,6 +40,7 @@ mod tests { // Test DLA core config.set_dla_core(-1); + #[cfg(not(feature = "mock"))] assert_eq!(config.get_dla_core(), -1); // Test device type @@ -52,7 +53,7 @@ mod tests { // Test tactic sources (kEDGE_MASK_CONVOLUTIONS = 3) let sources = 1u32 << 3; - assert!(config.set_tactic_sources(sources)); + config.set_tactic_sources(sources).unwrap(); #[cfg(not(feature = "mock"))] assert_eq!(config.get_tactic_sources(), sources); @@ -97,7 +98,9 @@ mod tests { assert_eq!(config.get_max_nb_tactics(), 10); // Test tiling optimization level - assert!(config.set_tiling_optimization_level(TilingOptimizationLevel::kFAST)); + config + .set_tiling_optimization_level(TilingOptimizationLevel::kFAST) + .unwrap(); #[cfg(not(feature = "mock"))] assert_eq!( config.get_tiling_optimization_level(), @@ -105,15 +108,17 @@ mod tests { ); // Test L2 limit for tiling - assert!(config.set_l2_limit_for_tiling(1024)); + config.set_l2_limit_for_tiling(1024).unwrap(); #[cfg(not(feature = "mock"))] assert_eq!(config.get_l2_limit_for_tiling(), 1024); // Test compute capabilities - assert!(config.set_nb_compute_capabilities(1)); + config.set_nb_compute_capabilities(1).unwrap(); #[cfg(not(feature = "mock"))] assert_eq!(config.get_nb_compute_capabilities(), 1); - assert!(config.set_compute_capability(ComputeCapability::kCURRENT, 0)); + config + .set_compute_capability(ComputeCapability::kCURRENT, 0) + .unwrap(); #[cfg(not(feature = "mock"))] assert_eq!( config.get_compute_capability(0),