Skip to content

backend: resolve pl.dynamic dims in the distributed host orchestrator#1871

Draft
zhaozhaozz wants to merge 1 commit into
hw-native-sys:mainfrom
zhaozhaozz:fix/host-orch-dynamic-dims
Draft

backend: resolve pl.dynamic dims in the distributed host orchestrator#1871
zhaozhaozz wants to merge 1 commit into
hw-native-sys:mainfrom
zhaozhaozz:fix/host-orch-dynamic-dims

Conversation

@zhaozhaozz

Copy link
Copy Markdown
Contributor

Draft / WIP — companion to a still-in-progress DeepSeek V4 packed
chunked-prefill change in pypto-lib. The fix itself is self-contained, but
it is gated on that consuming feature and is posted here for review/visibility.

Problem

The distributed host orchestrator generated for an @pl.jit.host wrapper
slices per-rank inputs, e.g.

tensors["x_hc__ssa_v0"][rank, 0:DEEPSEEK_PREFILL_TOKENS_DYN, 0:4, 0:4096]

When the host function carries pl.dynamic() dims, the dim symbol
(DEEPSEEK_PREFILL_TOKENS_DYN) is a runtime-only value that is never bound
anywhere in the generated orchestration/host_orch.py, so executing the
orchestrator fails with NameError: name '...DYN' is not defined.

The device-side codegen already recovers dynamic dims from runtime tensor
shapes (_append_dynamic_dim_unpacking); the Python host orchestrator had
no equivalent. No existing distributed kernel combines @pl.jit.host with
pl.dynamic, so this path was previously untested.

Fix

Add a post-processing step in _generate_with_distributed
(python/pypto/backend/pto_backend.py) that, for each orchestrator function,
emits at the top of the body — once per dynamic-dim symbol —

<DIM> = tensors["<key>"].shape[<dim>]

derived from the first generated slice that uses the symbol (so it is robust
to the symbol's emitted name).

Scope / risk

  • Additive; operates on the generated host-orchestration source string.
  • No-op when the orchestrator has no pl.dynamic() dims (the common case),
    so existing distributed kernels are unaffected.

Validation

Built locally; the generated host_orch.py now defines the dynamic dims and
the distributed program dispatches across ranks. Exercised end-to-end by the
companion pypto-lib packed chunked-prefill kernel on a2a3 (2-card EP2).

The distributed host orchestrator (orchestration/host_orch.py) slices
per-rank inputs such as tensors["x_hc"][rank, 0:DIM, ...]. When a host
function carries pl.dynamic() dims, DIM is a runtime-only symbol that is
never bound in the generated Python, so executing the orchestrator raises
NameError.

The device-side codegen already recovers dynamic dims from runtime tensor
shapes (_append_dynamic_dim_unpacking). Add the equivalent for the Python
host orchestrator: emit, once per symbol, <DIM> = tensors["<key>"].shape[<dim>]
derived from the first slice that uses it, at the top of each orchestrator
function body. No-op when the orchestrator has no dynamic dims (the common
case), so existing distributed kernels are unaffected.

This unblocks @pl.jit.host wrappers that loop over ranks and slice
[N_RANKS, PREFILL_TOKENS_DYN, ...] packed tensors (e.g. DeepSeek V4 packed
chunked prefill).
@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 7f2e46c0-1d6c-4371-a1fd-3e7faef49c81

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a helper function _resolve_host_orch_dynamic_dims to extract and define pl.dynamic() dimension symbols at the top of host-orchestrator functions, preventing NameError during execution. The reviewer identified a critical issue where symbols are resolved globally and injected into all functions, which can lead to runtime KeyErrors if a function does not receive the specific tensor in its local scope. A detailed code suggestion was provided to resolve and inject these symbols on a per-function basis.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +1157 to +1175
resolved: "OrderedDict[str, str]" = OrderedDict()
for key, idx_str in slice_re.findall(orch_code):
for dim_i, comp in enumerate(c.strip() for c in idx_str.split(",")):
m = bound_re.match(comp)
if m and m.group(1) not in resolved:
resolved[m.group(1)] = f'tensors["{key}"].shape[{dim_i}]'
if not resolved:
return orch_code

def_re = re.compile(r"^(\s*)def\s+\w+\(orch\b.*\):\s*$")
out: list[str] = []
for line in orch_code.split("\n"):
out.append(line)
m = def_re.match(line)
if m:
indent = m.group(1) + " "
out.append(f"{indent}# Recover pl.dynamic() dims from runtime tensor shapes for host slicing.")
out.extend(f"{indent}{name} = {expr}" for name, expr in resolved.items())
return "\n".join(out)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Currently, _resolve_host_orch_dynamic_dims resolves all dynamic dimension symbols globally across the entire orch_code and injects all of them into every function matching def_re.

If the generated orchestrator file contains multiple functions (e.g., multiple orchestrators or helper functions) that operate on different sets of tensors, this global injection will define variables using tensors that may not exist in the local tensors dictionary of a specific function. This will raise a KeyError at runtime when that function is executed.

To prevent this, we should resolve and inject the dynamic dimension symbols per function by scanning the slices within each function's body and only injecting the symbols that are actually used in that specific function.

    def_re = re.compile(r"^(\s*)def\s+\w+\(orch\b.*\):\s*$")
    func_headers = {}
    func_resolved = {}
    current_func_idx = None

    lines = orch_code.split("\n")
    for line_i, line in enumerate(lines):
        m_def = def_re.match(line)
        if m_def:
            current_func_idx = line_i
            func_headers[line_i] = m_def.group(1)
            func_resolved[line_i] = OrderedDict()
            continue

        if current_func_idx is not None:
            if line.strip() and not line.startswith(" ") and not line.startswith("\t"):
                current_func_idx = None
            else:
                for key, idx_str in slice_re.findall(line):
                    for dim_i, comp in enumerate(c.strip() for c in idx_str.split(",")):
                        m = bound_re.match(comp)
                        if m and m.group(1) not in func_resolved[current_func_idx]:
                            func_resolved[current_func_idx][m.group(1)] = f'tensors["{key}"].shape[{dim_i}]'

    if not any(func_resolved.values()):
        return orch_code

    out: list[str] = []
    for line_i, line in enumerate(lines):
        out.append(line)
        if line_i in func_headers and func_resolved[line_i]:
            indent = func_headers[line_i] + "    "
            out.append(f"{indent}# Recover pl.dynamic() dims from runtime tensor shapes for host slicing.")
            out.extend(f"{indent}{name} = {expr}" for name, expr in func_resolved[line_i].items())
    return "\n".join(out)

@zhaozhaozz

Copy link
Copy Markdown
Contributor Author

Consumed by hw-native-sys/pypto-lib#625 (DeepSeek V4 packed chunked prefill, WIP) — that kernel is the first @pl.jit.host + pl.dynamic user and needs this host-orchestrator dynamic-dim binding to dispatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant