backend: resolve pl.dynamic dims in the distributed host orchestrator#1871
backend: resolve pl.dynamic dims in the distributed host orchestrator#1871zhaozhaozz wants to merge 1 commit into
Conversation
The distributed host orchestrator (orchestration/host_orch.py) slices per-rank inputs such as tensors["x_hc"][rank, 0:DIM, ...]. When a host function carries pl.dynamic() dims, DIM is a runtime-only symbol that is never bound in the generated Python, so executing the orchestrator raises NameError. The device-side codegen already recovers dynamic dims from runtime tensor shapes (_append_dynamic_dim_unpacking). Add the equivalent for the Python host orchestrator: emit, once per symbol, <DIM> = tensors["<key>"].shape[<dim>] derived from the first slice that uses it, at the top of each orchestrator function body. No-op when the orchestrator has no dynamic dims (the common case), so existing distributed kernels are unaffected. This unblocks @pl.jit.host wrappers that loop over ranks and slice [N_RANKS, PREFILL_TOKENS_DYN, ...] packed tensors (e.g. DeepSeek V4 packed chunked prefill).
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a helper function _resolve_host_orch_dynamic_dims to extract and define pl.dynamic() dimension symbols at the top of host-orchestrator functions, preventing NameError during execution. The reviewer identified a critical issue where symbols are resolved globally and injected into all functions, which can lead to runtime KeyErrors if a function does not receive the specific tensor in its local scope. A detailed code suggestion was provided to resolve and inject these symbols on a per-function basis.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| resolved: "OrderedDict[str, str]" = OrderedDict() | ||
| for key, idx_str in slice_re.findall(orch_code): | ||
| for dim_i, comp in enumerate(c.strip() for c in idx_str.split(",")): | ||
| m = bound_re.match(comp) | ||
| if m and m.group(1) not in resolved: | ||
| resolved[m.group(1)] = f'tensors["{key}"].shape[{dim_i}]' | ||
| if not resolved: | ||
| return orch_code | ||
|
|
||
| def_re = re.compile(r"^(\s*)def\s+\w+\(orch\b.*\):\s*$") | ||
| out: list[str] = [] | ||
| for line in orch_code.split("\n"): | ||
| out.append(line) | ||
| m = def_re.match(line) | ||
| if m: | ||
| indent = m.group(1) + " " | ||
| out.append(f"{indent}# Recover pl.dynamic() dims from runtime tensor shapes for host slicing.") | ||
| out.extend(f"{indent}{name} = {expr}" for name, expr in resolved.items()) | ||
| return "\n".join(out) |
There was a problem hiding this comment.
Currently, _resolve_host_orch_dynamic_dims resolves all dynamic dimension symbols globally across the entire orch_code and injects all of them into every function matching def_re.
If the generated orchestrator file contains multiple functions (e.g., multiple orchestrators or helper functions) that operate on different sets of tensors, this global injection will define variables using tensors that may not exist in the local tensors dictionary of a specific function. This will raise a KeyError at runtime when that function is executed.
To prevent this, we should resolve and inject the dynamic dimension symbols per function by scanning the slices within each function's body and only injecting the symbols that are actually used in that specific function.
def_re = re.compile(r"^(\s*)def\s+\w+\(orch\b.*\):\s*$")
func_headers = {}
func_resolved = {}
current_func_idx = None
lines = orch_code.split("\n")
for line_i, line in enumerate(lines):
m_def = def_re.match(line)
if m_def:
current_func_idx = line_i
func_headers[line_i] = m_def.group(1)
func_resolved[line_i] = OrderedDict()
continue
if current_func_idx is not None:
if line.strip() and not line.startswith(" ") and not line.startswith("\t"):
current_func_idx = None
else:
for key, idx_str in slice_re.findall(line):
for dim_i, comp in enumerate(c.strip() for c in idx_str.split(",")):
m = bound_re.match(comp)
if m and m.group(1) not in func_resolved[current_func_idx]:
func_resolved[current_func_idx][m.group(1)] = f'tensors["{key}"].shape[{dim_i}]'
if not any(func_resolved.values()):
return orch_code
out: list[str] = []
for line_i, line in enumerate(lines):
out.append(line)
if line_i in func_headers and func_resolved[line_i]:
indent = func_headers[line_i] + " "
out.append(f"{indent}# Recover pl.dynamic() dims from runtime tensor shapes for host slicing.")
out.extend(f"{indent}{name} = {expr}" for name, expr in func_resolved[line_i].items())
return "\n".join(out)|
Consumed by hw-native-sys/pypto-lib#625 (DeepSeek V4 packed chunked prefill, WIP) — that kernel is the first |
Problem
The distributed host orchestrator generated for an
@pl.jit.hostwrapperslices per-rank inputs, e.g.
When the host function carries
pl.dynamic()dims, the dim symbol(
DEEPSEEK_PREFILL_TOKENS_DYN) is a runtime-only value that is never boundanywhere in the generated
orchestration/host_orch.py, so executing theorchestrator fails with
NameError: name '...DYN' is not defined.The device-side codegen already recovers dynamic dims from runtime tensor
shapes (
_append_dynamic_dim_unpacking); the Python host orchestrator hadno equivalent. No existing distributed kernel combines
@pl.jit.hostwithpl.dynamic, so this path was previously untested.Fix
Add a post-processing step in
_generate_with_distributed(
python/pypto/backend/pto_backend.py) that, for each orchestrator function,emits at the top of the body — once per dynamic-dim symbol —
derived from the first generated slice that uses the symbol (so it is robust
to the symbol's emitted name).
Scope / risk
pl.dynamic()dims (the common case),so existing distributed kernels are unaffected.
Validation
Built locally; the generated
host_orch.pynow defines the dynamic dims andthe distributed program dispatches across ranks. Exercised end-to-end by the
companion pypto-lib packed chunked-prefill kernel on a2a3 (2-card EP2).