Skip to content

Commit dce0bec

Browse files
committed
up
1 parent a096d94 commit dce0bec

4 files changed

Lines changed: 245 additions & 31 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ jobs:
7979
backends/mlx/test/test_pattern_utils.py \
8080
backends/mlx/test/test_partitioner.py \
8181
backends/mlx/test/test_serialization_dedup.py \
82+
backends/mlx/test/test_slot_recycling.py \
8283
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
8384
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
8485
-v

backends/mlx/builder/program_builder.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def make_tmp_value_slot(self) -> Tuple[str, Slot]:
242242
"""Create a temporary value (SymInt) slot."""
243243
return self.slot_manager.make_tmp_value_slot()
244244

245+
def tmp_scope(self):
246+
"""Context manager scoping temporary slot ids for reuse.
247+
248+
See :meth:`SlotManager.tmp_scope`.
249+
"""
250+
return self.slot_manager.tmp_scope()
251+
245252
def make_or_get_constant(self, name: str, tensor: torch.Tensor) -> Slot:
246253
"""
247254
Creates an extra constant outside of the ExportedProgram state_dict.
@@ -529,7 +536,8 @@ def _process_nodes(self) -> None: # noqa C901
529536

530537
if self.node_info[n].handler is not None:
531538
handler = self.node_info[n].handler
532-
handler(self, n)
539+
with self.tmp_scope():
540+
handler(self, n)
533541
self._mark_supported(n, handler=handler)
534542
continue
535543

@@ -558,7 +566,8 @@ def _process_nodes(self) -> None: # noqa C901
558566
continue
559567

560568
try:
561-
handler(self, n)
569+
with self.tmp_scope():
570+
handler(self, n)
562571
self._mark_supported(n, handler=handler)
563572
except Exception as e:
564573
trace_str = traceback.format_exc()
@@ -688,14 +697,20 @@ def _collect_used_slots(
688697
# Inputs, outputs, mutable buffers - always include
689698
used_slots.add(s)
690699

700+
# Count distinct physical slots. Slots that share (id_space, idx) are the
701+
# same slot reused across disjoint lifetimes (delete-as-you-go reclaim /
702+
# tmp_scope) and are coalesced to a single global id below, so they must
703+
# be counted once. (For non-tensors, SymInt/SymBool share the vid pool.)
691704
num_tensors: Dict[IdSpace, int] = defaultdict(int)
692705
num_values: Dict[IdSpace, int] = defaultdict(int)
693-
seen: Set[Slot] = set()
706+
seen_keys: Set[Tuple[bool, IdSpace, int]] = set()
694707
for s in used_slots:
695-
if s in seen:
708+
is_tensor = s.id_type == IdType.Tensor
709+
key = (is_tensor, s.id_space, s.idx)
710+
if key in seen_keys:
696711
continue
697-
seen.add(s)
698-
if s.id_type == IdType.Tensor:
712+
seen_keys.add(key)
713+
if is_tensor:
699714
num_tensors[s.id_space] += 1
700715
else:
701716
num_values[s.id_space] += 1
@@ -719,19 +734,28 @@ def _create_slot_mappings(
719734
IdSpace.Temp: 4,
720735
}
721736

737+
# Coalesce slots that share (id_space, idx) to a single global id. Such
738+
# slots are the same physical slot reused across disjoint lifetimes
739+
# (delete-as-you-go reclaim / tmp_scope), so they must map to the same
740+
# global Tid/Vid. Sorting by (id_space, idx) keeps per-space id ranges
741+
# contiguous, matching the counts from _collect_used_slots.
742+
def _coalesce(slots: List[Slot]) -> Dict[Slot, int]:
743+
mapping: Dict[Slot, int] = {}
744+
key_to_global: Dict[Tuple[IdSpace, int], int] = {}
745+
for s in sorted(slots, key=lambda s: (id_space_order[s.id_space], s.idx)):
746+
key = (s.id_space, s.idx)
747+
gid = key_to_global.get(key)
748+
if gid is None:
749+
gid = len(key_to_global)
750+
key_to_global[key] = gid
751+
mapping[s] = gid
752+
return mapping
753+
722754
# Create Tid mapping
723-
slot_to_tid = sorted(
724-
[s for s in used_slots if s.id_type == IdType.Tensor],
725-
key=lambda s: (id_space_order[s.id_space], s.idx),
726-
)
727-
slot_to_tid = {s: idx for idx, s in enumerate(slot_to_tid)}
755+
slot_to_tid = _coalesce([s for s in used_slots if s.id_type == IdType.Tensor])
728756

729757
# Create Vid mapping
730-
slot_to_vid = sorted(
731-
[s for s in used_slots if s.id_type != IdType.Tensor],
732-
key=lambda s: (id_space_order[s.id_space], s.idx),
733-
)
734-
slot_to_vid = {s: idx for idx, s in enumerate(slot_to_vid)}
758+
slot_to_vid = _coalesce([s for s in used_slots if s.id_type != IdType.Tensor])
735759

736760
# Remap all Tid/Vid values in instructions to use global indices
737761
if hasattr(self, "_tid_slot_map"):

backends/mlx/builder/slot_manager.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
import uuid
1010
from collections import defaultdict
11+
from contextlib import contextmanager
1112
from dataclasses import dataclass
1213
from enum import auto, Enum
13-
from typing import Dict, Optional, Tuple, Union
14+
from typing import Dict, Iterator, List, Optional, Tuple, Union
1415

1516
import torch
1617
from torch.fx.node import Node
@@ -73,6 +74,54 @@ def __init__(self):
7374
self.tid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager)
7475
self.vid_managers: Dict[IdSpace, IdManager] = defaultdict(IdManager)
7576
self.name_to_slot: Dict[str, Slot] = {}
77+
# Stack of active temp-slot scopes (see ``tmp_scope``). Temp tids/vids
78+
# allocated via make_tmp_slot()/make_tmp_value_slot() are registered on
79+
# the innermost scope and their ids returned for reuse on scope exit.
80+
self._tmp_scopes: List[List[Slot]] = []
81+
82+
@contextmanager
83+
def tmp_scope(self) -> Iterator[None]:
84+
"""Scope temporary slot allocations so their ids can be reused.
85+
86+
Temp tids/vids allocated via :meth:`make_tmp_slot` /
87+
:meth:`make_tmp_value_slot` inside this context are returned to their
88+
id pools when the context exits, so later allocations (temp or node)
89+
can reuse them. Allocating a temp slot outside any ``tmp_scope`` raises
90+
``RuntimeError``.
91+
92+
Scopes may be nested; each allocation is tied to the innermost scope.
93+
The Slot objects stay in ``name_to_slot`` (mirroring node-slot reclaim
94+
via ``return_id``) so serialization still sees every distinct slot.
95+
"""
96+
self._tmp_scopes.append([])
97+
try:
98+
yield
99+
finally:
100+
scope = self._tmp_scopes.pop()
101+
for slot in scope:
102+
if slot.id_type == IdType.Tensor:
103+
self.tid_managers[slot.id_space].return_id(slot.idx)
104+
else:
105+
self.vid_managers[slot.id_space].return_id(slot.idx)
106+
107+
def _new_tmp_slot(self, id_type: IdType, prefix: str) -> Tuple[str, Slot]:
108+
if not self._tmp_scopes:
109+
raise RuntimeError(
110+
f"{prefix}() must be called within a SlotManager.tmp_scope() "
111+
"context so temporary ids can be reclaimed and reused."
112+
)
113+
name = f"{prefix}_{uuid.uuid4().hex}"
114+
id_space = IdSpace.Temp
115+
manager = (
116+
self.tid_managers[id_space]
117+
if id_type == IdType.Tensor
118+
else self.vid_managers[id_space]
119+
)
120+
idx = manager.get_id()
121+
slot = Slot(id_type=id_type, id_space=id_space, idx=idx)
122+
self.name_to_slot[name] = slot
123+
self._tmp_scopes[-1].append(slot)
124+
return name, slot
76125

77126
def set_slot(self, node_or_name: Union[Node, str], slot: Slot):
78127
if isinstance(node_or_name, Node):
@@ -129,23 +178,11 @@ def make_constant_slot(self, name: str) -> Slot:
129178
return slot
130179

131180
def make_tmp_slot(self) -> Tuple[str, Slot]:
132-
name = f"tmp_{uuid.uuid4().hex}"
133-
id_space = IdSpace.Temp
134-
manager = self.tid_managers[id_space]
135-
idx = manager.get_id()
136-
slot = Slot(id_type=IdType.Tensor, id_space=id_space, idx=idx)
137-
self.name_to_slot[name] = slot
138-
return name, slot
181+
return self._new_tmp_slot(IdType.Tensor, "tmp")
139182

140183
def make_tmp_value_slot(self) -> Tuple[str, Slot]:
141184
"""Create a temporary SymInt slot and register it."""
142-
name = f"tmp_val_{uuid.uuid4().hex}"
143-
id_space = IdSpace.Temp
144-
manager = self.vid_managers[id_space]
145-
idx = manager.get_id()
146-
slot = Slot(id_type=IdType.SymInt, id_space=id_space, idx=idx)
147-
self.name_to_slot[name] = slot
148-
return name, slot
185+
return self._new_tmp_slot(IdType.SymInt, "tmp_val")
149186

150187
def make_or_get_slots(
151188
self, node: Node, id_space: IdSpace = IdSpace.Temp
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""Regression tests for temp-slot recycling in the MLX program builder.
10+
11+
Two invariants are guarded here:
12+
13+
1. ``SlotManager.tmp_scope`` reclaims temp tids/vids on exit (and creating a
14+
temp slot outside a scope raises), so local ids are reused.
15+
2. The serialized graph coalesces slots that share ``(id_space, idx)`` to a
16+
single global Tid/Vid, so ``num_temp_tensors`` / ``num_values`` reflect that
17+
reuse. Without this, recycled slots each get their own runtime slot (which is
18+
never freed until end-of-execution ``reset()``), inflating peak memory. This
19+
is easy to silently reintroduce (e.g. enumerating distinct Slot objects), so
20+
it is asserted directly.
21+
22+
Run::
23+
24+
python -m unittest executorch.backends.mlx.builder.test_slot_recycling
25+
"""
26+
27+
import unittest
28+
29+
import torch
30+
import torch.nn as nn
31+
32+
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
33+
from executorch.backends.mlx.builder.slot_manager import (
34+
IdSpace,
35+
IdType,
36+
Slot,
37+
SlotManager,
38+
)
39+
from executorch.backends.mlx.serialization.mlx_graph_schema import Tid, Vid
40+
41+
42+
def _trivial_ep():
43+
"""Minimal ExportedProgram just to satisfy ``MLXProgramBuilder.__init__``.
44+
45+
The graph is never processed; the coalescing tests drive the builder's slot
46+
bookkeeping directly.
47+
"""
48+
49+
class _Identity(nn.Module):
50+
def forward(self, x):
51+
return x + 1
52+
53+
return torch.export.export(_Identity(), (torch.zeros(2),))
54+
55+
56+
class TmpScopeTest(unittest.TestCase):
57+
def test_make_tmp_requires_scope(self):
58+
sm = SlotManager()
59+
with self.assertRaises(RuntimeError):
60+
sm.make_tmp_slot()
61+
with self.assertRaises(RuntimeError):
62+
sm.make_tmp_value_slot()
63+
64+
def test_tmp_ids_reclaimed_and_reused(self):
65+
sm = SlotManager()
66+
with sm.tmp_scope():
67+
_, a = sm.make_tmp_slot()
68+
_, b = sm.make_tmp_slot()
69+
self.assertNotEqual(a.idx, b.idx) # live simultaneously
70+
self.assertTrue(sm.is_alive(a))
71+
# Reclaimed on exit.
72+
self.assertFalse(sm.is_alive(a))
73+
self.assertFalse(sm.is_alive(b))
74+
# Next scope reuses a freed idx.
75+
with sm.tmp_scope():
76+
_, c = sm.make_tmp_slot()
77+
self.assertIn(c.idx, (a.idx, b.idx))
78+
79+
def test_value_slots_reclaimed(self):
80+
sm = SlotManager()
81+
with sm.tmp_scope():
82+
_, v = sm.make_tmp_value_slot()
83+
self.assertTrue(sm.is_alive(v))
84+
self.assertFalse(sm.is_alive(v))
85+
86+
def test_nested_scopes(self):
87+
sm = SlotManager()
88+
with sm.tmp_scope():
89+
_, outer = sm.make_tmp_slot()
90+
with sm.tmp_scope():
91+
_, inner = sm.make_tmp_slot()
92+
# Inner scope reclaimed its slot; outer slot is still live.
93+
self.assertFalse(sm.is_alive(inner))
94+
self.assertTrue(sm.is_alive(outer))
95+
96+
97+
class SlotCoalescingTest(unittest.TestCase):
98+
"""Slots sharing ``(id_space, idx)`` must map to one global Tid/Vid."""
99+
100+
def _builder_with_slots(self, tensor_slots, value_slots):
101+
P = MLXProgramBuilder(_trivial_ep())
102+
# Start from a clean slot table so the trivial graph's own slots don't
103+
# interfere, then register synthetic slots as if emitted by handlers.
104+
P.slot_manager = SlotManager()
105+
P._tid_slot_map = []
106+
P._vid_slot_map = []
107+
for i, s in enumerate(tensor_slots):
108+
P.slot_manager.name_to_slot[f"t{i}"] = s
109+
P._tid_slot_map.append((Tid(idx=None), s))
110+
for i, s in enumerate(value_slots):
111+
P.slot_manager.name_to_slot[f"v{i}"] = s
112+
P._vid_slot_map.append((Vid(idx=None), s))
113+
return P
114+
115+
def test_reused_tids_coalesce(self):
116+
a = Slot(IdType.Tensor, IdSpace.Temp, 0)
117+
b = Slot(IdType.Tensor, IdSpace.Temp, 0) # reused idx 0 (disjoint life)
118+
c = Slot(IdType.Tensor, IdSpace.Temp, 1)
119+
k = Slot(IdType.Tensor, IdSpace.Constant, 0)
120+
P = self._builder_with_slots([a, b, c, k], [])
121+
122+
used, num_tensors, _ = P._collect_used_slots()
123+
slot_to_tid, _ = P._create_slot_mappings(used)
124+
125+
self.assertEqual(slot_to_tid[a], slot_to_tid[b], "reused idx must coalesce")
126+
self.assertNotEqual(slot_to_tid[a], slot_to_tid[c], "distinct idx stays distinct")
127+
# Counts reflect distinct (id_space, idx), not distinct Slot objects.
128+
self.assertEqual(num_tensors[IdSpace.Temp], 2)
129+
self.assertEqual(sum(num_tensors.values()), len(set(slot_to_tid.values())))
130+
# Emitted Tid references collapse in the serialized graph too.
131+
ref = {id(s): t for t, s in P._tid_slot_map}
132+
self.assertEqual(ref[id(a)].idx, ref[id(b)].idx)
133+
self.assertNotEqual(ref[id(a)].idx, ref[id(c)].idx)
134+
135+
def test_reused_vids_coalesce(self):
136+
# SymInt and SymBool share the vid pool, so equal idx must coalesce.
137+
v0 = Slot(IdType.SymInt, IdSpace.Temp, 0)
138+
v0b = Slot(IdType.SymBool, IdSpace.Temp, 0)
139+
v1 = Slot(IdType.SymInt, IdSpace.Temp, 1)
140+
P = self._builder_with_slots([], [v0, v0b, v1])
141+
142+
used, _, num_values = P._collect_used_slots()
143+
_, slot_to_vid = P._create_slot_mappings(used)
144+
145+
self.assertEqual(slot_to_vid[v0], slot_to_vid[v0b], "shared vid idx must coalesce")
146+
self.assertNotEqual(slot_to_vid[v0], slot_to_vid[v1])
147+
self.assertEqual(num_values[IdSpace.Temp], 2)
148+
self.assertEqual(sum(num_values.values()), len(set(slot_to_vid.values())))
149+
150+
151+
if __name__ == "__main__":
152+
unittest.main()

0 commit comments

Comments
 (0)