Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,113 @@ compiled[grid, stream](A.ctypes.data, O.ctypes.data, 4, 128)
**Only `entry=True` kernels support `.compile()` and `[grid, stream]` launch.**
Calling `.compile()` on an `entry=False` module raises an error.

### Loading an existing PTO file

Use `source=` when the kernel implementation already exists as hand-written PTO
IR, but you still want PTODSL's Python compile and launch workflow. The
decorated Python function declares the host-side ABI; `source=` provides the
kernel body either as a PTO file path or as PTO text directly.

<!-- ptodsl-doc-test: {"mode":"launch_fragment","fixture":"launch.source_backed_tadd","symbol":"tadd","files":{"kernels/tadd_entry.pto":"module {\n func.func @tadd_entry(%arg0: !pto.ptr<f32, gm>, %arg1: !pto.ptr<f32, gm>, %arg2: !pto.ptr<f32, gm>, %arg3: i32) {\n return\n }\n}\n"}} -->
```python
@pto.jit(
name="tadd_entry",
target="a5",
backend="vpto",
source="kernels/tadd_entry.pto",
)
def tadd(
A_ptr: pto.ptr(pto.f32, "gm"),
B_ptr: pto.ptr(pto.f32, "gm"),
O_ptr: pto.ptr(pto.f32, "gm"),
numel: pto.i32,
):
# The body is ignored when source= is provided.
pass


compiled = tadd.compile()
compiled[grid, stream](A, B, O, numel)
```

The Python function body is not traced in this form. Keep it empty, or leave a
short comment for readers. Positional parameters still matter: PTODSL uses them
to build the launch wrapper and marshal Python, NumPy, or torch-npu arguments.

The PTO source must contain one non-declaration `func.func` whose symbol matches
the JIT entry name. By default, the entry name is the Python function name. Use
`name=` when the PTO symbol has a different name, or when the source contains more
than one kernel:

```mlir
module {
func.func @tadd_entry(
%a: !pto.ptr<f32, gm>,
%b: !pto.ptr<f32, gm>,
%o: !pto.ptr<f32, gm>,
%numel: i32) {
// hand-written PTO body
return
}
}
```

PTODSL checks the selected PTO function before compiling:

- The number of PTO function arguments must match the Python positional
parameters.
- Each argument type must match the Python annotation, position by position.
- The PTO entry must return no values.
- If the file or entry cannot be found, the diagnostic names the requested entry
and source path.

When `source` is a filesystem path, relative paths are resolved from the Python
file that declares the decorated function, so tests can keep the Python wrapper
next to the PTO file:

```text
case.py
kernels/tadd_entry.pto
```

```python
# case.py
@pto.jit(name="tadd_entry", source="kernels/tadd_entry.pto")
def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")):
pass
```

For short tests, `source` can also embed the PTO text directly:

```python
tadd_source = """module {
func.func @tadd_entry(%a: !pto.ptr<f32, gm>, %o: !pto.ptr<f32, gm>) {
return
}
}
"""

@pto.jit(name="tadd_entry", source=tadd_source)
def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")):
pass
```

Source-backed entries use the same `.compile()` and `compiled[grid, stream](...)`
launch syntax as ordinary traced entries. If the PTO file contents change,
compiling the same declaration again rebuilds the cached artifact.

Limitations:

- `source=` is only supported for launchable `entry=True` kernels.
- Keyword-only `pto.const_expr` parameters are not supported with `source=`.
Source files are loaded as fixed PTO IR text; PTODSL does not template or
specialize the source file.
- `.compile(...)` does not accept constexpr bindings for source-backed kernels.
- `backend=`, `mode=`, and `insert_sync=` still matter. For source-backed VPTO
files, set `backend="vpto"`. When `mode="explicit"`, PTODSL compiles the
source as explicit PTO; otherwise sync insertion follows the same policy as
ordinary `@pto.jit` entries.

### SPMD built-ins

Available inside an `entry=True` body:
Expand Down
75 changes: 75 additions & 0 deletions ptodsl/ptodsl/_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,75 @@ def kernel_module_launch_error(function_name: str | None = None) -> RuntimeError
)


