diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 7bb02647c..61e4b8e50 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -224,6 +224,113 @@ compiled[grid, stream](A.ctypes.data, O.ctypes.data, 4, 128) **Only `entry=True` kernels support `.compile()` and `[grid, stream]` launch.** Calling `.compile()` on an `entry=False` module raises an error. +### Loading an existing PTO file + +Use `source=` when the kernel implementation already exists as hand-written PTO +IR, but you still want PTODSL's Python compile and launch workflow. The +decorated Python function declares the host-side ABI; `source=` provides the +kernel body either as a PTO file path or as PTO text directly. + + +```python +@pto.jit( + name="tadd_entry", + target="a5", + backend="vpto", + source="kernels/tadd_entry.pto", +) +def tadd( + A_ptr: pto.ptr(pto.f32, "gm"), + B_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), + numel: pto.i32, +): + # The body is ignored when source= is provided. + pass + + +compiled = tadd.compile() +compiled[grid, stream](A, B, O, numel) +``` + +The Python function body is not traced in this form. Keep it empty, or leave a +short comment for readers. Positional parameters still matter: PTODSL uses them +to build the launch wrapper and marshal Python, NumPy, or torch-npu arguments. + +The PTO source must contain one non-declaration `func.func` whose symbol matches +the JIT entry name. By default, the entry name is the Python function name. Use +`name=` when the PTO symbol has a different name, or when the source contains more +than one kernel: + +```mlir +module { + func.func @tadd_entry( + %a: !pto.ptr, + %b: !pto.ptr, + %o: !pto.ptr, + %numel: i32) { + // hand-written PTO body + return + } +} +``` + +PTODSL checks the selected PTO function before compiling: + +- The number of PTO function arguments must match the Python positional + parameters. +- Each argument type must match the Python annotation, position by position. +- The PTO entry must return no values. +- If the file or entry cannot be found, the diagnostic names the requested entry + and source path. + +When `source` is a filesystem path, relative paths are resolved from the Python +file that declares the decorated function, so tests can keep the Python wrapper +next to the PTO file: + +```text +case.py +kernels/tadd_entry.pto +``` + +```python +# case.py +@pto.jit(name="tadd_entry", source="kernels/tadd_entry.pto") +def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")): + pass +``` + +For short tests, `source` can also embed the PTO text directly: + +```python +tadd_source = """module { + func.func @tadd_entry(%a: !pto.ptr, %o: !pto.ptr) { + return + } +} +""" + +@pto.jit(name="tadd_entry", source=tadd_source) +def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")): + pass +``` + +Source-backed entries use the same `.compile()` and `compiled[grid, stream](...)` +launch syntax as ordinary traced entries. If the PTO file contents change, +compiling the same declaration again rebuilds the cached artifact. + +Limitations: + +- `source=` is only supported for launchable `entry=True` kernels. +- Keyword-only `pto.const_expr` parameters are not supported with `source=`. + Source files are loaded as fixed PTO IR text; PTODSL does not template or + specialize the source file. +- `.compile(...)` does not accept constexpr bindings for source-backed kernels. +- `backend=`, `mode=`, and `insert_sync=` still matter. For source-backed VPTO + files, set `backend="vpto"`. When `mode="explicit"`, PTODSL compiles the + source as explicit PTO; otherwise sync insertion follows the same policy as + ordinary `@pto.jit` entries. + ### SPMD built-ins Available inside an `entry=True` body: diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index de1f303d5..572b1d595 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -164,6 +164,75 @@ def kernel_module_launch_error(function_name: str | None = None) -> RuntimeError ) +def jit_source_entry_false_error( + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for unsupported ``@pto.jit(entry=False, source=...)``.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + return TypeError( + f"{target} does not support entry=False while source={source!r}. " + "Source-backed JIT is currently limited to launchable entry kernels." + ) + + +def jit_source_constexpr_error( + name: str, + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for unsupported source-backed ``pto.const_expr`` params.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + return TypeError( + f"{target} does not support keyword-only pto.const_expr parameter '{name}' while source={source!r}. " + "Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text." + ) + + +def jit_source_compile_constexpr_error( + names: list[str] | tuple[str, ...], + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for ``.compile(...)`` constexpr bindings in source mode.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + joined = ", ".join(names) + return TypeError( + f"{target} does not accept .compile(...) constexpr binding(s) {joined} while source={source!r}. " + "Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text." + ) + + +def jit_source_file_error(source: object, resolved_path: object, reason: str) -> FileNotFoundError: + """Return one diagnostic for source path resolution/loading failures.""" + return FileNotFoundError( + f"@pto.jit(source={source!r}) could not load PTO IR source file {str(resolved_path)!r}: {reason}" + ) + + +def jit_source_entry_error(source_path: object, entry_name: str, reason: str) -> TypeError: + """Return one diagnostic for source entry selection failures.""" + return TypeError( + f"@pto.jit(source=...) could not bind entry {entry_name!r} in {str(source_path)!r}: {reason}" + ) + + +def jit_source_abi_error(source_path: object, entry_name: str, reason: str) -> TypeError: + """Return one diagnostic for source ABI verification failures.""" + return TypeError( + f"@pto.jit(source=...) ABI mismatch for entry {entry_name!r} in {str(source_path)!r}: {reason}" + ) + + def jit_keyword_only_non_constexpr_error(name: str, annotation: object) -> TypeError: """Return one diagnostic for keyword-only params that are not ``pto.const_expr``.""" return TypeError( @@ -419,6 +488,12 @@ def unsupported_public_surface_error(name: str) -> AttributeError: "kernel_module_compile_error", "kernel_module_launch_error", "kernel_module_return_value_error", + "jit_source_abi_error", + "jit_source_compile_constexpr_error", + "jit_source_constexpr_error", + "jit_source_entry_false_error", + "jit_source_entry_error", + "jit_source_file_error", "invalid_jit_mode_error", "invalid_jit_backend_error", "jit_legacy_tensor_spec_helper_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index 5c23bc97f..fc86a21b9 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -15,6 +15,8 @@ from ._diagnostics import ( invalid_jit_backend_error, invalid_jit_mode_error, + jit_source_constexpr_error, + jit_source_entry_false_error, kernel_module_launch_error, ) from ._kernel_compilation import CompiledKernelHandle, KernelCompiler @@ -169,6 +171,7 @@ def jit( insert_sync: bool | None = None, ast_rewrite: bool | None = None, frontend_options: Mapping | None = None, + source: str | None = None, ): """ Decorator that wraps a Python function as a PTODSL JIT kernel template. @@ -195,6 +198,11 @@ def jit( frontend_options: Reserved structured frontend options. Currently supports ``ast_rewrite`` and ``rewrite_part={"control_flow"}``. + source: + Optional filesystem path to a PTO IR source file. When + provided, PTODSL keeps the Python signature as the host ABI + declaration but loads the kernel implementation from the + source file instead of tracing the Python body. The decorated function is replaced by a :class:`KernelHandle` that: @@ -214,6 +222,17 @@ def decorator(fn): kernel_signature = parse_jit_kernel_signature(fn, entry=entry) normalized_mode = _normalize_mode(mode, fn=fn) normalized_backend = _normalize_backend(backend, fn=fn) + if source is not None: + if not isinstance(source, str): + raise TypeError("@pto.jit source must be a filesystem path string when provided") + if entry is False: + raise jit_source_entry_false_error(source, function_name=fn_name) + if kernel_signature.constexpr_parameters: + raise jit_source_constexpr_error( + kernel_signature.constexpr_parameters[0].name, + source, + function_name=fn_name, + ) source_file = None try: source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) @@ -232,6 +251,7 @@ def decorator(fn): module_style=ModuleStyle.BACKEND_PARTITIONED, source_file=source_file, source_line=getattr(fn.__code__, "co_firstlineno", None), + jit_source=source, ), kernel_signature, fn, diff --git a/ptodsl/ptodsl/_kernel_compilation.py b/ptodsl/ptodsl/_kernel_compilation.py index 1029d3888..3d2217ee4 100644 --- a/ptodsl/ptodsl/_kernel_compilation.py +++ b/ptodsl/ptodsl/_kernel_compilation.py @@ -12,8 +12,13 @@ import inspect from ._ast_rewrite import rewrite_jit_function -from ._diagnostics import kernel_module_compile_error, kernel_module_launch_error +from ._diagnostics import ( + jit_source_compile_constexpr_error, + kernel_module_compile_error, + kernel_module_launch_error, +) from ._runtime.launch import LaunchHandle, parse_launch_spec +from ._source_loader import SourceModuleLoader from ._tracing import ModuleArtifact, SignatureTracingRuntime @@ -89,6 +94,8 @@ def tracing_callback(self): def compile(self, **constexpr_bindings): if self._module_spec.entry is False: raise kernel_module_compile_error(self._py_name) + if self._module_spec.jit_source is not None: + return self._compile_source_backed(**constexpr_bindings) normalized_bindings = self._kernel_signature.bind_constexpr_bindings(constexpr_bindings) kernel_identity = self._kernel_identity if self._ast_rewrite: @@ -124,6 +131,36 @@ def compile(self, **constexpr_bindings): self._compiled_cache[specialization_key] = compiled return compiled + def _compile_source_backed(self, **constexpr_bindings): + if constexpr_bindings: + raise jit_source_compile_constexpr_error( + tuple(sorted(constexpr_bindings)), + self._module_spec.jit_source, + function_name=self._module_spec.function_name, + ) + + loader = SourceModuleLoader(self._module_spec, self._kernel_signature) + specialization_key = self._kernel_signature.specialization_key( + loader.cache_identity(), + {}, + ) + + cached = self._compiled_cache.get(specialization_key) + if cached is not None: + return cached + + compiled = CompiledKernelHandle( + self._py_name, + specialization_key=specialization_key, + constexpr_bindings={}, + module_factory=loader.build_module, + module_spec=self._module_spec, + kernel_signature=self._kernel_signature, + ) + compiled.build() + self._compiled_cache[specialization_key] = compiled + return compiled + def cached_specializations(self): return tuple(self._compiled_cache.values()) diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index 777fa5420..0821c326b 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -45,12 +45,18 @@ def _run_ptoas( *, target_arch: str, insert_sync: bool | None = None, + backend: str | None = None, + pto_level: str | None = None, ) -> None: ptoas = resolve_ptoas_binary() cmd = [ str(ptoas), f"--pto-arch={target_arch}", ] + if backend is not None: + cmd.append(f"--pto-backend={backend}") + if pto_level is not None: + cmd.append(f"--pto-level={pto_level}") if insert_sync is True: cmd.append("--enable-insert-sync") cmd.extend([ @@ -70,6 +76,15 @@ def _effective_insert_sync(*, mode: str, insert_sync: bool | None) -> bool: return mode != "explicit" +def _source_ptoas_overrides(module_spec) -> dict: + if getattr(module_spec, "jit_source", None) is None: + return {} + overrides = {"backend": module_spec.backend} + if module_spec.mode == "explicit": + overrides["pto_level"] = "level3" + return overrides + + def _host_compile_flags() -> list[str]: return common_include_flags() + [ "-std=gnu++17", @@ -199,6 +214,7 @@ def build_native_library( mode=module_spec.mode, insert_sync=module_spec.insert_sync, ), + **_source_ptoas_overrides(module_spec), ) launch_object = artifacts.cache_dir / "launch.o" diff --git a/ptodsl/ptodsl/_source_loader.py b/ptodsl/ptodsl/_source_loader.py new file mode 100644 index 000000000..b936937dd --- /dev/null +++ b/ptodsl/ptodsl/_source_loader.py @@ -0,0 +1,205 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Source-backed module loading for ``@pto.jit(source=...)``.""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path + +from ._bootstrap import make_context +from ._diagnostics import ( + jit_source_abi_error, + jit_source_entry_error, + jit_source_file_error, +) + +from mlir.ir import Location, Module + + +@dataclass(frozen=True) +class SourceModuleArtifact: + """Parsed source-backed module plus identity metadata.""" + + module: Module + mlir_text: str + resolved_path: Path | None + source_kind: str + content_digest: str + + +class SourceModuleLoader: + """Resolve, parse, and verify one source-backed JIT module.""" + + def __init__(self, module_spec, kernel_signature): + self._module_spec = module_spec + self._kernel_signature = kernel_signature + self._artifact: SourceModuleArtifact | None = None + + @property + def source(self) -> str: + source = self._module_spec.jit_source + if source is None: + raise RuntimeError("source-backed loader requires KernelModuleSpec.jit_source") + return source + + def cache_identity(self) -> tuple: + """Return source identity for the specialization key.""" + artifact = self._load() + return ( + "source", + str(artifact.resolved_path), + artifact.content_digest, + self._module_spec.function_name, + self._module_spec.entry, + self._module_spec.target_arch, + self._module_spec.kernel_kind, + self._module_spec.backend, + self._module_spec.mode, + self._module_spec.insert_sync, + self._module_spec.module_style, + ) + + def build_module(self): + """Return ``(module, metadata)`` for ``ModuleArtifact``.""" + artifact = self._load() + metadata = { + "mlir_text": artifact.mlir_text, + "source_kind": artifact.source_kind, + "source_digest": artifact.content_digest, + } + if artifact.resolved_path is not None: + metadata["source_path"] = str(artifact.resolved_path) + return artifact.module, metadata + + def _load(self) -> SourceModuleArtifact: + if self._artifact is None: + resolved_path, mlir_text, source_kind = self._resolve_source() + content_digest = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest() + ctx = make_context() + with ctx, Location.unknown(): + module = Module.parse(mlir_text) + entry = self._select_entry(module, resolved_path) + self._verify_entry_abi(entry, resolved_path) + module.operation.verify() + self._artifact = SourceModuleArtifact( + module=module, + mlir_text=mlir_text, + resolved_path=resolved_path, + source_kind=source_kind, + content_digest=content_digest, + ) + return self._artifact + + def _resolve_source_path(self) -> Path: + raw_path = Path(self.source) + if raw_path.is_absolute(): + return raw_path.resolve() + declaring_file = self._module_spec.source_file + if declaring_file: + return (Path(declaring_file).resolve().parent / raw_path).resolve() + return raw_path.resolve() + + def _looks_like_inline_source(self, source: str) -> bool: + stripped = source.lstrip() + if "\n" in source or "\r" in source: + return True + return stripped.startswith("module {") or stripped.startswith("builtin.module {") or ( + "module {" in stripped and "func.func @" in stripped + ) + + def _resolve_source(self) -> tuple[Path | None, str, str]: + if self._looks_like_inline_source(self.source): + return None, self.source, "inline" + resolved_path = self._resolve_source_path() + return resolved_path, self._read_source_text(resolved_path), "path" + + def _read_source_text(self, resolved_path: Path) -> str: + try: + return resolved_path.read_text(encoding="utf-8") + except FileNotFoundError as exc: + raise jit_source_file_error(self.source, resolved_path, "file does not exist") from exc + except OSError as exc: + raise jit_source_file_error(self.source, resolved_path, str(exc)) from exc + + def _select_entry(self, module: Module, resolved_path: Path | None): + matches = [] + for op in _walk_ops(module.operation): + if op.operation.name != "func.func": + continue + if getattr(op, "is_external", False): + continue + if _symbol_name(op) == self._module_spec.function_name: + matches.append(op) + + if not matches: + raise jit_source_entry_error( + _source_location_label(self.source, resolved_path), + self._module_spec.function_name, + "missing non-declaration func.func with this symbol name", + ) + if len(matches) > 1: + raise jit_source_entry_error( + _source_location_label(self.source, resolved_path), + self._module_spec.function_name, + f"found {len(matches)} matching non-declaration func.func ops", + ) + return matches[0] + + def _verify_entry_abi(self, entry, resolved_path: Path | None) -> None: + source_label = _source_location_label(self.source, resolved_path) + expected = tuple(str(type_obj) for type_obj in self._kernel_signature.compute_entry_arg_types()) + actual = tuple(str(type_obj) for type_obj in entry.type.inputs) + results = tuple(str(type_obj) for type_obj in entry.type.results) + + if results: + raise jit_source_abi_error( + source_label, + self._module_spec.function_name, + f"source entry must return no values, got ({', '.join(results)})", + ) + if len(actual) != len(expected): + raise jit_source_abi_error( + source_label, + self._module_spec.function_name, + "parameter count differs; " + f"expected ({', '.join(expected)}), got ({', '.join(actual)})", + ) + for index, (expected_type, actual_type) in enumerate(zip(expected, actual)): + if expected_type != actual_type: + raise jit_source_abi_error( + source_label, + self._module_spec.function_name, + f"parameter {index} differs; expected {expected_type}, got {actual_type}", + ) + + +def _symbol_name(op) -> str | None: + attrs = op.attributes + if "sym_name" not in attrs: + return None + return str(attrs["sym_name"]).strip('"') + + +def _walk_ops(root_op): + for region in root_op.regions: + for block in region.blocks: + for op in block.operations: + yield op + yield from _walk_ops(op.operation) + + +def _source_location_label(source: str, resolved_path: Path | None): + if resolved_path is not None: + return resolved_path + digest = hashlib.sha256(source.encode("utf-8")).hexdigest()[:12] + return f"" + + +__all__ = ["SourceModuleLoader"] diff --git a/ptodsl/ptodsl/_tracing/artifacts.py b/ptodsl/ptodsl/_tracing/artifacts.py index 14a50ec4a..5dc51e3fe 100644 --- a/ptodsl/ptodsl/_tracing/artifacts.py +++ b/ptodsl/ptodsl/_tracing/artifacts.py @@ -19,10 +19,11 @@ class ModuleArtifact: Subclasses may either pass an eager ``module`` or a lazy ``module_factory``. """ - def __init__(self, py_name: str, *, module=None, module_factory=None): + def __init__(self, py_name: str, *, module=None, module_factory=None, mlir_text: str | None = None): self._py_name = py_name self._cached_module = module self._module_factory = module_factory + self._cached_mlir_text = mlir_text self._build_metadata = {} def build(self): @@ -34,6 +35,8 @@ def build(self): if isinstance(built, tuple): self._cached_module, metadata = built self._build_metadata = dict(metadata or {}) + if "mlir_text" in self._build_metadata: + self._cached_mlir_text = self._build_metadata["mlir_text"] else: self._cached_module = built self._build_metadata = {} @@ -45,7 +48,10 @@ def mlir_module(self): def mlir_text(self) -> str: """Return the textual MLIR form.""" - return str(self.build()) + self.build() + if self._cached_mlir_text is not None: + return self._cached_mlir_text + return str(self._cached_module) def verify(self) -> None: """Verify the cached module operation.""" diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index f15f62e39..f4108724c 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -38,6 +38,7 @@ class KernelModuleSpec: module_style: ModuleStyle = ModuleStyle.NESTED source_file: str | None = None source_line: int | None = None + jit_source: str | None = None def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): diff --git a/ptodsl/tests/support/docs_fragment_fixtures.py b/ptodsl/tests/support/docs_fragment_fixtures.py index 788cf3178..db833700c 100644 --- a/ptodsl/tests/support/docs_fragment_fixtures.py +++ b/ptodsl/tests/support/docs_fragment_fixtures.py @@ -369,6 +369,27 @@ def kernel_name( assert record.marshaled_arg_count == 4 """ ), + "launch.source_backed_tadd": _fixture( + f""" + grid = 2 + stream = object() + A = 0x1000 + B = 0x2000 + O = 0x3000 + numel = 128 + + {SNIPPET_PLACEHOLDER} + + assert compiled.ir_function_name == "tadd_entry" + assert compiled.build_metadata()["source_path"].endswith("kernels/tadd_entry.pto") + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == grid + assert record.stream is stream + assert record.args == (A, B, O, numel) + assert record.marshaled_arg_count == 4 + """ + ), "launch.mat_add_wrapper": _fixture( f""" import numpy as np diff --git a/ptodsl/tests/test_docs_as_test.py b/ptodsl/tests/test_docs_as_test.py index 42b251575..83ec5a279 100644 --- a/ptodsl/tests/test_docs_as_test.py +++ b/ptodsl/tests/test_docs_as_test.py @@ -69,6 +69,7 @@ class DocTestDirective: symbol: Optional[str] = None compile_kwargs: Optional[dict[str, object]] = None fixture: Optional[str] = None + files: Optional[dict[str, str]] | None = None @dataclass(frozen=True) @@ -257,11 +258,25 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: symbol = payload.get("symbol") compile_kwargs = payload.get("compile") fixture = payload.get("fixture") + files = payload.get("files") expect( isinstance(mode, str) and mode, f"{block_label(block)}: ptodsl-doc-test metadata must define a non-empty string 'mode'", ) + if files is not None: + expect( + isinstance(files, dict) and all(isinstance(path, str) and isinstance(text, str) for path, text in files.items()), + f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " + "ptodsl-doc-test metadata 'files' must be an object mapping relative file paths to text", + ) + for path in files: + file_path = Path(path) + expect( + not file_path.is_absolute() and ".." not in file_path.parts, + f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " + f"ptodsl-doc-test metadata file path must be relative and stay inside the snippet directory: {path!r}", + ) if mode in ("compile", "compile_fragment"): expect( isinstance(symbol, str) and symbol, @@ -282,8 +297,9 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: symbol=symbol, compile_kwargs=compile_kwargs, fixture=fixture, + files=files, ) - return DocTestDirective(mode=mode, symbol=symbol, compile_kwargs=compile_kwargs) + return DocTestDirective(mode=mode, symbol=symbol, compile_kwargs=compile_kwargs, files=files) if mode == "launch_fragment": expect( @@ -300,7 +316,7 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " "ptodsl-doc-test launch_fragment does not accept a 'compile' object; the snippet owns its compile/launch flow", ) - return DocTestDirective(mode=mode, symbol=symbol, fixture=fixture) + return DocTestDirective(mode=mode, symbol=symbol, fixture=fixture, files=files) expect( False, @@ -310,23 +326,46 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: return DocTestDirective(mode=mode) +def _write_directive_files(snippet_dir: Path, files: dict[str, str] | None) -> None: + if not files: + return + for relative_path, text in files.items(): + output_path = snippet_dir / relative_path + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(text, encoding="utf-8") + + +@contextmanager +def directive_execution_dir(directive: DocTestDirective): + if not directive.files: + yield None + return + + with tempfile.TemporaryDirectory() as temp_dir: + snippet_dir = Path(temp_dir) + _write_directive_files(snippet_dir, directive.files) + yield snippet_dir + + def execute_source( source: str, block: MarkdownCodeBlock, symbol: Optional[str] = None, *, extra_namespace: Optional[dict[str, object]] = None, + source_dir: Path | None = None, ) -> dict[str, object]: + source_file = block.path if source_dir is None else source_dir / "case.py" namespace: dict[str, object] = { "__builtins__": __builtins__, "__name__": "__ptodsl_doc_snippet__", - "__file__": str(block.path), + "__file__": str(source_file), "pto": pto, "scalar": scalar, } if extra_namespace is not None: namespace.update(extra_namespace) - filename = f"{block.path}::codeblock:{block.start_line}" + filename = f"{source_file}::codeblock:{block.start_line}" source_lines = source.splitlines(keepends=True) linecache.cache[filename] = (len(source), None, source_lines, filename) try: @@ -409,8 +448,9 @@ def verify_compiled_target( def run_compile_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: directive = parse_test_directive(block) - namespace = execute_source(block.text, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) + with directive_execution_dir(directive) as source_dir: + namespace = execute_source(block.text, block, directive.symbol, source_dir=source_dir) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: @@ -429,8 +469,9 @@ def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> Non raise AssertionError( f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" ) from exc - namespace = execute_source(rendered_source, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) + with directive_execution_dir(directive) as source_dir: + namespace = execute_source(rendered_source, block, directive.symbol, source_dir=source_dir) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: @@ -450,13 +491,15 @@ def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" ) from exc - with capture_launch_records() as launch_records: - execute_source( - rendered_source, - block, - directive.symbol, - extra_namespace={"PTODSL_DOC_LAUNCH_RECORDS": launch_records}, - ) + with directive_execution_dir(directive) as source_dir: + with capture_launch_records() as launch_records: + execute_source( + rendered_source, + block, + directive.symbol, + extra_namespace={"PTODSL_DOC_LAUNCH_RECORDS": launch_records}, + source_dir=source_dir, + ) expect( bool(launch_records), diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index fd1b39164..c7ed86ead 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -12,6 +12,7 @@ import re import sys from tempfile import TemporaryDirectory +from importlib.util import module_from_spec, spec_from_file_location from typing import Optional from unittest import mock @@ -3118,6 +3119,187 @@ def main() -> None: pointer_block64.specialization_key.constexpr_signature == (("BLOCK", 64),), "pointer-first specialization key should change only with constexpr bindings", ) + source_native_build_compiled = None + source_explicit_native_build_compiled = None + source_no_insert_sync_native_build_compiled = None + with TemporaryDirectory() as tmpdir: + source_path = Path(tmpdir) / "source_kernel.pto" + source_text_v1 = ( + "module {\n" + " func.func @selected_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + " func.func @other_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_text_v2 = ( + "module {\n" + " func.func @selected_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " %c0 = arith.constant 0 : i32\n" + " return\n" + " }\n" + " func.func @other_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_path.write_text(source_text_v1, encoding="utf-8") + + @pto.jit(name="selected_source_entry", target="a5", source=str(source_path)) + def source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_compiled_v1 = source_backed_probe.compile() + source_compiled_v1_again = source_backed_probe.compile() + expect(source_compiled_v1 is source_compiled_v1_again, "unchanged source-backed JIT should hit specialization cache") + expect( + source_compiled_v1.mlir_text() == source_text_v1, + "source-backed JIT mlir_text() should preserve the authored source text", + ) + expect( + source_compiled_v1.ir_function_name == "selected_source_entry", + "source-backed JIT should use name= for entry selection and launch wrapper naming", + ) + expect( + source_compiled_v1.build_metadata()["source_path"] == str(source_path.resolve()), + "source-backed JIT metadata should expose the resolved source path", + ) + + source_path.write_text(source_text_v2, encoding="utf-8") + source_compiled_v2 = source_backed_probe.compile() + expect( + source_compiled_v2 is not source_compiled_v1, + "editing the source file should materialize a new specialization", + ) + expect( + source_compiled_v2.specialization_key != source_compiled_v1.specialization_key, + "source-backed specialization key should include source content", + ) + expect( + source_compiled_v2.mlir_text() == source_text_v2, + "source-backed JIT should reload changed source text", + ) + source_auto_path = Path(tmpdir) / "source_auto.pto" + source_auto_text = ( + "module {\n" + " func.func @source_auto_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_auto_path.write_text(source_auto_text, encoding="utf-8") + + @pto.jit(name="source_auto_native", target="a5", backend="vpto", source=str(source_auto_path)) + def source_auto_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_native_build_compiled = source_auto_native.compile() + + source_explicit_path = Path(tmpdir) / "source_explicit.pto" + source_explicit_text = ( + "module {\n" + " func.func @source_explicit_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_explicit_path.write_text(source_explicit_text, encoding="utf-8") + + @pto.jit( + name="source_explicit_native", + target="a5", + backend="vpto", + mode="explicit", + source=str(source_explicit_path), + ) + def source_explicit_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_explicit_native_build_compiled = source_explicit_native.compile() + + source_no_insert_sync_path = Path(tmpdir) / "source_no_insert_sync.pto" + source_no_insert_sync_text = ( + "module {\n" + " func.func @source_no_insert_sync_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_no_insert_sync_path.write_text(source_no_insert_sync_text, encoding="utf-8") + + @pto.jit( + name="source_no_insert_sync_native", + target="a5", + backend="vpto", + insert_sync=False, + source=str(source_no_insert_sync_path), + ) + def source_no_insert_sync_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_no_insert_sync_native_build_compiled = source_no_insert_sync_native.compile() + + relative_case_dir = Path(tmpdir) / "relative_case" + relative_case_dir.mkdir() + relative_source_path = relative_case_dir / "relative_kernel.pto" + relative_source_text = ( + "module {\n" + " func.func @relative_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + relative_source_path.write_text(relative_source_text, encoding="utf-8") + relative_module_path = relative_case_dir / "relative_case.py" + relative_module_path.write_text( + "from ptodsl import pto\n" + "@pto.jit(name='relative_source_entry', target='a5', source='relative_kernel.pto')\n" + "def relative_kernel(ptr: pto.ptr(pto.f32, 'gm')):\n" + " raise RuntimeError('source-backed JIT should not trace the Python body')\n", + encoding="utf-8", + ) + spec = spec_from_file_location("ptodsl_relative_source_case", relative_module_path) + expect(spec is not None and spec.loader is not None, "relative source test module should be importable") + relative_module = module_from_spec(spec) + spec.loader.exec_module(relative_module) + relative_compiled = relative_module.relative_kernel.compile() + expect( + relative_compiled.mlir_text() == relative_source_text, + "relative source paths should resolve against the declaring Python file", + ) + expect( + relative_compiled.build_metadata()["source_path"] == str(relative_source_path.resolve()), + "relative source-backed metadata should expose the resolved source path", + ) + + inline_source_text = ( + "module {\n" + " func.func @inline_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + "}\n" + ) + + @pto.jit(name="inline_source_entry", target="a5", source=inline_source_text) + def inline_source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + inline_compiled = inline_source_backed_probe.compile() + expect( + inline_compiled.mlir_text() == inline_source_text, + "inline source-backed JIT mlir_text() should preserve the authored source text", + ) + inline_metadata = inline_compiled.build_metadata() + expect( + inline_metadata["source_kind"] == "inline", + "inline source-backed JIT metadata should mark the source kind as inline", + ) + expect( + "source_path" not in inline_metadata, + "inline source-backed JIT metadata should not synthesize a filesystem path", + ) pointer_artifacts_default = artifact_paths( pointer_default._py_name, pointer_default.ir_function_name, @@ -3482,6 +3664,9 @@ def main() -> None: ("pure-container", host_vec_copy.compile()), ("same-backend-multi-child-container", kernel_module_compiled), ("mixed-backend-container", emitc_entry_calls_vpto_kernel_module_probe.compile()), + ("source-auto", source_native_build_compiled), + ("source-explicit", source_explicit_native_build_compiled), + ("source-no-insert-sync", source_no_insert_sync_native_build_compiled), ) native_build_observations = [] @@ -3499,13 +3684,15 @@ def fake_artifacts(py_name, ir_function_name, specialization_key): manifest_path=cache_dir / "manifest.json", ) - def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None): + def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None, backend=None, pto_level=None): native_build_observations.append( { "mlir_path": mlir_path, "kernel_object": kernel_object, "target_arch": target_arch, "insert_sync": insert_sync, + "backend": backend, + "pto_level": pto_level, "mlir_text": mlir_path.read_text(encoding="utf-8"), } ) @@ -3562,14 +3749,25 @@ def fake_link_shared_library(launch_object, kernel_object, shared_library, *, ke observation["insert_sync"] == expected_insert_sync, f"{label} native build should forward the effective insert_sync policy to ptoas", ) + expected_backend = compiled._module_spec.backend if compiled._module_spec.jit_source is not None else None + expected_pto_level = "level3" if compiled._module_spec.jit_source is not None and compiled._module_spec.mode == "explicit" else None expect( - observation["mlir_text"] == compiled.mlir_text(), - f"{label} native build should hand the backend-partitioned container MLIR to ptoas unchanged", + observation["backend"] == expected_backend, + f"{label} native build should only forward ptoas backend overrides for source-backed kernels", + ) + expect( + observation["pto_level"] == expected_pto_level, + f"{label} native build should only map explicit mode to ptoas level3 for source-backed kernels", ) expect( - observation["mlir_text"].count("module") >= 2, - f"{label} native build should route the unified outer+child container through ptoas", + observation["mlir_text"] == compiled.mlir_text(), + f"{label} native build should hand the backend-partitioned container MLIR to ptoas unchanged", ) + if compiled._module_spec.jit_source is None: + expect( + observation["mlir_text"].count("module") >= 2, + f"{label} native build should route the unified outer+child container through ptoas", + ) with TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) mlir_path = tmpdir_path / "kernel.mlir" @@ -3626,6 +3824,32 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "--enable-insert-sync" in ptoas_cmds[0], "native build should pass --enable-insert-sync when the compiled module explicitly requests it", ) + ptoas_cmds.clear() + with mock.patch.object(native_build_runtime, "resolve_ptoas_binary", return_value=Path("/tmp/fake-ptoas")), mock.patch.object( + native_build_runtime, "_run", side_effect=fake_run_ptoas_cmd + ): + native_build_runtime._run_ptoas( + mlir_path, + kernel_object, + target_arch="a5", + backend="vpto", + pto_level="level3", + insert_sync=True, + ) + expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command with source-backed overrides") + source_ptoas_cmd = ptoas_cmds[0] + expect( + "--pto-backend=vpto" in source_ptoas_cmd, + "source-backed native build should pass the decorator backend to ptoas", + ) + expect( + "--pto-level=level3" in source_ptoas_cmd, + "source-backed explicit mode should pass --pto-level=level3 to ptoas", + ) + expect( + "--enable-insert-sync" in source_ptoas_cmd, + "source-backed native build should still pass explicit/effective insert-sync to ptoas", + ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") auto_mode_violation = expect_raises( RuntimeError, @@ -5005,6 +5229,9 @@ def _enter_inline_simt_with_resource_attr(): launch_handle = block64[1, None] expect(callable(launch_handle), "compiled[grid, stream] should return a launch callable") expect(hasattr(launch_handle, "__call__"), "launch handle should support __call__") + source_launch_handle = source_native_build_compiled[1, None] + expect(callable(source_launch_handle), "source-backed compiled[grid, stream] should return a launch callable") + expect(hasattr(source_launch_handle, "__call__"), "source-backed launch handle should support __call__") print("ptodsl_jit_compile: PASS") os._exit(0) diff --git a/ptodsl/tests/test_jit_diagnostics.py b/ptodsl/tests/test_jit_diagnostics.py index 96ebcd4a3..7c1d344bc 100644 --- a/ptodsl/tests/test_jit_diagnostics.py +++ b/ptodsl/tests/test_jit_diagnostics.py @@ -9,6 +9,7 @@ from pathlib import Path import sys +from tempfile import TemporaryDirectory sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) @@ -404,6 +405,31 @@ def bad_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): return bad_probe +def define_source_non_string_probe(): + @pto.jit(target="a5", source=123) + def bad_probe(rows: pto.i32): + _ = rows + + return bad_probe + + +def define_source_entry_false_probe(): + @pto.jit(target="a5", entry=False, source="kernel.pto") + def bad_probe(tile: pto.Tile): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_probe + + +def define_source_constexpr_probe(): + @pto.jit(target="a5", source="kernel.pto") + def bad_probe(ptr: pto.ptr(pto.f32, "gm"), *, BLOCK: pto.const_expr = 8): + _ = ptr + _ = BLOCK + + return bad_probe + + @pto.jit(target="a5") def regular_entry_probe(rows: pto.i32): _ = rows @@ -788,6 +814,195 @@ def main() -> None: "mask", "do not belong at this kernel-module boundary", ) + expect_raises( + define_source_non_string_probe, + TypeError, + "@pto.jit source must be a filesystem path string when provided", + ) + expect_raises( + define_source_entry_false_probe, + TypeError, + "@pto.jit(source=...) kernel 'bad_probe' does not support entry=False while source='kernel.pto'", + "Source-backed JIT is currently limited to launchable entry kernels", + ) + expect_raises( + define_source_constexpr_probe, + TypeError, + "@pto.jit(source=...) kernel 'bad_probe' does not support keyword-only pto.const_expr parameter 'BLOCK' while source='kernel.pto'", + "does not template or specialize source text", + ) + with TemporaryDirectory() as tmpdir: + source_dir = Path(tmpdir) + missing_path = source_dir / "missing.pto" + + @pto.jit(target="a5", source=str(missing_path)) + def source_missing_file_probe(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + source_missing_file_probe.compile, + FileNotFoundError, + "@pto.jit(source=", + "missing.pto", + "file does not exist", + ) + + missing_entry_path = source_dir / "missing_entry.pto" + missing_entry_path.write_text( + "module {\n" + " func.func @other_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(name="wanted_entry", target="a5", source=str(missing_entry_path)) + def source_missing_entry_probe(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + source_missing_entry_probe.compile, + TypeError, + "could not bind entry 'wanted_entry'", + "missing non-declaration func.func", + ) + + ambiguous_entry_path = source_dir / "ambiguous_entry.pto" + ambiguous_entry_path.write_text( + "module {\n" + " func.func @ambiguous_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + " builtin.module {\n" + " func.func @ambiguous_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(ambiguous_entry_path)) + def ambiguous_entry(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + ambiguous_entry.compile, + TypeError, + "could not bind entry 'ambiguous_entry'", + "found 2 matching non-declaration func.func ops", + ) + + count_mismatch_path = source_dir / "count_mismatch.pto" + count_mismatch_path.write_text( + "module {\n" + " func.func @count_mismatch(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(count_mismatch_path)) + def count_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + count_mismatch.compile, + TypeError, + "ABI mismatch for entry 'count_mismatch'", + "parameter count differs", + "expected (!pto.ptr, i32)", + "got (!pto.ptr)", + ) + + type_mismatch_path = source_dir / "type_mismatch.pto" + type_mismatch_path.write_text( + "module {\n" + " func.func @type_mismatch(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(type_mismatch_path)) + def type_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + type_mismatch.compile, + TypeError, + "ABI mismatch for entry 'type_mismatch'", + "parameter 0 differs", + "expected !pto.ptr", + "got !pto.ptr", + ) + + non_void_path = source_dir / "non_void.pto" + non_void_path.write_text( + "module {\n" + " func.func @non_void(%arg0: !pto.ptr) -> i32 {\n" + " %c0 = arith.constant 0 : i32\n" + " return %c0 : i32\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(non_void_path)) + def non_void(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + non_void.compile, + TypeError, + "ABI mismatch for entry 'non_void'", + "source entry must return no values", + "i32", + ) + + compile_constexpr_path = source_dir / "compile_constexpr.pto" + compile_constexpr_path.write_text( + "module {\n" + " func.func @compile_constexpr(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(compile_constexpr_path)) + def compile_constexpr(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + lambda: compile_constexpr.compile(BLOCK=8), + TypeError, + "@pto.jit(source=...) kernel 'compile_constexpr' does not accept .compile(...) constexpr binding(s) BLOCK", + "does not template or specialize source text", + ) + + inline_count_mismatch_source = ( + "module {\n" + " func.func @inline_count_mismatch(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + + @pto.jit(target="a5", source=inline_count_mismatch_source) + def inline_count_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + inline_count_mismatch.compile, + TypeError, + "ABI mismatch for entry 'inline_count_mismatch'", + "parameter count differs", + "} { +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +import numpy as np + + +def _bootstrap_dsl_st_common() -> None: + here = Path(__file__).resolve() + for candidate in here.parents: + common_dir = candidate / "test" / "dsl-st" + if (common_dir / "common.py").exists(): + sys.path.insert(0, str(common_dir)) + return + raise RuntimeError("Unable to locate test/dsl-st/common.py from vmadd.py") + + +_bootstrap_dsl_st_common() + +from common import auto_main, golden_output_case +from ptodsl import pto + + +ELEMS = 1024 +SEED = 29 + +VMADD_SOURCE = """module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { func.func @a5_extra_vmadd_kernel( %f_acc: !pto.ptr, %f_lhs: !pto.ptr, %f_rhs: !pto.ptr, %out_vmadd: !pto.ptr) @@ -57,3 +83,47 @@ return } } +""" + + +@pto.jit( + name="a5_extra_vmadd_kernel", + target="a5", + backend="vpto", + mode="explicit", + source=VMADD_SOURCE, +) +def a5_extra_vmadd_kernel( + f_acc: pto.ptr(pto.f32, "gm"), + f_lhs: pto.ptr(pto.f32, "gm"), + f_rhs: pto.ptr(pto.f32, "gm"), + out_vmadd: pto.ptr(pto.f32, "gm"), +): + pass + + +def make_inputs(): + rng = np.random.default_rng(SEED) + f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) + f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) + return [f_acc, f_lhs, f_rhs] + + +def make_expected(f_acc, f_lhs, f_rhs): + return (f_lhs * f_acc + f_rhs).astype(np.float32) + + +CASES = [ + golden_output_case( + "a5_extra_vmadd", + a5_extra_vmadd_kernel, + inputs=make_inputs, + expected=make_expected, + rtol=2e-4, + atol=2e-4, + ), +] + + +auto_main(globals()) diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py b/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py deleted file mode 100644 index baed02d3b..000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import os -import sys - -import numpy as np - - -def compare_float(golden_path: str, output_path: str, label: str, atol: float) -> bool: - if not os.path.exists(golden_path) or not os.path.exists(output_path): - print(f"[ERROR] missing file for {label}") - return False - golden = np.fromfile(golden_path, dtype=np.float32) - output = np.fromfile(output_path, dtype=np.float32) - ok = golden.shape == output.shape and np.allclose( - golden, output, atol=atol, rtol=atol, equal_nan=True - ) - if not ok: - diff = np.max(np.abs(golden.astype(np.float64) - output.astype(np.float64))) - print(f"[ERROR] compare failed: {label}, max_abs_diff={diff}") - return ok - - -def compare_exact(golden_path: str, output_path: str, dtype, label: str) -> bool: - if not os.path.exists(golden_path) or not os.path.exists(output_path): - print(f"[ERROR] missing file for {label}") - return False - golden = np.fromfile(golden_path, dtype=dtype) - output = np.fromfile(output_path, dtype=dtype) - ok = golden.shape == output.shape and np.array_equal(golden, output) - if not ok: - mismatch = np.flatnonzero(golden != output) - first = int(mismatch[0]) if mismatch.size else -1 - print( - f"[ERROR] compare failed: {label}, first_mismatch={first}, " - f"golden={golden[first] if first >= 0 else 'n/a'}, " - f"output={output[first] if first >= 0 else 'n/a'}" - ) - return ok - - -def main() -> None: - checks = [ - compare_float("golden_vmadd.bin", "out_vmadd.bin", "vmadd", 2e-4), - ] - if not all(checks): - sys.exit(2) - print("[INFO] compare passed (a5 extra vmadd)") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py b/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py deleted file mode 100644 index 276b6c82c..000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import argparse -from pathlib import Path - -import numpy as np - -ELEMS = 1024 -SEED = 29 - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--output-dir", type=Path, default=Path(".")) - parser.add_argument("--seed", type=int, default=SEED) - args = parser.parse_args() - out = args.output_dir - out.mkdir(parents=True, exist_ok=True) - rng = np.random.default_rng(args.seed) - - f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) - f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) - f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) - - f_acc.tofile(out / "f_acc.bin") - f_lhs.tofile(out / "f_lhs.bin") - f_rhs.tofile(out / "f_rhs.bin") - - (f_lhs * f_acc + f_rhs).astype(np.float32).tofile(out / "golden_vmadd.bin") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp b/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp deleted file mode 100644 index 91edee7f6..000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif -#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) -typedef struct { unsigned char v; } hifloat8_t; -typedef struct { unsigned char v; } float8_e4m3_t; -typedef struct { unsigned char v; } float8_e5m2_t; -typedef struct { unsigned char v; } float8_e8m0_t; -typedef struct { unsigned char v; } float4_e1m2x2_t; -typedef struct { unsigned char v; } float4_e2m1x2_t; -#endif -#include -#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) -#include -#endif -#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) -struct MrgSortExecutedNumList { - uint16_t mrgSortList0; - uint16_t mrgSortList1; - uint16_t mrgSortList2; - uint16_t mrgSortList3; -}; -#endif -#ifndef __CPU_SIM -#include "acl/acl.h" -#endif - -extern "C" __global__ [aicore] void a5_extra_vmadd_kernel( - __gm__ float *f_acc, __gm__ float *f_lhs, __gm__ float *f_rhs, - __gm__ float *out_vmadd); - -void LaunchA5ExtraVmadd(float *p0, float *p1, float *p2, float *p3, - void *stream) { - a5_extra_vmadd_kernel<<<1, nullptr, stream>>>( - (__gm__ float *)p0, (__gm__ float *)p1, (__gm__ float *)p2, - (__gm__ float *)p3); -} diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp b/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp deleted file mode 100644 index d5a3b781d..000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include - -using namespace PtoTestCommon; - -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ - (int)_ret, __FILE__, __LINE__); \ - const char *_recent = aclGetRecentErrMsg(); \ - if (_recent != nullptr && _recent[0] != '\0') \ - std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -#define FILE_CHECK(expr, path) \ - do { \ - if (!(expr)) { \ - std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ - path, __FILE__, __LINE__); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -void LaunchA5ExtraVmadd(float *p0, float *p1, float *p2, float *p3, - void *stream); - -struct Buffer { - const char *path; - size_t size; - bool input; - void *host; - void *device; -}; - -int main() { - constexpr size_t kElems = 1024; - constexpr size_t kF32Bytes = kElems * sizeof(float); - - Buffer bufs[] = { - {"./f_acc.bin", kF32Bytes, true, nullptr, nullptr}, - {"./f_lhs.bin", kF32Bytes, true, nullptr, nullptr}, - {"./f_rhs.bin", kF32Bytes, true, nullptr, nullptr}, - {"./out_vmadd.bin", kF32Bytes, false, nullptr, nullptr}, - }; - - int rc = 0; - bool aclInited = false; - bool deviceSet = false; - int deviceId = 0; - aclrtStream stream = nullptr; - - ACL_CHECK(aclInit(nullptr)); - aclInited = true; - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) - deviceId = std::atoi(envDevice); - ACL_CHECK(aclrtSetDevice(deviceId)); - deviceSet = true; - ACL_CHECK(aclrtCreateStream(&stream)); - - for (Buffer &buf : bufs) { - ACL_CHECK(aclrtMallocHost(&buf.host, buf.size)); - ACL_CHECK(aclrtMalloc(&buf.device, buf.size, ACL_MEM_MALLOC_HUGE_FIRST)); - if (!buf.input) - continue; - size_t fileSize = buf.size; - FILE_CHECK(ReadFile(buf.path, fileSize, buf.host, buf.size) && - fileSize == buf.size, - buf.path); - ACL_CHECK(aclrtMemcpy(buf.device, buf.size, buf.host, buf.size, - ACL_MEMCPY_HOST_TO_DEVICE)); - } - - LaunchA5ExtraVmadd( - static_cast(bufs[0].device), static_cast(bufs[1].device), - static_cast(bufs[2].device), - static_cast(bufs[3].device), - stream); - - ACL_CHECK(aclrtSynchronizeStream(stream)); - - for (Buffer &buf : bufs) { - if (buf.input) - continue; - ACL_CHECK(aclrtMemcpy(buf.host, buf.size, buf.device, buf.size, - ACL_MEMCPY_DEVICE_TO_HOST)); - FILE_CHECK(WriteFile(buf.path, buf.host, buf.size), buf.path); - } - -cleanup: - for (Buffer &buf : bufs) { - if (buf.device != nullptr) - aclrtFree(buf.device); - if (buf.host != nullptr) - aclrtFreeHost(buf.host); - } - if (stream != nullptr) - aclrtDestroyStream(stream); - if (deviceSet) - aclrtResetDevice(deviceId); - if (aclInited) - aclFinalize(); - return rc; -} diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh index 709b29747..fa530911a 100755 --- a/test/vpto/scripts/run_host_vpto_validation.sh +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -25,6 +25,7 @@ CASE_NAME="${CASE_NAME:-}" DEVICE="${DEVICE:-SIM}" SIM_LIB_DIR="${SIM_LIB_DIR:-}" COMPILE_ONLY="${COMPILE_ONLY:-0}" +PTODSL_SIM_SOC_VERSION="${PTODSL_SIM_SOC_VERSION:-Ascend950PR_9599}" log() { echo "[$(date +'%F %T')] $*" @@ -94,8 +95,6 @@ resolve_sim_lib_dir() { die "SIM_LIB_DIR is required for DEVICE=SIM and no dav_3510 simulator lib dir was found under: ${ASCEND_HOME_PATH}" } -resolve_sim_lib_dir - BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" @@ -104,13 +103,54 @@ command -v python3 >/dev/null 2>&1 || die "python3 not found" mkdir -p "${WORK_SPACE}" WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" +is_ptodsl_case_dir() { + [[ -f "$1/kernel.py" ]] +} + +is_ptodsl_case_file() { + local case_path="$1" + local base_name + base_name="$(basename "${case_path}")" + [[ -f "${case_path}" ]] && + [[ "${case_path}" == *.py ]] && + [[ "${base_name}" != "golden.py" ]] && + [[ "${base_name}" != "compare.py" ]] && + [[ "${base_name}" != "kernel.py" ]] && + [[ "${base_name}" != _* ]] +} + +is_ptodsl_case_path() { + is_ptodsl_case_dir "$1" || is_ptodsl_case_file "$1" +} + +ptodsl_case_script() { + local case_path="$1" + if is_ptodsl_case_dir "${case_path}"; then + printf '%s\n' "${case_path}/kernel.py" + return 0 + fi + if is_ptodsl_case_file "${case_path}"; then + printf '%s\n' "${case_path}" + return 0 + fi + die "path is not a PTODSL case: ${case_path}" +} + +validate_case_path() { + local case_name="$1" + local case_path="$2" + + if is_ptodsl_case_path "${case_path}"; then + return 0 + fi + [[ -d "${case_path}" ]] || die "case ${case_name} is neither a directory nor a PTODSL case file" + for f in launch.cpp main.cpp golden.py compare.py; do + [[ -f "${case_path}/${f}" ]] || die "case ${case_name} is missing ${f}" + done + [[ -f "${case_path}/kernel.pto" ]] || die "case ${case_name} must provide kernel.pto" +} + discover_cases() { - local required_files=( - launch.cpp - main.cpp - golden.py - compare.py - ) local onboard_only_prefix="onboard-only/" if [[ -n "${CASE_NAME}" ]]; then @@ -119,28 +159,36 @@ discover_cases() { "${CASE_NAME}" == "${onboard_only_prefix}"* ]]; then die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi - local requested_dir="${CASES_ROOT}/${CASE_NAME}" - [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - for f in "${required_files[@]}"; do - [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" - done - [[ -f "${requested_dir}/kernel.pto" ]] || - die "case ${CASE_NAME} must provide kernel.pto" + local requested_path="${CASES_ROOT}/${CASE_NAME}" + [[ -e "${requested_path}" ]] || die "unknown case: ${CASE_NAME}" + validate_case_path "${CASE_NAME}" "${requested_path}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - local ok=1 - for f in "${required_files[@]}"; do - if [[ ! -f "${dir}/${f}" ]]; then - ok=0 - break - fi - done - [[ "${ok}" -eq 1 ]] || continue - [[ -f "${dir}/kernel.pto" ]] || continue + [[ -f "${dir}/kernel.pto" || -f "${dir}/kernel.py" ]] || continue local rel="${dir#${CASES_ROOT}/}" + if ! is_ptodsl_case_dir "${dir}"; then + local ok=1 + for f in launch.cpp main.cpp golden.py compare.py; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + fi + if [[ "${DEVICE}" == "SIM" && "${COMPILE_ONLY}" != "1" && + "${rel}" == "${onboard_only_prefix}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done + + find "${CASES_ROOT}" -type f -name '*.py' | sort | while read -r path; do + is_ptodsl_case_file "${path}" || continue + local rel="${path#${CASES_ROOT}/}" if [[ "${DEVICE}" == "SIM" && "${COMPILE_ONLY}" != "1" && "${rel}" == "${onboard_only_prefix}"* ]]; then continue @@ -157,6 +205,19 @@ fi readarray -t CASES < <(discover_cases) [[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" +needs_legacy_sim_lib=0 +if [[ "${DEVICE}" == "SIM" ]]; then + for case_name in "${CASES[@]}"; do + if ! is_ptodsl_case_path "${CASES_ROOT}/${case_name}"; then + needs_legacy_sim_lib=1 + break + fi + done +fi +if [[ "${needs_legacy_sim_lib}" == "1" ]]; then + resolve_sim_lib_dir +fi + case_output_token() { printf '%s' "$1" | sed 's#[/[:space:]]#_#g' } @@ -246,9 +307,34 @@ build_host_executable() { -lstdc++ -lascendcl -lm -ltiling_api -lplatform -lc_sec -ldl -lnnopbase } +run_ptodsl_case() { + local case_name="$1" + local case_path="$2" + local out_dir="$3" + local case_script + case_script="$(ptodsl_case_script "${case_path}")" + + log "[$case_name] run PTODSL source-backed case" + ( + cd "${out_dir}" + export PTODSL_CACHE_DIR="${out_dir}/ptodsl-cache" + export PATH="$(dirname "${PTOAS_BIN}"):${PATH}" + if [[ "${DEVICE}" == "SIM" ]]; then + "${ROOT_DIR}/scripts/sim_dsl.sh" \ + --soc-version "${PTODSL_SIM_SOC_VERSION}" \ + --output "${out_dir}/msprof" \ + "${case_script}" + else + export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" + python3 "${case_script}" + fi + ) + log "[$case_name] output dir: ${out_dir}" +} + build_one_impl() { local case_name="$1" - local case_dir="${CASES_ROOT}/${case_name}" + local case_path="${CASES_ROOT}/${case_name}" local case_token case_token="$(case_output_token "${case_name}")" local out_dir="${WORK_SPACE}/${case_token}" @@ -257,12 +343,18 @@ build_one_impl() { local kernel_so="${out_dir}/lib${case_token}_kernel.so" local -a ptoas_args=() + if is_ptodsl_case_path "${case_path}"; then + run_ptodsl_case "${case_name}" "${case_path}" "${out_dir}" + return 0 + fi + local case_dir="${case_path}" + + [[ -f "${case_dir}/kernel.pto" ]] || die "missing kernel.pto for ${case_name}" + [[ -f "${case_dir}/main.cpp" ]] || die "missing main.cpp for ${case_name}" [[ -f "${case_dir}/launch.cpp" ]] || die "missing launch.cpp for ${case_name}" [[ -f "${case_dir}/golden.py" ]] || die "missing golden.py for ${case_name}" [[ -f "${case_dir}/compare.py" ]] || die "missing compare.py for ${case_name}" - [[ -f "${case_dir}/kernel.pto" ]] || - die "missing kernel.pto for ${case_name}" if [[ -f "${case_dir}/ptoas.flags" ]]; then read -r -a ptoas_args < "${case_dir}/ptoas.flags" diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh index 98be7d669..2bb116972 100755 --- a/test/vpto/scripts/run_host_vpto_validation_parallel.sh +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -74,13 +74,41 @@ WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" RUNNER_LOG="${WORK_SPACE}/parallel-runner.log" +is_ptodsl_case_dir() { + [[ -f "$1/kernel.py" ]] +} + +is_ptodsl_case_file() { + local case_path="$1" + local base_name + base_name="$(basename "${case_path}")" + [[ -f "${case_path}" ]] && + [[ "${case_path}" == *.py ]] && + [[ "${base_name}" != "golden.py" ]] && + [[ "${base_name}" != "compare.py" ]] && + [[ "${base_name}" != "kernel.py" ]] && + [[ "${base_name}" != _* ]] +} + +is_ptodsl_case_path() { + is_ptodsl_case_dir "$1" || is_ptodsl_case_file "$1" +} + +validate_case_path() { + local case_name="$1" + local case_path="$2" + + if is_ptodsl_case_path "${case_path}"; then + return 0 + fi + [[ -d "${case_path}" ]] || die "case ${case_name} is neither a directory nor a PTODSL case file" + for f in launch.cpp main.cpp golden.py compare.py; do + [[ -f "${case_path}/${f}" ]] || die "case ${case_name} is missing ${f}" + done + [[ -f "${case_path}/kernel.pto" ]] || die "case ${case_name} must provide kernel.pto" +} + discover_cases() { - local required_files=( - launch.cpp - main.cpp - golden.py - compare.py - ) local onboard_only_prefix="onboard-only/" if [[ -n "${CASE_NAME}" ]]; then @@ -88,28 +116,39 @@ discover_cases() { "${CASE_NAME}" == "${onboard_only_prefix}"* ]]; then die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi - local requested_dir="${CASES_ROOT}/${CASE_NAME}" - [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - for f in "${required_files[@]}"; do - [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" - done - [[ -f "${requested_dir}/kernel.pto" ]] || - die "case ${CASE_NAME} must provide kernel.pto" + local requested_path="${CASES_ROOT}/${CASE_NAME}" + [[ -e "${requested_path}" ]] || die "unknown case: ${CASE_NAME}" + validate_case_path "${CASE_NAME}" "${requested_path}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - local ok=1 - for f in "${required_files[@]}"; do - if [[ ! -f "${dir}/${f}" ]]; then - ok=0 - break - fi - done - [[ "${ok}" -eq 1 ]] || continue - [[ -f "${dir}/kernel.pto" ]] || continue + [[ -f "${dir}/kernel.pto" || -f "${dir}/kernel.py" ]] || continue local rel="${dir#${CASES_ROOT}/}" + if ! is_ptodsl_case_dir "${dir}"; then + local ok=1 + for f in launch.cpp main.cpp golden.py compare.py; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + fi + if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && + "${rel}" == "${onboard_only_prefix}"* ]]; then + continue + fi + if [[ -n "${CASE_PREFIX}" && "${rel}" != "${CASE_PREFIX}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done + + find "${CASES_ROOT}" -type f -name '*.py' | sort | while read -r path; do + is_ptodsl_case_file "${path}" || continue + local rel="${path#${CASES_ROOT}/}" if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && "${rel}" == "${onboard_only_prefix}"* ]]; then continue