Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .github/workflows/linux_x64_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,21 @@ jobs:
run: |
echo "TRTEP_LIBRARY_PATH=$GITHUB_WORKSPACE/orttrtep/lib/libORTTensorRTEp.so" >> $GITHUB_ENV

- name: Run tests
- name: Run unit tests
env:
TRT_EP_LIBRARY_PATH: ${{ env.TRTEP_LIBRARY_PATH }}
TESTDATA_DIR: "${{ github.workspace }}/orttrtep/bin/testdata"
LD_LIBRARY_PATH: "${{ github.workspace }}/onnxruntime:${{ github.workspace }}/orttrtep/lib:$LD_LIBRARY_PATH"
run: |
TEST_EXE="$GITHUB_WORKSPACE/orttrtep/bin/trt_ep_tests"
if [ -f "$TEST_EXE" ]; then
chmod +x "$TEST_EXE"
"$TEST_EXE" --gtest_output=xml:trt_ep_unit_test_results.xml
else
echo "WARNING: trt_ep_tests not found, skipping unit tests"
fi

- name: Run onnxruntime_provider_test
env:
ORT_UNIT_TEST_MAIN_LOG_LEVEL: 0
ORT_TRT_EP_ENABLE_BUILDER_PLACEHOLDER: 1
Expand All @@ -375,20 +389,6 @@ jobs:
"$GITHUB_WORKSPACE/onnxruntime/onnxruntime_provider_test" \
"${{ env.ARTIFACT_NAME }}"

- name: Run unit tests
env:
TRT_EP_LIBRARY_PATH: ${{ env.TRTEP_LIBRARY_PATH }}
TESTDATA_DIR: "${{ github.workspace }}/orttrtep/bin/testdata"
LD_LIBRARY_PATH: "${{ github.workspace }}/onnxruntime:${{ github.workspace }}/orttrtep/lib:$LD_LIBRARY_PATH"
run: |
TEST_EXE="$GITHUB_WORKSPACE/orttrtep/bin/trt_ep_tests"
if [ -f "$TEST_EXE" ]; then
chmod +x "$TEST_EXE"
"$TEST_EXE" --gtest_output=xml:trt_ep_unit_test_results.xml
else
echo "WARNING: trt_ep_tests not found, skipping unit tests"
fi

- name: Upload build artifacts
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v7
Expand Down
47 changes: 26 additions & 21 deletions .github/workflows/windows_x64_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,32 @@ jobs:
python-version: '3.x'
architecture: x64

- name: Run tests
- name: Run unit tests
shell: pwsh
env:
TRT_EP_LIBRARY_PATH: ${{ env.TRTEP_LIBRARY_PATH }}
TESTDATA_DIR: '${{ github.workspace }}\orttrtep\bin\testdata'
run: |
$testExe = "${{ github.workspace }}\orttrtep\bin\trt_ep_tests.exe"
if (Test-Path $testExe) {
# Find and copy onnxruntime.dll to same directory as test exe so it's found
$ortDll = Get-ChildItem -Path "${{ github.workspace }}\onnxruntime" -Filter "onnxruntime.dll" -Recurse | Select-Object -First 1
if ($ortDll) {
Copy-Item $ortDll.FullName "${{ github.workspace }}\orttrtep\bin\onnxruntime.dll" -Force
Write-Host "Copied onnxruntime.dll from $($ortDll.FullName)"
} else {
Write-Warning "onnxruntime.dll not found in ORT artifacts"
}

& $testExe --gtest_output=xml:trt_ep_unit_test_results.xml
if ($lastExitCode -ne 0) {
exit $lastExitCode
}
} else {
Write-Warning "trt_ep_tests.exe not found, skipping unit tests"
}

- name: Run onnxruntime_provider_test
shell: pwsh
env:
ORT_UNIT_TEST_MAIN_LOG_LEVEL: 0
Expand All @@ -336,26 +361,6 @@ jobs:
exit $lastExitCode
}

