Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 148 additions & 84 deletions onnxruntime/core/providers/qnn/qnn_ep_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include "core/providers/qnn/qnn_ep_utils.h"

#include <deque>
#include <iostream>
#include <string>
#include <unordered_set>

#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/qnn/common/inlined_containers.h"
Expand Down Expand Up @@ -1824,6 +1826,56 @@ std::vector<OrtNodeGroup> OrtSelectorManager::GetOrtQDQSelections(const OrtGraph

namespace utils {

size_t ComputeGroupExternalInDegree(
const QnnNodeGroupInfo& group,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map,
const OrtApi& ort_api) {
size_t external = 0;

auto count_external_producers = [&](const OrtNode* node) {
size_t num_inputs = 0;
if (OrtStatus* s = ort_api.Node_GetNumInputs(node, &num_inputs); s != nullptr) {
ort_api.ReleaseStatus(s);
return;
}
std::vector<const OrtValueInfo*> inputs(num_inputs);
if (OrtStatus* s = ort_api.Node_GetInputs(node, inputs.data(), inputs.size()); s != nullptr) {
ort_api.ReleaseStatus(s);
return;
}
for (const OrtValueInfo* input : inputs) {
if (input == nullptr) {
continue;
}
const OrtNode* producer_node = nullptr;
if (OrtStatus* s = ort_api.ValueInfo_GetValueProducer(input, &producer_node, nullptr); s != nullptr) {
ort_api.ReleaseStatus(s);
continue;
}
if (producer_node == nullptr) {
continue; // Initializer or graph input.
}
auto it = node_unit_map.find(producer_node);
if (it == node_unit_map.cend()) {
continue;
}
const OrtNodeUnit* producer_nu = it->second;
if (group.member_set.count(producer_nu) > 0) {
continue; // Internal edge within the group; do not count.
}
++external;
}
};

for (const OrtNodeUnit* member_nu : group.member_node_units) {
for (const OrtNode* node : member_nu->GetAllNodesInGroup()) {
count_external_producers(node);
}
}

return external;
}

// QNN-EP COPY START
// Below implementations are directly copied from "core/common/common.h"
// Returns whether `key` is in `container`.
Expand All @@ -1841,49 +1893,50 @@ std::vector<std::vector<const OrtNode*>> CreateSupportedPartitionNodeGroups(
const OrtApi& ort_api,
const std::vector<const OrtNode*>& supported_nodes,
const std::string& ep_type,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map) {
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map,
const std::vector<QnnNodeGroupInfo>& groups,
const std::unordered_map<const OrtNodeUnit*, size_t>& node_unit_to_group_id,
const Ort::Logger& logger) {
std::vector<std::vector<const OrtNode*>> supported_groups{};

// Fast membership test for EP support on the target OrtNode.
std::unordered_set<const OrtNode*> supported_nodes_set(supported_nodes.cbegin(), supported_nodes.cend());

size_t num_nodes = 0;
auto status = ort_api.Graph_GetNumNodes(graph, &num_nodes);
if (status != nullptr) {
ort_api.ReleaseStatus(status);
return {};
}
std::vector<const OrtNode*> graph_nodes(num_nodes);
status = ort_api.Graph_GetNodes(graph, graph_nodes.data(), graph_nodes.size());
if (status != nullptr) {
ort_api.ReleaseStatus(status);
return {};
}

// #inputs from unprocessed nodes (in-degree) per node.
std::unordered_map<size_t, size_t> in_degree{};
// Nodes that are ready to process.
std::deque<const OrtNode*> nodes_to_process{};
// Nodes that will be processed when considering the next partition node group.
std::deque<const OrtNode*> nodes_to_process_with_next_group{};
// Per-group in-degree; indexed by group_id.
std::vector<size_t> group_in_degree(groups.size(), 0);
size_t live_group_count = 0;
for (const QnnNodeGroupInfo& g : groups) {
if (g.is_defunct) {
continue;
}
group_in_degree[g.group_id] = g.external_in_degree;
++live_group_count;
}

// Initialize in-degrees and find root nodes.
for (size_t node_idx = 0; node_idx < num_nodes; ++node_idx) {
const OrtNode* node = graph_nodes[node_idx];
const OrtNodeUnit* node_unit = node_unit_map.at(node);
std::deque<size_t> queue_current{};
std::deque<size_t> queue_next{};

if (&node_unit->GetNode() != node) {
// Only process the target node.
// Seed with groups that have no external in-edges.
for (const QnnNodeGroupInfo& g : groups) {
if (g.is_defunct) {
continue;
}

size_t degree = node_unit->GetInputEdgesCount(ort_api);
in_degree.insert({node_unit->Index(), degree});
if (degree == 0) {
nodes_to_process.push_back(node);
if (group_in_degree[g.group_id] == 0) {
queue_current.push_back(g.group_id);
}
}

std::vector<const OrtNode*> supported_group{};
// The partition node group's border is the aggregate of its nodes' output nodes.
InlinedHashSet<const OrtNode*> supported_group_border{};
// Group ids of unprocessed downstream groups reachable from the currently in-progress partition. Mirrors the
// `supported_group_border` concept in the old BFS but tracks groups instead of OrtNodes.
std::unordered_set<size_t> supported_group_border{};

auto close_group = [&]() {
if (!supported_group.empty()) {
Expand All @@ -1893,87 +1946,98 @@ std::vector<std::vector<const OrtNode*>> CreateSupportedPartitionNodeGroups(
}
};

size_t num_nodes_processed = 0;
size_t num_groups_processed = 0;

while (!nodes_to_process.empty() || !nodes_to_process_with_next_group.empty()) {
if (nodes_to_process.empty()) {
// We have processed all the nodes that we can while building this partition node group, start a new one.
while (!queue_current.empty() || !queue_next.empty()) {
if (queue_current.empty()) {
close_group();
nodes_to_process.swap(nodes_to_process_with_next_group);
queue_current.swap(queue_next);
continue;
}

const OrtNode* node = nodes_to_process.front();
nodes_to_process.pop_front();

const OrtNodeUnit* node_unit = node_unit_map.at(node);
const bool is_qdq_node_unit = node_unit->UnitType() == OrtNodeUnit::Type::QDQGroup;

// A node that is already assigned to an EP other than current EP is unsupported.
const char* node_ep_name;
ORT_CONTINUE_ON_ERROR(ort_api.Node_GetEpName(node, &node_ep_name), ort_api);
const bool is_node_supported = ((std::string(node_ep_name).empty() || node_ep_name == ep_type) &&
std::find(supported_nodes.cbegin(), supported_nodes.cend(), node) != supported_nodes.cend());
size_t gid = queue_current.front();
queue_current.pop_front();

if (!is_node_supported && Contains(supported_group_border, node)) {
// An unsupported node on the border will be processed after the current partition node group.
nodes_to_process_with_next_group.push_back(node);
const QnnNodeGroupInfo& g = groups[gid];
if (g.is_defunct) {
++num_groups_processed;
continue;
}

if (is_node_supported) {
if (is_qdq_node_unit) {
// Add DQ -> node -> Q for the node unit and must be in topological order.
for (const OrtNode* dq : node_unit->GetDQNodes()) {
supported_group.push_back(dq);
}

supported_group.push_back(node);
const OrtNode* redundent_clip_node = node_unit->GetRedundantClipNode();
if (redundent_clip_node) {
supported_group.push_back(redundent_clip_node);
supported_group_border.erase(redundent_clip_node);
}
// A group is supported if IQnnNodeGroup::IsSupported accepted it AND the target OrtNode is not already claimed
// by another EP AND our own supported_nodes set contains the target (belt-and-suspenders check).
const OrtNode* target_ortnode = &g.target_node_unit->GetNode();
const char* node_ep_name = nullptr;
ORT_CONTINUE_ON_ERROR(ort_api.Node_GetEpName(target_ortnode, &node_ep_name), ort_api);
const std::string ep_name_str = node_ep_name ? std::string(node_ep_name) : std::string{};
const bool is_group_supported = g.is_supported &&
(ep_name_str.empty() || ep_name_str == ep_type) &&
supported_nodes_set.find(target_ortnode) != supported_nodes_set.cend();

// Border deferral: an unsupported group on the border of the in-progress partition gets pushed to the next
// partition so the current partition can close with its existing members intact.
if (!is_group_supported && supported_group_border.count(gid) > 0) {
queue_next.push_back(gid);
continue;
}

for (const OrtNode* q : node_unit->GetQNodes()) {
supported_group.push_back(q);
if (is_group_supported) {
// Emit all member OrtNodes atomically. This preserves QDQGroup DQ/target/Q ordering for each member
// NodeUnit (via GetAllNodesInGroup) and binds all fusion members to the same partition.
for (const OrtNodeUnit* member_nu : g.member_node_units) {
for (const OrtNode* node : member_nu->GetAllNodesInGroup()) {
supported_group.push_back(node);
}
} else {
supported_group.push_back(node);
}

// Remove node from the border.
supported_group_border.erase(node);
supported_group_border.erase(gid);
}

// For each downstream node:
// 1: Add the downstream node to the border if the current node is supported.
// 2: Adjust in-degrees of the nodes consuming the current node's outputs, and add any new nodes to process.
for (const OrtNode* output_node : node_unit->GetOutputNodes(ort_api)) {
const OrtNodeUnit* downstream_node_unit = node_unit_map.at(output_node);
const OrtNode* downstream_node = &downstream_node_unit->GetNode();

if (is_node_supported) {
supported_group_border.insert(downstream_node);
}
// Propagate to downstream groups. For each distinct edge from any member OrtNode to a NodeUnit outside the
// group, decrement the downstream group's in-degree. Edge-count semantics match ComputeGroupExternalInDegree:
// multiple edges from the group to the same downstream group produce multiple decrements.
for (const OrtNodeUnit* member_nu : g.member_node_units) {
for (const OrtNode* output_node : member_nu->GetOutputNodes(ort_api)) {
auto it_nu = node_unit_map.find(output_node);
if (it_nu == node_unit_map.cend()) {
continue;
}
const OrtNodeUnit* downstream_nu = it_nu->second;
if (g.member_set.count(downstream_nu) > 0) {
continue; // Edge to another member of the same group; internal.
}
auto it_gid = node_unit_to_group_id.find(downstream_nu);
if (it_gid == node_unit_to_group_id.cend()) {
continue;
}
size_t downstream_gid = it_gid->second;

auto& downstream_node_in_degree = in_degree[downstream_node_unit->Index()];
--downstream_node_in_degree;
if (is_group_supported) {
supported_group_border.insert(downstream_gid);
}

if (downstream_node_in_degree == 0) {
nodes_to_process.push_back(downstream_node);
auto& d_in = group_in_degree[downstream_gid];
if (d_in > 0) {
--d_in;
if (d_in == 0) {
queue_current.push_back(downstream_gid);
}
}
}
}

++num_nodes_processed;
++num_groups_processed;
}

close_group();

if (num_nodes_processed != in_degree.size()) {
ORT_CXX_API_THROW("Processed " + std::to_string(num_nodes_processed) +
" nodes. Expected to process " + std::to_string(in_degree.size()),
ORT_EP_FAIL);
if (num_groups_processed != live_group_count) {
std::string msg = "Processed " + std::to_string(num_groups_processed) +
" groups. Expected to process " + std::to_string(live_group_count) +
". An IQnnNodeGroup cycle may exist (a fusion member depends on an unsupported op that also " +
"depends on another fusion member). Cycle detection + demotion should have run upstream in " +
"GetSupportedNodes.";
ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_ERROR, msg.c_str());
ORT_CXX_API_THROW(msg, ORT_EP_FAIL);
}

return supported_groups;
Expand Down
34 changes: 33 additions & 1 deletion onnxruntime/core/providers/qnn/qnn_ep_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <optional>
#include <tuple>
Expand Down Expand Up @@ -401,13 +402,44 @@ class OrtSelectorManager {

namespace utils {

// Describes an IQnnNodeGroup (multi-NodeUnit fusion or a single NodeUnit wrapped as a trivial 1-member group)
// in terms the partitioner can reason about. One entry exists per IQnnNodeGroup produced by GetQnnNodeGroups
// (supported or not), so every NodeUnit in the graph belongs to exactly one group.
//
// The partitioner treats each group as an atomic BFS scheduling unit: all members are admitted into the same
// partition, or none. This prevents BFS from splitting fusion members across partition boundaries when an
// unsupported op sits topologically between them.
struct QnnNodeGroupInfo {
size_t group_id = 0; // Dense index into the groups vector.
const OrtNodeUnit* target_node_unit = nullptr; // From IQnnNodeGroup::GetTargetNodeUnit(); nullptr if defunct.
std::vector<const OrtNodeUnit*> member_node_units; // All member NodeUnits (includes the target).
std::unordered_set<const OrtNodeUnit*> member_set; // O(1) membership test; mirrors member_node_units.
bool is_supported = false; // Cached IQnnNodeGroup::IsSupported() result.
bool is_defunct = false; // Set when the group is demoted; BFS skips defunct entries.
size_t external_in_degree = 0; // Count of input edges from NodeUnits NOT in member_set.
};

// Walks each member's OrtNode inputs and counts edges whose producer NodeUnit is not in the group's member_set.
// Initializers (null producer) are skipped. For QDQGroup members, walks DQ/target/Q nodes.
size_t ComputeGroupExternalInDegree(
const QnnNodeGroupInfo& group,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map,
const OrtApi& ort_api);

// Refer to CreateSupportedPartitions in partitioning_utils.cc.
//
// Group-aware partitioner. `groups` describes every IQnnNodeGroup (fusion or 1-member wrapper);
// `node_unit_to_group_id` maps each NodeUnit to its owning group. The BFS iterates groups as atomic units,
// guaranteeing that members of any multi-NodeUnit fusion land in the same partition.
std::vector<std::vector<const OrtNode*>> CreateSupportedPartitionNodeGroups(
const OrtGraph* graph,
const OrtApi& ort_api,
const std::vector<const OrtNode*>& supported_nodes,
const std::string& ep_type,
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map);
const std::unordered_map<const OrtNode*, const OrtNodeUnit*>& node_unit_map,
const std::vector<QnnNodeGroupInfo>& groups,
const std::unordered_map<const OrtNodeUnit*, size_t>& node_unit_to_group_id,
const Ort::Logger& logger);

} // namespace utils

Expand Down
Loading