Skip to content

Commit ca22d56

Browse files
authored
Reorder slice before view (#20240)
Differential Revision: D108217652 Pull Request resolved: #20240
1 parent a9dd615 commit ca22d56

2 files changed

Lines changed: 447 additions & 1 deletion

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 234 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from collections import defaultdict
1313
from math import prod
14-
from typing import Callable, cast, DefaultDict, List, Tuple
14+
from typing import Callable, cast, DefaultDict, List, Optional, Tuple
1515

1616
import torch
1717
import torch.fx
@@ -781,6 +781,239 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
781781
return True
782782

783783

784+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
785+
class MoveSliceBeforeViewPass(RemoveOrReplacePassInterface):
786+
"""Move a slice_copy above a view_copy when the slice is re-expressible as a
787+
single slice on one dim of the pre-view tensor.
788+
789+
Rewrites view(x) -> slice(dim=d, start, end, step) into
790+
slice(x, dim=d', start', end', step') -> view(sliced, slice_out_shape), so the
791+
slice lands directly on x. This may be useful in attention patterns, where
792+
we view outputs of a large linear into a new shape where the number of
793+
attention heads are the last dim, and we need to run independent computation
794+
per head. Moving the slice before the view can allow us to then directly slice
795+
the constant linear weights.
796+
797+
A view is a contiguous reshape: it never moves or reorders elements, it only
798+
re-groups the shared row-major index space into different dims. A slice keeps
799+
an arithmetic progression of indices (start, start+step, ...) along one viewed
800+
dim, and that progression collapses back to a *single* slice on one pre-view
801+
dim exactly when the row-major strides line up. ``_derive_pre_view_slice``
802+
handles the three cases that qualify:
803+
804+
* untouched dim: the viewed dim is left unchanged by the view -- same size
805+
and same inner stride as some pre-view dim -- so the slice copies over
806+
verbatim (any step).
807+
* contiguous: the viewed dim and a pre-view dim span the same flat extent
808+
(a split's outermost factor, or a merge that aligns), so a contiguous
809+
(step==1) slice maps to a contiguous pre-view slice.
810+
* strided: the viewed dim is an innermost factor of a pre-view dim
811+
(identical inner stride) selected width-1, so it maps to a strided
812+
pre-view slice with step == the viewed dim's size.
813+
814+
Everything else -- middle factors, wider strided selections -- is block-strided
815+
(runs separated by gaps), which no single slice can express, so it is left
816+
unchanged.
817+
818+
Each slice is handled independently, so a view that fans out to several slices
819+
is rewritten one slice at a time and the now-dead view is removed by dead-code
820+
elimination -- there is no single-user requirement on the view.
821+
"""
822+
823+
@property
824+
def targets(self) -> list[EdgeOpOverload]:
825+
return [exir_ops.edge.aten.slice_copy.Tensor]
826+
827+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
828+
view_node = get_arg(node, "input", torch.fx.Node)
829+
if view_node.target != exir_ops.edge.aten.view_copy.default:
830+
return False
831+
832+
x_node = get_arg(view_node, "input", torch.fx.Node)
833+
pre_view_shape = tuple(x_node.meta["val"].shape)
834+
post_view_shape = tuple(view_node.meta["val"].shape)
835+
if 0 in pre_view_shape or 0 in post_view_shape:
836+
return False
837+
838+
dim = get_arg(node, "dim", int)
839+
if dim < 0:
840+
dim += len(post_view_shape)
841+
post_view_size = post_view_shape[dim]
842+
843+
bounds = self._normalize_slice(node, post_view_size)
844+
if bounds is None:
845+
return False
846+
start, stop, step = bounds
847+
848+
# The slice's own output shape gives the selected-element count along the
849+
# sliced dim directly -- it is exactly output_shape[dim].
850+
slice_out_shape = tuple(node.meta["val"].shape)
851+
post_view_count = slice_out_shape[dim]
852+
if post_view_count == 0:
853+
return False
854+
855+
# Row-major stride of the sliced viewed dim, and of every pre-view dim.
856+
post_view_stride = prod(post_view_shape[dim + 1 :])
857+
pre_view_strides = self._row_major_strides(pre_view_shape)
858+
859+
derived = self._derive_pre_view_slice(
860+
pre_view_shape,
861+
pre_view_strides,
862+
post_view_stride,
863+
post_view_size,
864+
start,
865+
stop,
866+
step,
867+
post_view_count,
868+
)
869+
if derived is None:
870+
return False
871+
pre_view_dim, pre_view_start, pre_view_stop, pre_view_step = derived
872+
873+
graph = node.graph
874+
with graph.inserting_before(node):
875+
new_slice_args = (
876+
x_node,
877+
pre_view_dim,
878+
pre_view_start,
879+
pre_view_stop,
880+
pre_view_step,
881+
)
882+
new_slice = graph.create_node(
883+
"call_function",
884+
exir_ops.edge.aten.slice_copy.Tensor,
885+
args=new_slice_args,
886+
)
887+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
888+
x_node.meta["val"], *new_slice_args[1:]
889+
)
890+
new_view = graph.create_node(
891+
"call_function",
892+
exir_ops.edge.aten.view_copy.default,
893+
args=(new_slice, list(slice_out_shape)),
894+
)
895+
new_view.meta["val"] = exir_ops.edge.aten.view_copy.default(
896+
new_slice.meta["val"], list(slice_out_shape)
897+
)
898+
899+
node.replace_all_uses_with(new_view)
900+
return True
901+
902+
@staticmethod
903+
def _row_major_strides(shape: tuple[int, ...]) -> list[int]:
904+
"""Row-major (contiguous) strides for ``shape``."""
905+
strides = [1] * len(shape)
906+
acc = 1
907+
for i in range(len(shape) - 1, -1, -1):
908+
strides[i] = acc
909+
acc *= shape[i]
910+
return strides
911+
912+
def _normalize_slice(
913+
self, node: torch.fx.Node, post_view_size: int
914+
) -> Optional[tuple[int, int, int]]:
915+
"""Resolve the slice to concrete, clamped ``(start, stop, step)`` ints, or
916+
None if the bounds are dynamic or the step is non-positive (neither of
917+
which this pass handles)."""
918+
step = get_arg(node, "step")
919+
920+
if not isinstance(step, int):
921+
return None
922+
923+
if step <= 0:
924+
return None
925+
926+
raw_start = get_arg(node, "start")
927+
raw_stop = get_arg(node, "end")
928+
929+
# Make sure raw_start/raw_stop are not symbolic.
930+
if (raw_start is not None and not isinstance(raw_start, int)) or (
931+
raw_stop is not None and not isinstance(raw_stop, int)
932+
):
933+
return None
934+
935+
start = 0 if raw_start is None else raw_start
936+
stop = post_view_size if raw_stop is None else raw_stop
937+
if start < 0:
938+
start += post_view_size
939+
if stop < 0:
940+
stop += post_view_size
941+
start = max(0, min(start, post_view_size))
942+
stop = max(0, min(stop, post_view_size))
943+
return start, stop, step
944+
945+
def _derive_pre_view_slice(
946+
self,
947+
pre_view_shape: tuple[int, ...],
948+
pre_view_strides: list[int],
949+
post_view_stride: int,
950+
post_view_size: int,
951+
start: int,
952+
stop: int,
953+
step: int,
954+
post_view_count: int,
955+
) -> tuple[int, int, int, int] | None:
956+
"""Return ``(dim, start, stop, step)`` for the single pre-view-tensor slice
957+
equivalent to slicing the viewed dim, or None if no single pre-view slice
958+
reproduces it.
959+
960+
Both shapes index the same row-major flat space, so the sliced viewed dim
961+
(size ``post_view_size``, inner stride ``post_view_stride``) lines up with
962+
one pre-view dim (size ``pre_view_size``, inner stride ``pre_view_stride``)
963+
in one of three ways.
964+
"""
965+
for pre_view_dim, (pre_view_stride, pre_view_size) in enumerate(
966+
zip(pre_view_strides, pre_view_shape)
967+
):
968+
# Untouched: the viewed dim is identical to this pre-view dim (same
969+
# size and same inner stride), so the slice applies verbatim, any step.
970+
if pre_view_stride == post_view_stride and pre_view_size == post_view_size:
971+
return pre_view_dim, start, stop, step
972+
973+
# Contiguous: the viewed dim and this pre-view dim span the same flat
974+
# extent (same period), and the selected band aligns to this dim's
975+
# boundaries. A contiguous (step==1) viewed slice
976+
# [start, start+post_view_count) is the flat band [start*
977+
# post_view_stride, (start+post_view_count)*post_view_stride), a
978+
# contiguous slice on this pre-view dim iff both ends are multiples of
979+
# its stride.
980+
if (
981+
step == 1
982+
and post_view_size * post_view_stride == pre_view_size * pre_view_stride
983+
):
984+
flat_start = start * post_view_stride
985+
flat_stop = (start + post_view_count) * post_view_stride
986+
if (
987+
flat_start % pre_view_stride == 0
988+
and flat_stop % pre_view_stride == 0
989+
):
990+
return (
991+
pre_view_dim,
992+
flat_start // pre_view_stride,
993+
flat_stop // pre_view_stride,
994+
1,
995+
)
996+
997+
# Strided is the ONLY way the reshape itself introduces a stride, and
998+
# it requires a width-1 selection (post_view_count == 1): the viewed
999+
# dim is an innermost factor of this pre-view dim (identical inner
1000+
# stride), so fixing that single factor index and letting the rest of
1001+
# the pre-view dim run yields a uniform stride equal to the viewed dim's
1002+
# size. Any wider selection (post_view_count > 1) of an inner factor
1003+
# leaves runs separated by gaps -- block-strided, not a single slice --
1004+
# so width-1 is required.
1005+
if (
1006+
post_view_count == 1
1007+
and post_view_size > 1
1008+
and pre_view_stride == post_view_stride
1009+
and pre_view_size % post_view_size == 0
1010+
):
1011+
pre_view_count = pre_view_size // post_view_size
1012+
pre_view_stop = start + (pre_view_count - 1) * post_view_size + 1
1013+
return pre_view_dim, start, pre_view_stop, post_view_size
1014+
return None
1015+
1016+
7841017
@register_cadence_pass(CadencePassAttribute(opt_level=1))
7851018
class PropagateSlice(RemoveOrReplacePassInterface):
7861019
"""Propagate slice_copy before element-wise ops when the cost model

0 commit comments

Comments
 (0)