Skip to content

Commit 0eb0585

Browse files
author
ssjia
committed
[ET-VK] Fix exponential blowup in tag_memory_meta_pass repset tracing
Pull Request resolved: #18207 The trace_node_users_to_constrain_repset DFS previously tracked search depth as a per-branch int counter, allowing each branch of a fan-out to independently explore up to max_trace_search_depth nodes. In transformer-style graphs with heavy fan-out this caused exponential blowup in the number of nodes visited. Replace the int counter with a mutable list containing a single int that is shared by reference across all recursive branches. This limits the TOTAL number of nodes explored per top-level trace call to max_trace_search_depth (16), regardless of fan-out structure. Authored with Claude. ghstack-source-id: 353546691 @exported-using-ghexport Differential Revision: [D96790445](https://our.internmc.facebook.com/intern/diff/D96790445/)
1 parent 22174fa commit 0eb0585

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ def __init__(
132132
self.texture_limits = texture_limits
133133
self.force_fp16 = force_fp16
134134

135-
# Magic number to limit "lookahead" when tracing through users of an operator
136-
# to constrain the representation of its arguments/outputs.
137-
self.max_trace_search_depth = None
135+
# Limit the total number of nodes explored when tracing through users of
136+
# an operator to constrain the representation of its arguments/outputs.
137+
# Without a limit, transformer-style graphs cause exponential blowup.
138+
self.max_trace_search_depth = 64
138139

139140
def is_valid_op_node(self, node: Any) -> bool:
140141
"""
@@ -261,7 +262,7 @@ def constrain_repset_with_user(
261262
current_node: torch.fx.Node,
262263
arg_i: int,
263264
arg_repset: utils.TensorRepSet,
264-
search_depth: int = 0,
265+
search_depth: list[int] | None = None,
265266
) -> utils.TensorRepSet:
266267
"""
267268
Attempts to constrain `arg_repset` based on the required repset of the argument
@@ -301,21 +302,26 @@ def constrain_repset_with_user(
301302
current_node, arg_repset, search_depth
302303
)
303304

304-
def trace_node_users_to_constrain_repset(
305+
def trace_node_users_to_constrain_repset( # noqa: C901
305306
self,
306307
origin_node: torch.fx.Node,
307308
repset: utils.TensorRepSet,
308-
search_depth: int = 0,
309+
search_depth: list[int] | None = None,
309310
) -> utils.TensorRepSet:
310311
"""
311312
For an ambiguous repset, try to constrain the repset by tracing the required
312313
repsets of the users of `origin_node`. The idea is to try to find a representation
313314
that can be used the longest without needing user nodes to insert a transition
314315
for its arguments.
315316
"""
316-
# Optionally limit the search depth to improve export time
317+
# Optionally limit the total number of nodes explored to improve export
318+
# time. search_depth is a mutable list so that all branches of a fan-out
319+
# share a single counter, preventing exponential blowup.
317320
if self.max_trace_search_depth is not None:
318-
if search_depth > self.max_trace_search_depth:
321+
if search_depth is None:
322+
search_depth = [self.max_trace_search_depth]
323+
search_depth[0] -= 1
324+
if search_depth[0] <= 0:
319325
return repset
320326

321327
users_to_trace = origin_node.users
@@ -339,7 +345,7 @@ def trace_node_users_to_constrain_repset(
339345

340346
if arg_i_in_user is not None:
341347
repset = self.constrain_repset_with_user(
342-
usage_node, arg_i_in_user, repset, search_depth + 1
348+
usage_node, arg_i_in_user, repset, search_depth
343349
)
344350

345351
if repset.is_constrained():

0 commit comments

Comments
 (0)