def jit_source_entry_false_error(
source: object,
*,
function_name: str | None = None,
) -> TypeError:
"""Return one diagnostic for unsupported ``@pto.jit(entry=False, source=...)``."""
target = "@pto.jit(source=...) kernel"
if function_name:
target = f"@pto.jit(source=...) kernel {function_name!r}"
return TypeError(
f"{target} does not support entry=False while source={source!r}. "
"Source-backed JIT is currently limited to launchable entry kernels."
)


def jit_source_constexpr_error(
name: str,
source: object,
*,
function_name: str | None = None,
) -> TypeError:
"""Return one diagnostic for unsupported source-backed ``pto.const_expr`` params."""
target = "@pto.jit(source=...) kernel"
if function_name:
target = f"@pto.jit(source=...) kernel {function_name!r}"
return TypeError(
f"{target} does not support keyword-only pto.const_expr parameter '{name}' while source={source!r}. "
"Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text."
)


def jit_source_compile_constexpr_error(
names: list[str] | tuple[str, ...],
source: object,
*,
function_name: str | None = None,
) -> TypeError:
"""Return one diagnostic for ``.compile(...)`` constexpr bindings in source mode."""
target = "@pto.jit(source=...) kernel"
if function_name:
target = f"@pto.jit(source=...) kernel {function_name!r}"
joined = ", ".join(names)
return TypeError(
f"{target} does not accept .compile(...) constexpr binding(s) {joined} while source={source!r}. "
"Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text."
)


def jit_source_file_error(source: object, resolved_path: object, reason: str) -> FileNotFoundError:
"""Return one diagnostic for source path resolution/loading failures."""
return FileNotFoundError(
f"@pto.jit(source={source!r}) could not load PTO IR source file {str(resolved_path)!r}: {reason}"
)


def jit_source_entry_error(source_path: object, entry_name: str, reason: str) -> TypeError:
"""Return one diagnostic for source entry selection failures."""
return TypeError(
f"@pto.jit(source=...) could not bind entry {entry_name!r} in {str(source_path)!r}: {reason}"
)


def jit_source_abi_error(source_path: object, entry_name: str, reason: str) -> TypeError:
"""Return one diagnostic for source ABI verification failures."""
return TypeError(
f"@pto.jit(source=...) ABI mismatch for entry {entry_name!r} in {str(source_path)!r}: {reason}"
)


