Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion autoparallel/graph_passes/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,24 @@ def _execute_graph(
if not hasattr(gm, "_compiled"):
from torch._inductor.compile_fx import compile_fx_inner

gm._compiled = compile_fx_inner(gm, args) # type: ignore[assignment, attr-defined]
# If gm has a shape_env (set by pp_joint_graph_builder), convert
# example_inputs to fake tensors so compile_fx_inner's
# fake_tensor_prop picks up the shape_env's allow_scalar_outputs.
# This enables Inductor compilation of graphs containing
# _local_scalar_dense (from data-dependent EP routing).
compile_args = args
shape_env = getattr(gm, "shape_env", None)
if shape_env is not None:
import torch
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode(shape_env=shape_env)
compile_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]

gm._compiled = compile_fx_inner(gm, compile_args) # type: ignore[assignment, attr-defined]
return gm._compiled(args) # type: ignore[operator, attr-defined]
return fx.Interpreter(gm).boxed_run(args)

Expand Down
Loading