- name: Run unit tests
shell: pwsh
env:
TRT_EP_LIBRARY_PATH: ${{ env.TRTEP_LIBRARY_PATH }}
TESTDATA_DIR: '${{ github.workspace }}\orttrtep\bin\testdata'
run: |
$testExe = "${{ github.workspace }}\orttrtep\bin\trt_ep_tests.exe"
if (Test-Path $testExe) {
# Copy onnxruntime.dll to same directory as test exe so it's found
Copy-Item "${{ github.workspace }}\onnxruntime\onnxruntime.dll" `
"${{ github.workspace }}\orttrtep\bin\onnxruntime.dll" -ErrorAction SilentlyContinue

& $testExe --gtest_output=xml:trt_ep_unit_test_results.xml
if ($lastExitCode -ne 0) {
exit $lastExitCode
}
} else {
Write-Warning "trt_ep_tests.exe not found, skipping unit tests"
}

- name: Upload build artifacts
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v7
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ if(ORTTensorRTEp_BUILD_TESTS)
include(GoogleTest)
gtest_discover_tests(trt_ep_tests
PROPERTIES ENVIRONMENT "TESTDATA_DIR=$<TARGET_FILE_DIR:trt_ep_tests>/testdata"
DISCOVERY_MODE PRE_TEST
)
endif()

Expand Down
125 changes: 125 additions & 0 deletions src/tensorrt_execution_provider_custom_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "tensorrt_execution_provider_custom_ops.h"
#include "nv_includes.h"

#include <cstring>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif

namespace trt_ep {

/*
* Create custom op domain list for TRT plugins.
*
* Collects all registered TRT plugins from the TRT registry and creates custom ops
* with "trt.plugins" domain. Additionally, if users specify extra plugin libraries,
* TRT EP will load them at runtime which will register those plugins to the TRT
* plugin registry.
*
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
OrtStatus* CreateTensorRTCustomOpDomainList(const char* ep_name,
const std::string& extra_plugin_lib_paths,
std::vector<OrtCustomOpDomain*>& domain_list) {
// Static storage for the custom op domain and custom ops.
// These must persist for the process lifetime since ORT holds raw pointers to them.
static std::unique_ptr<Ort::CustomOpDomain> custom_op_domain;
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

// If already initialized, just return the cached domain.
if (custom_op_domain != nullptr) {
domain_list.push_back(*custom_op_domain);
return nullptr;
}

// Load any extra TRT plugin libraries if specified.
// When the TRT plugin library is loaded, the global static object is created and the
// plugin is registered to TRT registry. This is done through macro, for example,
// REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
if (!extra_plugin_lib_paths.empty()) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
#ifdef _WIN32
HMODULE handle = LoadLibraryA(lib.c_str());
if (handle == nullptr) {
// Log but don't fail - some plugins may be optional
}
#else
void* handle = dlopen(lib.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
// Log but don't fail
}
#endif
}
}

try {
// Initialize default TRT plugins
initLibNvInferPlugins(nullptr, "");

// Get all registered TRT plugins from registry
int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;

custom_op_domain = std::make_unique<Ort::CustomOpDomain>("trt.plugins");

for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
nvinfer1::AsciiChar const* plugin_name = nullptr;
if (std::strcmp(plugin_creators[i]->getInterfaceInfo().kind, "PLUGIN CREATOR_V1") == 0) {
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996) // Ignore warning C4996: deprecated API
#endif
auto plugin_creator_v1 = static_cast<nvinfer1::IPluginCreator const*>(plugin_creator);
plugin_name = plugin_creator_v1->getPluginName();
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
} else if (std::strcmp(plugin_creators[i]->getInterfaceInfo().kind, "PLUGIN CREATOR_V3ONE") == 0) {
auto plugin_creator_v3 = static_cast<nvinfer1::IPluginCreatorV3One const*>(plugin_creator);
plugin_name = plugin_creator_v3->getPluginName();
} else {
continue; // Unknown plugin creator type, skip
}

// Each plugin may have different versions; we only register once per name
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}

auto custom_op = std::make_unique<TensorRTCustomOp>(ep_name, nullptr);
custom_op->SetName(plugin_name);
custom_op_domain->Add(custom_op.get());
created_custom_op_list.push_back(std::move(custom_op));
registered_plugin_names.insert(plugin_name);
}

domain_list.push_back(*custom_op_domain);
} catch (const std::exception&) {
// Failed to get TRT plugins. The domain won't be added but this is not fatal.
custom_op_domain.reset();
}

return nullptr;
}

} // namespace trt_ep
86 changes: 86 additions & 0 deletions src/tensorrt_execution_provider_custom_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT

#include <string>
#include <vector>

namespace trt_ep {

struct TensorRTCustomKernel {
TensorRTCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream)
: compute_stream_(compute_stream) {
}

void Compute(OrtKernelContext* /*context*/) {
// The implementation is in TensorRT plugin. No need to implement it here.
};

private:
void* compute_stream_;
};

struct TensorRTCustomOp : Ort::CustomOpBase<TensorRTCustomOp, TensorRTCustomKernel> {
explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider),
compute_stream_(compute_stream) {
}

void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const {
return new TensorRTCustomKernel(info, compute_stream_);
};

const char* GetName() const { return name_.c_str(); };

void SetName(const char* name) { name_ = name; };

const char* GetExecutionProviderType() const { return provider_.c_str(); };

size_t GetInputTypeCount() const { return num_inputs_; };

void SetInputTypeCount(size_t num) { num_inputs_ = num; };

ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };

OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC;
};

size_t GetOutputTypeCount() const { return num_outputs_; };

void SetOutputTypeCount(size_t num) { num_outputs_ = num; };

ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };

OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC;
};

bool GetVariadicInputHomogeneity() const {
return false; // heterogenous
}

bool GetVariadicOutputHomogeneity() const {
return false; // heterogeneous
}

private:
std::string provider_;
void* compute_stream_;
std::string name_;
size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input
size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output
};

/// Creates custom op domains for TRT plugins. Returns the number of domains created.
/// The domain_list is populated with OrtCustomOpDomain pointers that remain valid
/// for the lifetime of the process (static storage).
OrtStatus* CreateTensorRTCustomOpDomainList(const char* ep_name,
const std::string& extra_plugin_lib_paths,
std::vector<OrtCustomOpDomain*>& domain_list);

} // namespace trt_ep
37 changes: 37 additions & 0 deletions src/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* e
ReleaseAllocator = ReleaseAllocatorImpl;
CreateDataTransfer = CreateDataTransferImpl;
IsStreamAware = IsStreamAwareImpl;
GetNumCustomOpDomains = GetNumCustomOpDomainsImpl;
GetCustomOpDomains = GetCustomOpDomainsImpl;
}

TensorrtExecutionProviderFactory::~TensorrtExecutionProviderFactory() {
Expand Down Expand Up @@ -444,6 +446,41 @@ bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtE
return true;
}

OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetNumCustomOpDomainsImpl(
OrtEpFactory* this_ptr, size_t* num_domains) noexcept {
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);

if (factory->custom_op_domain_list_.empty()) {
// Note: extra_plugin_lib_paths are not available at factory level (they come from
// per-session provider options). Default TRT plugins are still registered here.
// Extra plugins loaded during EP creation will register to the TRT registry but
// won't add new custom op domains retroactively.
auto* status = trt_ep::CreateTensorRTCustomOpDomainList(
factory->ep_name_.c_str(), "", factory->custom_op_domain_list_);
if (status != nullptr) {
return status;
}
}

*num_domains = factory->custom_op_domain_list_.size();
return nullptr;
}

OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetCustomOpDomainsImpl(
OrtEpFactory* this_ptr, OrtCustomOpDomain** domains, size_t num_domains) noexcept {
auto* factory = static_cast<TensorrtExecutionProviderFactory*>(this_ptr);

if (num_domains > factory->custom_op_domain_list_.size()) {
return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "num_domains exceeds available custom op domains");
}

for (size_t i = 0; i < num_domains; ++i) {
domains[i] = factory->custom_op_domain_list_[i];
}

return nullptr;
}

OrtStatus* TensorrtExecutionProviderFactory::GetKernelRegistryForEp(TensorrtExecutionProvider& ep,
const OrtKernelRegistry** out_kernel_registry) {
*out_kernel_registry = nullptr;
Expand Down
Loading
Loading