Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,38 @@ Ort::Status GetMainContextNode(const OrtGraph** graphs,
return Ort::Status();
}

std::unordered_map<std::string, std::string> ParseIoNameOverrides(const OrtNode* ep_context_node) {
std::unordered_map<std::string, std::string> overrides;
if (ep_context_node == nullptr) {
return overrides;
}
OrtNodeAttrHelper node_helper(*ep_context_node);
const std::string encoded = node_helper.Get(IO_NAME_OVERRIDES, std::string{});
// Decode "internal=external;" pairs.
size_t pos = 0;
while (pos < encoded.size()) {
size_t sep = encoded.find(';', pos);
if (sep == std::string::npos) {
sep = encoded.size();
}
const std::string pair = encoded.substr(pos, sep - pos);
pos = sep + 1;
if (pair.empty()) {
continue;
}
size_t eq = pair.find('=');
if (eq == std::string::npos) {
continue;
}
std::string internal = pair.substr(0, eq);
std::string external = pair.substr(eq + 1);
if (!internal.empty() && !external.empty()) {
overrides.emplace(std::move(internal), std::move(external));
}
}
return overrides;
}

Ort::Status GetEpContextFromMainNode(const OrtNode* main_context_node,
const OrtApi& ort_api,
const std::basic_string<ORTCHAR_T>& ctx_onnx_model_path,
Expand Down Expand Up @@ -296,7 +328,8 @@ Ort::Status CreateEPContextNodes(const OrtNode** fused_nodes,
const Ort::Logger& logger,
bool share_ep_contexts,
bool stop_share_ep_contexts,
const std::string& ep_name) {
const std::string& ep_name,
const std::unordered_map<std::string, std::string>& tensor_name_overrides) {
// Still need more work to support multiple partition, it's out of EP's scope.
// Already have code to make sure it's single partition before this method get invoked.
for (size_t idx = 0; idx < count; ++idx) {
Expand Down Expand Up @@ -433,6 +466,27 @@ Ort::Status CreateEPContextNodes(const OrtNode** fused_nodes,
&attr));
attributes.push_back(attr);

// Persist the offload_graph_io_quantization tensor-name overrides so the cached-context load
// path can resolve graph I/O by name rather than by position (position is unreliable when QNN
// reorders graph outputs). Skipped when the map is empty (offload disabled) so non-offload
// context models are byte-for-byte unaffected.
if (!tensor_name_overrides.empty()) {
std::string encoded;
for (const auto& [internal, external] : tensor_name_overrides) {
encoded += internal;
encoded += '=';
encoded += external;
encoded += ';';
}
attr = nullptr;
ORT_CXX_RETURN_ON_API_FAIL(ort_api.CreateOpAttr(IO_NAME_OVERRIDES.c_str(),
encoded.c_str(),
static_cast<int>(encoded.length()),
ORT_OP_ATTR_STRING,
&attr));
attributes.push_back(attr);
}

ORT_CXX_RETURN_ON_API_FAIL(model_editor_api.CreateNode(EPCONTEXT_OP.c_str(),
kMSDomain,
graph_name.c_str(),
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "core/providers/qnn/builder/qnn_def.h"
Expand All @@ -26,6 +27,9 @@ static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
static const std::string MAX_SIZE = "max_size";
// Serialized internal->external tensor-name overrides produced by offload_graph_io_quantization.
// Persisted into the EPContext node so the cached-context load path can resolve graph I/O by name.
static const std::string IO_NAME_OVERRIDES = "io_name_overrides";

// EP_CONTEXT_TYPES
static const std::string EP_CONTEXT_TYPE_BIN = "bin";
Expand All @@ -49,6 +53,11 @@ Ort::Status GetMainContextNode(const OrtGraph** graphs,
const OrtApi& ort_api,
std::vector<int>& main_context_pos);

// Parses the IO_NAME_OVERRIDES attribute (if present) from an EPContext node into an
// internal->external tensor-name map. Returns an empty map when the attribute is absent
// (offload disabled, or a context model generated before this attribute existed).
std::unordered_map<std::string, std::string> ParseIoNameOverrides(const OrtNode* ep_context_node);

Ort::Status GetEpContextFromMainNode(const OrtNode* main_context_node,
const OrtApi& ort_api,
const std::basic_string<ORTCHAR_T>& ctx_onnx_model_path,
Expand Down Expand Up @@ -85,7 +94,8 @@ Ort::Status CreateEPContextNodes(const OrtNode** fused_nodes,
const Ort::Logger& logger,
bool share_ep_contexts,
bool stop_share_ep_contexts,
const std::string& ep_name);
const std::string& ep_name,
const std::unordered_map<std::string, std::string>& tensor_name_overrides);

} // namespace qnn
} // namespace onnxruntime
82 changes: 61 additions & 21 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,33 +171,73 @@ Ort::Status QnnModel::SetGraphInputOutputInfo(const QnnModelContext& context) {
std::forward_as_tuple(i, static_cast<int32_t>(elem_type), std::move(shape)));
}

// DLC tensors may carry overridden names that differ from the fused node I/O names.
// QNN tensors deserialized from the context binary may carry overridden names (produced by
// offload_graph_io_quantization) that differ from the fused-node I/O names. Alias each QNN name
// onto the correct fused-node entry so GetOutputIndex resolves to the right ORT index/type.
if (graph_info_) {
auto add_qnn_name_aliases = [](GraphInputOutputInfo& io_info,
auto alias_entry = [](GraphInputOutputInfo& io_info,
const std::string& qnn_name,
const std::string& fused_name) {
if (qnn_name == fused_name || io_info.indices.find(qnn_name) != io_info.indices.end()) {
return;
}
auto idx_it = io_info.indices.find(fused_name);
if (idx_it != io_info.indices.end()) {
io_info.indices.emplace(qnn_name, idx_it->second);
}
auto tensor_it = io_info.tensors.find(fused_name);
if (tensor_it != io_info.tensors.end()) {
const OnnxTensorInfo& info = tensor_it->second;
io_info.tensors.emplace(std::piecewise_construct,
std::forward_as_tuple(qnn_name),
std::forward_as_tuple(info.index_, info.data_type_,
std::vector<int64_t>(info.shape_)));
}
};

if (context.tensor_name_overrides && !context.tensor_name_overrides->empty()) {
// Preferred: resolve by the persisted map (order-independent).
// The bin tensor name is the `external` (e.g. "sep_cls_score"); the fused-node edge is
// the `internal` (e.g. "sep_cls_score_QuantizeLinear_Output").
auto alias_by_name = [&](GraphInputOutputInfo& io_info,
const std::vector<QnnTensorWrapper>& qnn_tensors) {
std::unordered_set<std::string> qnn_names;
for (const auto& t : qnn_tensors) {
qnn_names.insert(t.GetName());
}
for (const auto& [internal, external] : *context.tensor_name_overrides) {
if (qnn_names.count(external)) {
alias_entry(io_info, external, internal);
}
}
};
alias_by_name(graph_inputs_, graph_info_->InputTensors());
alias_by_name(graph_outputs_, graph_info_->OutputTensors());
} else {
// Legacy fallback for context binaries generated before the io_name_overrides attribute
// existed. Pairs by position, which is unreliable when QNN reorders graph I/O outputs.
auto alias_by_position = [&](GraphInputOutputInfo& io_info,
const std::vector<QnnTensorWrapper>& qnn_tensors,
const std::vector<std::string>& fused_order) {
for (size_t i = 0; i < qnn_tensors.size() && i < fused_order.size(); ++i) {
const std::string& qnn_name = qnn_tensors[i].GetName();
const std::string& fused_name = fused_order[i];
if (qnn_name != fused_name && io_info.indices.find(qnn_name) == io_info.indices.end()) {
auto idx_it = io_info.indices.find(fused_name);
if (idx_it != io_info.indices.end()) {
io_info.indices.emplace(qnn_name, idx_it->second);
}
auto tensor_it = io_info.tensors.find(fused_name);
if (tensor_it != io_info.tensors.end()) {
const OnnxTensorInfo& info = tensor_it->second;
io_info.tensors.emplace(std::piecewise_construct,
std::forward_as_tuple(qnn_name),
std::forward_as_tuple(info.index_, info.data_type_,
std::vector<int64_t>(info.shape_)));
bool any_alias = false;
for (size_t i = 0; i < qnn_tensors.size() && i < fused_order.size(); ++i) {
const std::string& qnn_name = qnn_tensors[i].GetName();
if (qnn_name != fused_order[i] && io_info.indices.find(qnn_name) == io_info.indices.end()) {
alias_entry(io_info, qnn_name, fused_order[i]);
any_alias = true;
}
}
return any_alias;
};
bool aliased = alias_by_position(graph_inputs_, graph_info_->InputTensors(), fused_input_order);
aliased |= alias_by_position(graph_outputs_, graph_info_->OutputTensors(), fused_output_order);
if (aliased) {
ORT_CXX_LOG(context.logger, ORT_LOGGING_LEVEL_WARNING,
"QNN context binary has renamed graph I/O but no io_name_overrides attribute; "
"falling back to positional name aliasing, which may misbind reordered I/O. "
"Regenerate the context model with the current EP to embed the name mapping.");
}
};

add_qnn_name_aliases(graph_inputs_, graph_info_->InputTensors(), fused_input_order);
add_qnn_name_aliases(graph_outputs_, graph_info_->OutputTensors(), fused_output_order);
}
}

return Ort::Status();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ size_t GetQnnTensorDataSizeInBytes(const Qnn_Tensor_t& tensor);
bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor);

// TODO: make these work with Wrappers?
std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type);
std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param);
std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor);
std::ostream& operator<<(std::ostream& out, const QnnOpConfigWrapper& op_conf_wrapper);
Expand Down
19 changes: 15 additions & 4 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,7 @@ OrtStatus* QnnEp::CompileContextModel(const OrtGraph** graphs,
// Collect graph and fused nodes names.
std::vector<std::pair<std::string, std::string>> names;
names.reserve(count);
std::vector<std::unordered_map<std::string, std::string>> io_name_overrides_per_graph(count);

for (size_t graph_idx = 0; graph_idx < count; ++graph_idx) {
const char* graph_name = nullptr;
Expand Down Expand Up @@ -1729,6 +1730,7 @@ OrtStatus* QnnEp::CompileContextModel(const OrtGraph** graphs,
}

names.push_back(std::pair<std::string, std::string>(graph_name, ep_context_node_name));
io_name_overrides_per_graph[graph_idx] = qnn::ParseIoNameOverrides(ep_context_node);
}

// Get QnnModel from EP shared contexts
Expand Down Expand Up @@ -1762,7 +1764,9 @@ OrtStatus* QnnEp::CompileContextModel(const OrtGraph** graphs,
/*onnx_output_names=*/nullptr,
/*model_settings=*/nullptr,
/*graph_configs=*/nullptr,
/*tensor_name_overrides=*/nullptr,
/*tensor_name_overrides=*/io_name_overrides_per_graph[graph_idx].empty()
? nullptr
: &io_name_overrides_per_graph[graph_idx],
/*json_qnn_graph_path=*/{}};
RETURN_IF_NOT_OK(qnn_model_shared->SetGraphInputOutputInfo(context));
RETURN_IF_NOT_OK(qnn_model_shared->SetupQnnInputOutput(logger_));
Expand Down Expand Up @@ -1840,7 +1844,9 @@ OrtStatus* QnnEp::CompileContextModel(const OrtGraph** graphs,
/*onnx_output_names=*/nullptr,
/*model_settings=*/nullptr,
/*graph_configs=*/nullptr,
/*tensor_name_overrides=*/nullptr,
/*tensor_name_overrides=*/io_name_overrides_per_graph[graph_idx].empty()
? nullptr
: &io_name_overrides_per_graph[graph_idx],
/*json_qnn_graph_path=*/{}};
RETURN_IF_NOT_OK(qnn_model->SetGraphInputOutputInfo(context));
RETURN_IF_NOT_OK(qnn_model->SetupQnnInputOutput(logger_));
Expand Down Expand Up @@ -1906,7 +1912,8 @@ OrtStatus* QnnEp::CreateEPContextNodes(const OrtGraph* graph,
logger_,
share_ep_contexts_,
stop_share_ep_contexts_,
name_));
name_,
tensor_name_overrides_));

// Get compatibility info for later query in GetCompiledModelCompatibilityInfo.
Ort::Status status = qnn_cache_compatibility_manager_->GetCompatibilityInfo(compatibility_info_);
Expand Down Expand Up @@ -2132,13 +2139,17 @@ OrtStatus* ORT_API_CALL QnnEp::CompileImpl(_In_ OrtEp* this_ptr,
#endif // _WIN32

// Clean up transient GetCapability→Compile state.
// NOTE: tensor_name_overrides_ must NOT be cleared here; it is read by CreateEPContextNodes
// below to serialize the io_name_overrides attribute into the EPContext model.
ep->onnx_graph_io_names_.reset();
ep->tensor_name_overrides_.clear();

if (ep->context_cache_enabled_) {
RETURN_IF_NOT_NULL(ep->CreateEPContextNodes(graphs[0], fused_nodes, count, ep_context_nodes));
}

// Clear only after CreateEPContextNodes has serialized the map into the EPContext model.
ep->tensor_name_overrides_.clear();

#if defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64))
end = std::chrono::steady_clock::now();
auto total_compile_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - compile_start);
Expand Down
Loading