|
| 1 | +# |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | +# |
| 8 | + |
| 9 | +"""Regression tests for temp-slot recycling in the MLX program builder. |
| 10 | +
|
| 11 | +Two invariants are guarded here: |
| 12 | +
|
| 13 | +1. ``SlotManager.tmp_scope`` reclaims temp tids/vids on exit (and creating a |
| 14 | + temp slot outside a scope raises), so local ids are reused. |
| 15 | +2. The serialized graph coalesces slots that share ``(id_space, idx)`` to a |
| 16 | + single global Tid/Vid, so ``num_temp_tensors`` / ``num_values`` reflect that |
| 17 | + reuse. Without this, recycled slots each get their own runtime slot (which is |
| 18 | + never freed until end-of-execution ``reset()``), inflating peak memory. This |
| 19 | + is easy to silently reintroduce (e.g. enumerating distinct Slot objects), so |
| 20 | + it is asserted directly. |
| 21 | +
|
| 22 | +Run:: |
| 23 | +
|
| 24 | + python -m unittest executorch.backends.mlx.builder.test_slot_recycling |
| 25 | +""" |
| 26 | + |
| 27 | +import unittest |
| 28 | + |
| 29 | +import torch |
| 30 | +import torch.nn as nn |
| 31 | + |
| 32 | +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder |
| 33 | +from executorch.backends.mlx.builder.slot_manager import ( |
| 34 | + IdSpace, |
| 35 | + IdType, |
| 36 | + Slot, |
| 37 | + SlotManager, |
| 38 | +) |
| 39 | +from executorch.backends.mlx.serialization.mlx_graph_schema import Tid, Vid |
| 40 | + |
| 41 | + |
| 42 | +def _trivial_ep(): |
| 43 | + """Minimal ExportedProgram just to satisfy ``MLXProgramBuilder.__init__``. |
| 44 | +
|
| 45 | + The graph is never processed; the coalescing tests drive the builder's slot |
| 46 | + bookkeeping directly. |
| 47 | + """ |
| 48 | + |
| 49 | + class _Identity(nn.Module): |
| 50 | + def forward(self, x): |
| 51 | + return x + 1 |
| 52 | + |
| 53 | + return torch.export.export(_Identity(), (torch.zeros(2),)) |
| 54 | + |
| 55 | + |
| 56 | +class TmpScopeTest(unittest.TestCase): |
| 57 | + def test_make_tmp_requires_scope(self): |
| 58 | + sm = SlotManager() |
| 59 | + with self.assertRaises(RuntimeError): |
| 60 | + sm.make_tmp_slot() |
| 61 | + with self.assertRaises(RuntimeError): |
| 62 | + sm.make_tmp_value_slot() |
| 63 | + |
| 64 | + def test_tmp_ids_reclaimed_and_reused(self): |
| 65 | + sm = SlotManager() |
| 66 | + with sm.tmp_scope(): |
| 67 | + _, a = sm.make_tmp_slot() |
| 68 | + _, b = sm.make_tmp_slot() |
| 69 | + self.assertNotEqual(a.idx, b.idx) # live simultaneously |
| 70 | + self.assertTrue(sm.is_alive(a)) |
| 71 | + # Reclaimed on exit. |
| 72 | + self.assertFalse(sm.is_alive(a)) |
| 73 | + self.assertFalse(sm.is_alive(b)) |
| 74 | + # Next scope reuses a freed idx. |
| 75 | + with sm.tmp_scope(): |
| 76 | + _, c = sm.make_tmp_slot() |
| 77 | + self.assertIn(c.idx, (a.idx, b.idx)) |
| 78 | + |
| 79 | + def test_value_slots_reclaimed(self): |
| 80 | + sm = SlotManager() |
| 81 | + with sm.tmp_scope(): |
| 82 | + _, v = sm.make_tmp_value_slot() |
| 83 | + self.assertTrue(sm.is_alive(v)) |
| 84 | + self.assertFalse(sm.is_alive(v)) |
| 85 | + |
| 86 | + def test_nested_scopes(self): |
| 87 | + sm = SlotManager() |
| 88 | + with sm.tmp_scope(): |
| 89 | + _, outer = sm.make_tmp_slot() |
| 90 | + with sm.tmp_scope(): |
| 91 | + _, inner = sm.make_tmp_slot() |
| 92 | + # Inner scope reclaimed its slot; outer slot is still live. |
| 93 | + self.assertFalse(sm.is_alive(inner)) |
| 94 | + self.assertTrue(sm.is_alive(outer)) |
| 95 | + |
| 96 | + |
| 97 | +class SlotCoalescingTest(unittest.TestCase): |
| 98 | + """Slots sharing ``(id_space, idx)`` must map to one global Tid/Vid.""" |
| 99 | + |
| 100 | + def _builder_with_slots(self, tensor_slots, value_slots): |
| 101 | + P = MLXProgramBuilder(_trivial_ep()) |
| 102 | + # Start from a clean slot table so the trivial graph's own slots don't |
| 103 | + # interfere, then register synthetic slots as if emitted by handlers. |
| 104 | + P.slot_manager = SlotManager() |
| 105 | + P._tid_slot_map = [] |
| 106 | + P._vid_slot_map = [] |
| 107 | + for i, s in enumerate(tensor_slots): |
| 108 | + P.slot_manager.name_to_slot[f"t{i}"] = s |
| 109 | + P._tid_slot_map.append((Tid(idx=None), s)) |
| 110 | + for i, s in enumerate(value_slots): |
| 111 | + P.slot_manager.name_to_slot[f"v{i}"] = s |
| 112 | + P._vid_slot_map.append((Vid(idx=None), s)) |
| 113 | + return P |
| 114 | + |
| 115 | + def test_reused_tids_coalesce(self): |
| 116 | + a = Slot(IdType.Tensor, IdSpace.Temp, 0) |
| 117 | + b = Slot(IdType.Tensor, IdSpace.Temp, 0) # reused idx 0 (disjoint life) |
| 118 | + c = Slot(IdType.Tensor, IdSpace.Temp, 1) |
| 119 | + k = Slot(IdType.Tensor, IdSpace.Constant, 0) |
| 120 | + P = self._builder_with_slots([a, b, c, k], []) |
| 121 | + |
| 122 | + used, num_tensors, _ = P._collect_used_slots() |
| 123 | + slot_to_tid, _ = P._create_slot_mappings(used) |
| 124 | + |
| 125 | + self.assertEqual(slot_to_tid[a], slot_to_tid[b], "reused idx must coalesce") |
| 126 | + self.assertNotEqual(slot_to_tid[a], slot_to_tid[c], "distinct idx stays distinct") |
| 127 | + # Counts reflect distinct (id_space, idx), not distinct Slot objects. |
| 128 | + self.assertEqual(num_tensors[IdSpace.Temp], 2) |
| 129 | + self.assertEqual(sum(num_tensors.values()), len(set(slot_to_tid.values()))) |
| 130 | + # Emitted Tid references collapse in the serialized graph too. |
| 131 | + ref = {id(s): t for t, s in P._tid_slot_map} |
| 132 | + self.assertEqual(ref[id(a)].idx, ref[id(b)].idx) |
| 133 | + self.assertNotEqual(ref[id(a)].idx, ref[id(c)].idx) |
| 134 | + |
| 135 | + def test_reused_vids_coalesce(self): |
| 136 | + # SymInt and SymBool share the vid pool, so equal idx must coalesce. |
| 137 | + v0 = Slot(IdType.SymInt, IdSpace.Temp, 0) |
| 138 | + v0b = Slot(IdType.SymBool, IdSpace.Temp, 0) |
| 139 | + v1 = Slot(IdType.SymInt, IdSpace.Temp, 1) |
| 140 | + P = self._builder_with_slots([], [v0, v0b, v1]) |
| 141 | + |
| 142 | + used, _, num_values = P._collect_used_slots() |
| 143 | + _, slot_to_vid = P._create_slot_mappings(used) |
| 144 | + |
| 145 | + self.assertEqual(slot_to_vid[v0], slot_to_vid[v0b], "shared vid idx must coalesce") |
| 146 | + self.assertNotEqual(slot_to_vid[v0], slot_to_vid[v1]) |
| 147 | + self.assertEqual(num_values[IdSpace.Temp], 2) |
| 148 | + self.assertEqual(sum(num_values.values()), len(set(slot_to_vid.values()))) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == "__main__": |
| 152 | + unittest.main() |
0 commit comments