Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
InsertConstShapesPass,
InsertControlFlowRescalesPass,
InsertDataLayoutCastsPass,
InsertDynamicPaddingPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
InsertRescaleInt32Pass,
InsertRescalePass,
Expand Down Expand Up @@ -630,6 +631,7 @@ def _tosa_pipeline(
[
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
InsertDynamicPaddingPass(),
FuseConsecutiveConcatShapesPass(),
EnsureUniqueOutputNodesPass(),
RemoveNoopPass(),
Expand Down
14 changes: 8 additions & 6 deletions backends/arm/_passes/insert_dynamic_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class InsertDynamicPaddingPass(ArmOpTargetedPass):
_passes_required_after: Set[Type[ExportPass]] = set()
target_ops = (
exir_ops.backend.tosa.CONV2D.default,
exir_ops.backend.tosa.CONV3D.default,
exir_ops.backend.tosa.DEPTHWISE_CONV2D.default,
exir_ops.backend.tosa.MAX_POOL2D.default,
exir_ops.backend.tosa.AVG_POOL2D.default,
Expand Down Expand Up @@ -57,11 +58,12 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
if not self._is_dynamic_padding(padding):
return super().call_operator(op, args, kwargs, meta, updated)

# Create a pad op before conv2d
# Create a pad op before the convolution/pool op.
input_tensor = args[0]

zero_padding_pair = [0, 0]
zero_spatial_padding = [0, 0, 0, 0]
spatial_rank = 3 if op == exir_ops.backend.tosa.CONV3D.default else 2
zero_spatial_padding = [0] * (spatial_rank * 2)
N_padding = super().call_shape_operator(
exir_ops.backend.tosa.CONST_SHAPE.default,
(zero_padding_pair,),
Expand Down Expand Up @@ -93,7 +95,7 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue:
meta,
True,
)
new_conv2d_args = list(args)
new_conv2d_args[0] = pad_res
new_conv2d_args[padding_index] = zero_spatial_padding
return super().call_operator(op, tuple(new_conv2d_args), kwargs, meta, updated)
new_args = list(args)
new_args[0] = pad_res
new_args[padding_index] = zero_spatial_padding
return super().call_operator(op, tuple(new_args), kwargs, meta, updated)
26 changes: 14 additions & 12 deletions backends/arm/_passes/rewrite_conv_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,25 @@ def _adjust_pad_if_needed(

if isinstance(mod_remainder, torch.SymInt):
shape_env = get_context_shape_env()
exact_values = evaluate_symbolic_expr_values(
mod_remainder.node.expr, shape_env
)
exact_values = evaluate_symbolic_expr_values(mod_remainder, shape_env)
if exact_values is not None:
mod_remainder_upper = max(exact_values)
if len(exact_values) == 1:
mod_remainder = int(next(iter(exact_values)))
elif mod_remainder_upper == 0:
mod_remainder = 0
else:
return pad - mod_remainder
else:
value_ranges = shape_env.bound_sympy(mod_remainder.node.expr)
mod_remainder_upper = int(value_ranges.upper)
if mod_remainder_upper == 0:
mod_remainder = 0
else:
mod_remainder_upper = mod_remainder

if mod_remainder_upper > pad:
# SizeAdjustInputPass already trims symbolic remainder classes
# that would force negative padding. Keep the symbolic
# expression here instead of asking ShapeEnv to normalize it.
return pad - mod_remainder
if mod_remainder > pad:
raise RuntimeError(
"This case should be handled by the SizeAdjustInputPass, is it enabled?\n"
"This case should be handled by SizeAdjustInputPass, is it enabled?\n"
)

return pad - mod_remainder

def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
Expand Down
43 changes: 41 additions & 2 deletions backends/arm/_passes/size_adjust_input_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,41 @@ def _greater_than(input: SymIntLike, other: int) -> bool | torch.SymBool:
return input > other


def _get_slice_adjustment(
remainder: SymIntLike,
pad: int,
stride: int,
) -> SymIntLike | None:
"""Return the amount to slice from the end of a conv dimension.

The required trim is ``max(remainder - pad, 0)``. For symbolic shapes we
encode that clamp using only integer arithmetic that the TOSA shape
materializer already supports: a sum of floor-div terms over the possible
residue classes.

"""
if not isinstance(remainder, torch.SymInt):
return remainder - pad if remainder > pad else None

shape_env = get_context_shape_env()
exact_values = evaluate_symbolic_expr_values(remainder.node.expr, shape_env)
if exact_values is not None:
adjustments = {max(value - pad, 0) for value in exact_values}
if len(adjustments) == 1:
adjustment = next(iter(adjustments))
return adjustment if adjustment > 0 else None

if pad >= stride - 1:
return None

adjustment: SymIntLike | None = None # type: ignore[no-redef]
for threshold in range(pad + 1, stride):
term = (remainder + stride - threshold) // stride
adjustment = term if adjustment is None else adjustment + term

return adjustment


def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
slices: Slices = []

Expand All @@ -85,8 +120,12 @@ def get_slices_convolution(conv_node: torch.fx.Node) -> Slices:
remainder = conv_remainder(
input_shape[dim], pad, dilation, weight_shape[dim], stride
)
if _greater_than(remainder, pad):
adjustment = remainder - pad
adjustment = _get_slice_adjustment(
remainder,
pad,
stride,
)
if adjustment is not None:
args = (dim, 0, input_shape[dim] - adjustment)
slices.append(args)

Expand Down
113 changes: 82 additions & 31 deletions backends/arm/_passes/symbolic_value_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,70 @@ def _symbol_values(symbol: sympy.Symbol, shape_env: ShapeEnv) -> _ExactValues:
return frozenset(sympy.Integer(value) for value in range(lower, upper + 1))


def _expr_symbols_to_values(
expr: sympy.Basic,
shape_env: ShapeEnv,
) -> dict[sympy.Symbol, _ExactValues]:
return {symbol: _symbol_values(symbol, shape_env) for symbol in expr.free_symbols}


def _try_expr_to_int(expr: sympy.Basic) -> Optional[int]:
integer_value = _expr_to_int(expr)
if integer_value is not None:
return integer_value

try:
return _expr_to_int(sympy.simplify(expr))
except (RecursionError, TypeError):
return None


def _constant_expr_values(expr: sympy.Basic) -> Optional[set[int]]:
if expr.free_symbols:
return None

integer_value = _try_expr_to_int(expr)
return {integer_value} if integer_value is not None else None


def _evaluate_exact_values(
expr: sympy.Basic,
shape_env: ShapeEnv,
) -> _ExactValues:
try:
return sympy_interp(
_ExactValueAnalysis,
_expr_symbols_to_values(expr, shape_env),
expr,
missing_handler=lambda symbol: _symbol_values(symbol, shape_env),
)
except (RecursionError, TypeError):
return None


def _exact_values_to_ints(exact_values: _ExactValues) -> Optional[set[int]]:
if exact_values is None:
return None

result: set[int] = set()
for value in exact_values:
integer_value = _try_expr_to_int(value)
if integer_value is None:
return None
result.add(integer_value)
return result


def _map_values(values: _ExactValues, fn) -> _ExactValues:
if values is None:
return None

result = {sympy.simplify(fn(value)) for value in values}
result = set()
for value in values:
try:
result.add(fn(value))
except (RecursionError, TypeError):
return None
if len(result) > _MAX_SET_SIZE:
return None
return frozenset(result)
Expand All @@ -55,7 +114,13 @@ def _combine_values(lhs: _ExactValues, rhs: _ExactValues, fn) -> _ExactValues:
if len(lhs) * len(rhs) > _MAX_SET_SIZE * _MAX_SET_SIZE:
return None

result = {sympy.simplify(fn(a, b)) for a in lhs for b in rhs}
result = set()
for a in lhs:
for b in rhs:
try:
result.add(fn(a, b))
except (RecursionError, TypeError):
return None
if len(result) > _MAX_SET_SIZE:
return None
return frozenset(result)
Expand All @@ -80,6 +145,12 @@ def mod(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
return None
return _combine_values(lhs, rhs, lambda a, b: sympy.Mod(a, b))

@staticmethod
def floordiv(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
if rhs is None or any(value == 0 for value in rhs):
return None
return _combine_values(lhs, rhs, lambda a, b: sympy.floor(a / b))

@staticmethod
def pow(lhs: _ExactValues, rhs: _ExactValues) -> _ExactValues:
return _combine_values(lhs, rhs, lambda a, b: a**b)
Expand All @@ -104,35 +175,15 @@ def evaluate_symbolic_expr_values(
) -> Optional[set[int]]:
"""Return a best-effort finite set of possible integer values.

The helper first relies on ``bound_sympy`` for cheap singleton detection.
When interval bounds are not precise enough, it falls back to a small
exact-set analysis over bounded symbols using ``sympy_interp``.
The helper avoids ShapeEnv bound queries here because some exported dynamic
expressions trigger very deep SymPy normalization. Instead, it relies on a
small exact-set analysis over bounded symbols using ``sympy_interp``.

"""
root_expr = sympy.simplify(
expr.node.expr if isinstance(expr, torch.SymInt) else expr
)
value_range = shape_env.bound_sympy(root_expr)
if value_range.is_int and value_range.is_singleton():
singleton = _expr_to_int(value_range.lower)
return {singleton} if singleton is not None else None

exact_values = sympy_interp(
_ExactValueAnalysis,
{
symbol: _symbol_values(symbol, shape_env)
for symbol in root_expr.free_symbols
},
root_expr,
missing_handler=lambda symbol: _symbol_values(symbol, shape_env),
)
if exact_values is None:
return None
root_expr = expr.node.expr if isinstance(expr, torch.SymInt) else expr

result: set[int] = set()
for value in exact_values:
integer_value = _expr_to_int(sympy.simplify(value))
if integer_value is None:
return None
result.add(integer_value)
return result
constant_values = _constant_expr_values(root_expr)
if constant_values is not None:
return constant_values

return _exact_values_to_ints(_evaluate_exact_values(root_expr, shape_env))
Loading
Loading