diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py index 02c1c0ee..48aa4451 100644 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ b/autoparallel/graph_passes/graph_pp_runner.py @@ -36,7 +36,24 @@ def _execute_graph( if not hasattr(gm, "_compiled"): from torch._inductor.compile_fx import compile_fx_inner - gm._compiled = compile_fx_inner(gm, args) # type: ignore[assignment, attr-defined] + # If gm has a shape_env (set by pp_joint_graph_builder), convert + # example_inputs to fake tensors so compile_fx_inner's + # fake_tensor_prop picks up the shape_env's allow_scalar_outputs. + # This enables Inductor compilation of graphs containing + # _local_scalar_dense (from data-dependent EP routing). + compile_args = args + shape_env = getattr(gm, "shape_env", None) + if shape_env is not None: + import torch + from torch._subclasses.fake_tensor import FakeTensorMode + + fake_mode = FakeTensorMode(shape_env=shape_env) + compile_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + + gm._compiled = compile_fx_inner(gm, compile_args) # type: ignore[assignment, attr-defined] return gm._compiled(args) # type: ignore[operator, attr-defined] return fx.Interpreter(gm).boxed_run(args)