def jit_keyword_only_non_constexpr_error(name: str, annotation: object) -> TypeError:
"""Return one diagnostic for keyword-only params that are not ``pto.const_expr``."""
return TypeError(
Expand Down Expand Up @@ -419,6 +488,12 @@ def unsupported_public_surface_error(name: str) -> AttributeError:
"kernel_module_compile_error",
"kernel_module_launch_error",
"kernel_module_return_value_error",
"jit_source_abi_error",
"jit_source_compile_constexpr_error",
"jit_source_constexpr_error",
"jit_source_entry_false_error",
"jit_source_entry_error",
"jit_source_file_error",
"invalid_jit_mode_error",
"invalid_jit_backend_error",
"jit_legacy_tensor_spec_helper_error",
Expand Down
20 changes: 20 additions & 0 deletions ptodsl/ptodsl/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ._diagnostics import (
invalid_jit_backend_error,
invalid_jit_mode_error,
jit_source_constexpr_error,
jit_source_entry_false_error,
kernel_module_launch_error,
)
from ._kernel_compilation import CompiledKernelHandle, KernelCompiler
Expand Down Expand Up @@ -169,6 +171,7 @@ def jit(
insert_sync: bool | None = None,
ast_rewrite: bool | None = None,
frontend_options: Mapping | None = None,
source: str | None = None,
):
"""
Decorator that wraps a Python function as a PTODSL JIT kernel template.
Expand All @@ -195,6 +198,11 @@ def jit(
frontend_options:
Reserved structured frontend options. Currently supports
``ast_rewrite`` and ``rewrite_part={"control_flow"}``.
source:
Optional filesystem path to a PTO IR source file. When
provided, PTODSL keeps the Python signature as the host ABI
declaration but loads the kernel implementation from the
source file instead of tracing the Python body.

The decorated function is replaced by a :class:`KernelHandle` that:

Expand All @@ -214,6 +222,17 @@ def decorator(fn):
kernel_signature = parse_jit_kernel_signature(fn, entry=entry)
normalized_mode = _normalize_mode(mode, fn=fn)
normalized_backend = _normalize_backend(backend, fn=fn)
if source is not None:
if not isinstance(source, str):
raise TypeError("@pto.jit source must be a filesystem path string when provided")
if entry is False:
raise jit_source_entry_false_error(source, function_name=fn_name)
if kernel_signature.constexpr_parameters:
raise jit_source_constexpr_error(
kernel_signature.constexpr_parameters[0].name,
source,
function_name=fn_name,
)
source_file = None
try:
source_file = inspect.getsourcefile(fn) or inspect.getfile(fn)
Expand All @@ -232,6 +251,7 @@ def decorator(fn):
module_style=ModuleStyle.BACKEND_PARTITIONED,
source_file=source_file,
source_line=getattr(fn.__code__, "co_firstlineno", None),
jit_source=source,
),
kernel_signature,
fn,
Expand Down
39 changes: 38 additions & 1 deletion ptodsl/ptodsl/_kernel_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
import inspect

from ._ast_rewrite import rewrite_jit_function
from ._diagnostics import kernel_module_compile_error, kernel_module_launch_error
from ._diagnostics import (
jit_source_compile_constexpr_error,
kernel_module_compile_error,
kernel_module_launch_error,
)
from ._runtime.launch import LaunchHandle, parse_launch_spec
from ._source_loader import SourceModuleLoader
from ._tracing import ModuleArtifact, SignatureTracingRuntime


Expand Down Expand Up @@ -89,6 +94,8 @@ def tracing_callback(self):
def compile(self, **constexpr_bindings):
if self._module_spec.entry is False:
raise kernel_module_compile_error(self._py_name)
if self._module_spec.jit_source is not None:
return self._compile_source_backed(**constexpr_bindings)
normalized_bindings = self._kernel_signature.bind_constexpr_bindings(constexpr_bindings)
kernel_identity = self._kernel_identity
if self._ast_rewrite:
Expand Down Expand Up @@ -124,6 +131,36 @@ def compile(self, **constexpr_bindings):
self._compiled_cache[specialization_key] = compiled
return compiled

def _compile_source_backed(self, **constexpr_bindings):
if constexpr_bindings:
raise jit_source_compile_constexpr_error(
tuple(sorted(constexpr_bindings)),
self._module_spec.jit_source,
function_name=self._module_spec.function_name,
)

loader = SourceModuleLoader(self._module_spec, self._kernel_signature)
specialization_key = self._kernel_signature.specialization_key(
loader.cache_identity(),
{},
)

cached = self._compiled_cache.get(specialization_key)
if cached is not None:
return cached

compiled = CompiledKernelHandle(
self._py_name,
specialization_key=specialization_key,
constexpr_bindings={},
module_factory=loader.build_module,
module_spec=self._module_spec,
kernel_signature=self._kernel_signature,
)
compiled.build()
self._compiled_cache[specialization_key] = compiled
return compiled

def cached_specializations(self):
return tuple(self._compiled_cache.values())

Expand Down
16 changes: 16 additions & 0 deletions ptodsl/ptodsl/_runtime/native_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,18 @@ def _run_ptoas(
*,
target_arch: str,
insert_sync: bool | None = None,
backend: str | None = None,
pto_level: str | None = None,
) -> None:
ptoas = resolve_ptoas_binary()
cmd = [
str(ptoas),
f"--pto-arch={target_arch}",
]
if backend is not None:
cmd.append(f"--pto-backend={backend}")
if pto_level is not None:
cmd.append(f"--pto-level={pto_level}")
if insert_sync is True:
cmd.append("--enable-insert-sync")
cmd.extend([
Expand All @@ -70,6 +76,15 @@ def _effective_insert_sync(*, mode: str, insert_sync: bool | None) -> bool:
return mode != "explicit"


def _source_ptoas_overrides(module_spec) -> dict:
if getattr(module_spec, "jit_source", None) is None:
return {}
overrides = {"backend": module_spec.backend}
if module_spec.mode == "explicit":
overrides["pto_level"] = "level3"
return overrides


def _host_compile_flags() -> list[str]:
return common_include_flags() + [
"-std=gnu++17",
Expand Down Expand Up @@ -199,6 +214,7 @@ def build_native_library(
mode=module_spec.mode,
insert_sync=module_spec.insert_sync,
),
**_source_ptoas_overrides(module_spec),
)

launch_object = artifacts.cache_dir / "launch.o"
Expand Down
Loading
Loading