From 13d9486decac6a08f642514ba0024a52f631dd86 Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Tue, 30 Jun 2026 19:44:55 +0000 Subject: [PATCH 1/6] feat(ptodsl): add the first series of TileLib templates with metadata and constraint support --- ptodsl/ptodsl/tilelib/__init__.py | 113 ++++++++++ ptodsl/ptodsl/tilelib/_render_runtime.py | 156 ++++++++++++++ ptodsl/ptodsl/tilelib/author.py | 107 ++++++++++ ptodsl/ptodsl/tilelib/constraints.py | 174 +++++++++++++++ ptodsl/ptodsl/tilelib/decorator.py | 93 ++++++++ ptodsl/ptodsl/tilelib/metadata.py | 160 ++++++++++++++ ptodsl/ptodsl/tilelib/registry.py | 160 ++++++++++++++ ptodsl/ptodsl/tilelib/render.py | 85 ++++++++ ptodsl/ptodsl/tilelib/templates/__init__.py | 39 ++++ .../ptodsl/tilelib/templates/a5/__init__.py | 8 + ptodsl/ptodsl/tilelib/templates/a5/tadd.py | 110 ++++++++++ ptodsl/ptodsl/tilelib/templates/a5/tbinop.py | 200 ++++++++++++++++++ ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py | 66 ++++++ ptodsl/ptodsl/tilelib/templates/a5/tdiv.py | 39 ++++ ptodsl/ptodsl/tilelib/templates/a5/tmax.py | 33 +++ ptodsl/ptodsl/tilelib/templates/a5/tmin.py | 33 +++ ptodsl/ptodsl/tilelib/templates/a5/tmul.py | 110 ++++++++++ ptodsl/ptodsl/tilelib/templates/a5/tsub.py | 33 +++ .../fixtures/tadd_a5_8x64_f32.golden.mlir | 38 ++++ ptodsl/tests/test_tilelib_constraints.py | 52 +++++ ptodsl/tests/test_tilelib_elementwise.py | 49 +++++ ptodsl/tests/test_tilelib_render.py | 71 +++++++ ptodsl/tests/test_tilelib_select.py | 108 ++++++++++ scripts/ptoas_env.sh | 3 + 24 files changed, 2040 insertions(+) create mode 100644 ptodsl/ptodsl/tilelib/__init__.py create mode 100644 ptodsl/ptodsl/tilelib/_render_runtime.py create mode 100644 ptodsl/ptodsl/tilelib/author.py create mode 100644 ptodsl/ptodsl/tilelib/constraints.py create mode 100644 ptodsl/ptodsl/tilelib/decorator.py create mode 100644 ptodsl/ptodsl/tilelib/metadata.py create mode 100644 ptodsl/ptodsl/tilelib/registry.py create mode 100644 ptodsl/ptodsl/tilelib/render.py create mode 100644 ptodsl/ptodsl/tilelib/templates/__init__.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/__init__.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tadd.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tbinop.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tdiv.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tmax.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tmin.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tmul.py create mode 100644 ptodsl/ptodsl/tilelib/templates/a5/tsub.py create mode 100644 ptodsl/tests/fixtures/tadd_a5_8x64_f32.golden.mlir create mode 100644 ptodsl/tests/test_tilelib_constraints.py create mode 100644 ptodsl/tests/test_tilelib_elementwise.py create mode 100644 ptodsl/tests/test_tilelib_render.py create mode 100644 ptodsl/tests/test_tilelib_select.py diff --git a/ptodsl/ptodsl/tilelib/__init__.py b/ptodsl/ptodsl/tilelib/__init__.py new file mode 100644 index 000000000..4cc104e7c --- /dev/null +++ b/ptodsl/ptodsl/tilelib/__init__.py @@ -0,0 +1,113 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib: ptodsl-native templates for ExpandTileOp (migration of tilelang-dsl). + +Layers: + - metadata : TileSpec + dtypes + TemplateMetadata + - author : body ops (for_/get_lanes/make_mask/vlds/vadd/vsts), engine-routed + - decorator : @tile_template (registers a version + its metadata) + - registry : constraint/priority selection among registered versions + - render : render_best(op, target, specs) + CLI (the ExpandTileOp seam) + - templates/ : the ported per-arch template bodies +""" + +from .author import ( + PostUpdate, + Tile, + addptr, + const_expr, + for_, + get_lanes, + if_, + make_mask, + static_range, + vadd, + vdiv, + vecscope, + vlds, + vmax, + vmin, + vmul, + vsts, + vsub, + yield_, +) +from .constraints import ( + BLayout, + SLayout, + check_layout, + check_memory_space, + check_type, + require_contiguous, +) +from .decorator import SpecializedTileTemplate, TileTemplate, tile_template +from .metadata import ScalarType, TemplateMetadata, TileSpec, bf16, f16, f32, i8, i16, i32 +from .registry import ( + AmbiguousTemplate, + NoMatchingTemplate, + TileTemplateRegistry, + default_registry, + legal_candidates, + register, + select, +) +from .render import render_best, select_and_specialize + +__all__ = [ + # authoring surface + "Tile", + "PostUpdate", + "tile_template", + "for_", + "static_range", + "if_", + "yield_", + "const_expr", + "vecscope", + "get_lanes", + "make_mask", + "addptr", + "vlds", + "vadd", + "vsub", + "vmul", + "vmax", + "vmin", + "vdiv", + "vsts", + # specs / metadata + "TileSpec", + "ScalarType", + "TemplateMetadata", + "BLayout", + "SLayout", + "check_type", + "check_memory_space", + "check_layout", + "require_contiguous", + "f32", + "f16", + "bf16", + "i32", + "i16", + "i8", + # descriptors + "TileTemplate", + "SpecializedTileTemplate", + # registry / selection + "TileTemplateRegistry", + "default_registry", + "legal_candidates", + "register", + "select", + "NoMatchingTemplate", + "AmbiguousTemplate", + # rendering + "render_best", + "select_and_specialize", +] diff --git a/ptodsl/ptodsl/tilelib/_render_runtime.py b/ptodsl/ptodsl/tilelib/_render_runtime.py new file mode 100644 index 000000000..e72e855be --- /dev/null +++ b/ptodsl/ptodsl/tilelib/_render_runtime.py @@ -0,0 +1,156 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +TileLib render runtime. + +Traces a tilelang-style template body into a standalone ``func.func`` whose MLIR is on +par with the legacy tilelang-dsl render (``tile_buf_addr`` -> memref, ``memref.subview``, +``pto.vlds/vadd/vsts``, dynamic ``pto.tile_valid_rows/cols``). + +Control flow is handled by the engine's AST rewrite (``rewrite_jit_function``): plain +``for x in range(...)`` in the template body is rewritten at trace time to +``pto.for_(...).carry(...)`` (the ``_control_flow`` surface, re-exported via :mod:`author`), +with loop-carried variables detected by liveness. This module therefore owns only the +entry tile_buf typing and the golden-shaped module/func container; loops, slicing, +valid-shape, mask and vector ops all come from the existing ptodsl engine. +""" + +from __future__ import annotations + +import inspect + +from .metadata import TileSpec, scalar_descriptor +from .._ast_rewrite import rewrite_jit_function +from .._bootstrap import make_context +from .._surface_types import Tile +from .._surface_values import TileValue +from .._tracing import KernelModuleSpec, ModuleStyle, TracingRuntime +from .._tracing.active import activate_runtime, activate_session +from .._types import _resolve + +from mlir.dialects import func +from mlir.ir import Attribute, InsertionPoint, Location, Module, StringAttr, UnitAttr + + +# ── tile handle handed to the template body ──────────────────────────────────────── + +class _TemplateTile(TileValue): + """Engine ``TileValue`` with the template-author alias ``element_type`` and forced + dynamic ``valid_shape`` (emit ``pto.tile_valid_rows/cols`` rather than folding the + static ``v_row/v_col`` carried in the tile_buf type). + + Metadata (shape/dtype/memory_space) is supplied from the ``TileSpec`` because a raw + entry-block ``tile_buf`` type is not introspectable by ``parse_tile_type_metadata``; + supplying it explicitly takes the fast path in ``infer_memref_type_from_surface_value``. + """ + + def __init__(self, value, spec: TileSpec): + elem = _resolve(scalar_descriptor(spec.dtype)) + super().__init__( + value, + shape=tuple(spec.shape), + physical_shape=tuple(spec.shape), + dtype=elem, + memory_space=spec.memory_space, + valid_shape=None, + ) + # Force the dynamic valid-shape ops to match the tilelang render. + self.static_valid_shape = None + self._valid_shape._cache.clear() + + @property + def element_type(self): + return self.dtype + + +# ── tracing runtime ──────────────────────────────────────────────────────────────── + +class _TemplateTrace(TracingRuntime): + def __init__(self, descriptor, tile_specs: dict): + super().__init__( + KernelModuleSpec( + function_name=descriptor.name, + target_arch=descriptor.target, + kernel_kind="vector", + mode="auto", + module_style=ModuleStyle.NESTED, + source_file=inspect.getsourcefile(descriptor.py_fn) or inspect.getfile(descriptor.py_fn), + source_line=getattr(descriptor.py_fn.__code__, "co_firstlineno", None), + ) + ) + self.descriptor = descriptor + self.tile_specs = tile_specs + self._ordered_specs: list = [] + self._signature_parameters = tuple(inspect.signature(descriptor.py_fn).parameters.items()) + + def compute_argument_types(self): + arg_types = [] + ordered = [] + for param_name, param in self._signature_parameters: + if not _is_tile_annotation(param.annotation): + raise TypeError( + f"tile-template parameters must be annotated Tile; {param_name!r} is {param.annotation!r}" + ) + spec = self.tile_specs.get(param_name) + if spec is None: + raise ValueError(f"missing TileSpec for parameter {param_name!r}") + ordered.append((param_name, spec)) + arg_types.append(spec.mlir_type()) + self._ordered_specs = ordered + return arg_types + + def bind_entry_arguments(self, entry_arguments): + return tuple( + _TemplateTile(arg, spec) for arg, (_, spec) in zip(entry_arguments, self._ordered_specs) + ) + + def trace_entry(self, *args): + # Apply the engine's AST control-flow rewrite so the template body can use plain + # `for x in range(...)` (rewritten to pto.for_(...).carry(...)) like tilelang. + rewritten = rewrite_jit_function(self.descriptor.py_fn) + rewritten(*args) + + # Custom golden-shaped container: single module(target_arch) + func(instance, kernel_kind). + def build_module(self): + ctx = make_context() + with ctx, Location.unknown(): + arg_types = list(self.compute_argument_types()) + module, ir_fn = self._create_instance_module(arg_types) + session = self.create_session(module, ir_fn) + entry = ir_fn.add_entry_block() + with InsertionPoint(entry), activate_runtime(self), activate_session(session): + self.initialize_session(session, entry) + args = self.bind_entry_arguments(entry.arguments) + self.trace_entry(*args) + self.validate_trace_state() + self.emit_return() + self.finalize_session(session) + session.validate_final_state() + self.verify_module(module) + return module + + def _create_instance_module(self, arg_types): + module = Module.create() + module.operation.attributes["pto.target_arch"] = StringAttr.get(self.descriptor.target) + with InsertionPoint(module.body): + fn_ty = func.FunctionType.get(arg_types, []) + ir_fn = func.FuncOp(self.descriptor.name, fn_ty) + ir_fn.attributes["pto.tilelang.instance"] = UnitAttr.get() + ir_fn.attributes["pto.kernel_kind"] = Attribute.parse("#pto.kernel_kind") + return module, ir_fn + + +def _is_tile_annotation(annotation) -> bool: + if annotation is Tile: + return True + if isinstance(annotation, str): + return annotation == "Tile" or annotation.endswith(".Tile") + return getattr(annotation, "__name__", None) == "Tile" + + +__all__ = ["_TemplateTrace", "_TemplateTile"] diff --git a/ptodsl/ptodsl/tilelib/author.py b/ptodsl/ptodsl/tilelib/author.py new file mode 100644 index 000000000..048252abd --- /dev/null +++ b/ptodsl/ptodsl/tilelib/author.py @@ -0,0 +1,107 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Author-facing TileLib body surface. + +A template imports this namespace as ``pto`` and writes a tilelang-style body. Control +flow uses the **engine's** AST-rewrite surface (``for x in range(...)`` is rewritten to +``pto.for_(...).carry(...)`` at trace time — see ``_render_runtime`` / ``_ast_rewrite``), +so ``for_``/``static_range``/``if_``/``yield_`` are re-exported from ``_control_flow`` +rather than reimplemented. The body ops route to the existing ``_ops`` engine. +""" + +from __future__ import annotations + +# Engine control-flow surface (target of the AST rewrite). +from .._control_flow import const_expr, for_, if_, static_range, vecscope, yield_ +from .._surface_types import PostUpdate, Tile +from .. import _ops +from .._types import _resolve + + +def get_lanes(dtype) -> int: + """Vector lanes for *dtype* (256-byte vreg). Returns a Python int used as a loop step.""" + return _ops._elements_per_vreg(_resolve(dtype)) + + +def make_mask(dtype, value): + """``pto.plt_b{8,16,32}`` predicate mask; returns a ``(mask, remained)``-unpackable value.""" + return _ops.make_mask(dtype, value) + + +def addptr(base_ptr, index_offset): + """``pto.addptr`` – advance a pointer by an element offset.""" + return _ops.addptr(base_ptr, index_offset) + + +def vlds(src_ptr, offset=None, result_vreg_type=None, *, dist=None, post_update=PostUpdate.OFF): + """``pto.vlds`` from a tile slice or pointer.""" + return _ops.vlds( + src_ptr, + offset, + result_vreg_type, + dist=dist, + post_update=post_update, + ) + + +def vadd(lhs, rhs, mask): + """``pto.vadd`` element-wise add.""" + return _ops.vadd(lhs, rhs, mask) + + +def vsub(lhs, rhs, mask): + """``pto.vsub`` element-wise subtract.""" + return _ops.vsub(lhs, rhs, mask) + + +def vmul(lhs, rhs, mask): + """``pto.vmul`` element-wise multiply.""" + return _ops.vmul(lhs, rhs, mask) + + +def vmax(lhs, rhs, mask): + """``pto.vmax`` element-wise maximum.""" + return _ops.vmax(lhs, rhs, mask) + + +def vmin(lhs, rhs, mask): + """``pto.vmin`` element-wise minimum.""" + return _ops.vmin(lhs, rhs, mask) + + +def vdiv(lhs, rhs, mask): + """``pto.vdiv`` element-wise divide (default precision).""" + return _ops.vdiv(lhs, rhs, mask) + + +def vsts(vec, dst_ptr, offset, mask=None, *, dist=None, post_update=PostUpdate.OFF): + """``pto.vsts`` to a tile slice or pointer.""" + return _ops.vsts(vec, dst_ptr, offset, mask, dist=dist, post_update=post_update) + + +__all__ = [ + "Tile", + "for_", + "static_range", + "if_", + "yield_", + "const_expr", + "vecscope", + "get_lanes", + "make_mask", + "addptr", + "PostUpdate", + "vlds", + "vadd", + "vsub", + "vmul", + "vmax", + "vmin", + "vdiv", + "vsts", +] diff --git a/ptodsl/ptodsl/tilelib/constraints.py b/ptodsl/ptodsl/tilelib/constraints.py new file mode 100644 index 000000000..2a6601495 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/constraints.py @@ -0,0 +1,174 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Constraint predicates + evaluation for TileLib version selection. + +A template may declare ``constraints=[predicate, ...]`` (legality rules, e.g. "row-major +layout and a 1-row output"). During selection we build a per-operand context and call each +predicate by **name-matching its parameters** against that context — the same introspection +convention as tilelang-dsl's `_evaluate_constraints`, so predicates port verbatim. The +predicate receives keys like ``src_shape`` / ``dst_valid_shape`` / ``src_config`` and returns +a truthy value when legal. + +``BLayout`` / ``SLayout`` mirror tilelang's enums so a copied predicate's +``cfg.b_layout != pto.BLayout.ROW_MAJOR`` comparison works unchanged (str enums compare equal +to the raw layout strings carried in operand specs). +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from enum import Enum + + +class BLayout(str, Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +class SLayout(str, Enum): + NONE_BOX = "none_box" + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + +@dataclass(frozen=True) +class _ConfigView: + """The ``{name}_config`` object a constraint sees (``.b_layout`` / ``.s_layout`` strings, + which compare equal to the BLayout/SLayout str-enums).""" + + b_layout: str + s_layout: str + + +def build_context(tile_specs: dict, target: str, op: str) -> dict: + """Build the flat name-keyed context predicates are matched against.""" + context: dict = {"target": target, "op": op} + operand_dtypes = [] + operand_memory_spaces = [] + operand_rows = [] + operand_cols = [] + operand_sizes = [] + operand_valid_cols = [] + operand_b_layouts = [] + for name, spec in tile_specs.items(): + shape = tuple(spec.shape) + valid = tuple(spec.valid_shape) if getattr(spec, "valid_shape", None) else shape + dtype = spec.dtype.name + memory_space = getattr(spec, "memory_space", "ub") + b_layout = getattr(spec, "b_layout", "row_major") + s_layout = getattr(spec, "s_layout", "none_box") + operand_dtypes.append(dtype) + operand_memory_spaces.append(memory_space) + operand_sizes.append(_shape_size(shape)) + operand_b_layouts.append(b_layout) + context[f"{name}_shape"] = shape + context[f"{name}_valid_shape"] = valid + context[f"{name}_dtype"] = dtype + context[f"{name}_memory_space"] = memory_space + context[f"{name}_config"] = _ConfigView( + b_layout=b_layout, + s_layout=s_layout, + ) + if len(shape) == 2: + context[f"{name}_rows"], context[f"{name}_cols"] = shape + context[f"{name}_valid_rows"], context[f"{name}_valid_cols"] = valid + operand_rows.append(shape[0]) + operand_cols.append(shape[1]) + operand_valid_cols.append(valid[1]) + context["operand_dtypes"] = tuple(operand_dtypes) + context["operand_memory_spaces"] = tuple(operand_memory_spaces) + context["operand_rows"] = tuple(operand_rows) + context["operand_cols"] = tuple(operand_cols) + context["operand_sizes"] = tuple(operand_sizes) + context["operand_valid_cols"] = tuple(operand_valid_cols) + context["operand_b_layouts"] = tuple(operand_b_layouts) + return context + + +def _shape_size(shape): + size = 1 + for dim in shape: + size *= dim + return size + + +def check_type(expected): + expected = tuple(expected) + + def _check_type(operand_dtypes, **_): + return tuple(operand_dtypes) == expected + + return _check_type + + +def check_memory_space(expected): + def _check_memory_space(operand_memory_spaces, **_): + return all(space == expected for space in operand_memory_spaces) + + return _check_memory_space + + +def check_layout(expected): + def _check_layout(operand_b_layouts, **_): + return all(layout == expected for layout in operand_b_layouts) + + return _check_layout + + +def require_contiguous(required=True): + def _require_contiguous(operand_rows, operand_cols, operand_valid_cols, **_): + if not required: + return True + full_cols = all(valid == cols for valid, cols in zip(operand_valid_cols, operand_cols)) + single_row = all(rows == 1 for rows in operand_rows) + return full_cols or single_row + + return _require_contiguous + + +def passes(predicates, context: dict) -> bool: + """Return True iff every predicate is satisfied for *context* (legality filter).""" + for predicate in predicates: + try: + signature = inspect.signature(predicate) + except (TypeError, ValueError): + return False + kwargs: dict = {} + for parameter in signature.parameters.values(): + if parameter.kind == inspect.Parameter.VAR_KEYWORD: + for key, value in context.items(): + kwargs.setdefault(key, value) + continue + if parameter.kind == inspect.Parameter.VAR_POSITIONAL: + continue + if parameter.name in context: + kwargs[parameter.name] = context[parameter.name] + elif parameter.default is not inspect.Parameter.empty: + continue + else: + # A required parameter we can't supply -> treat as not satisfiable. + return False + try: + if not predicate(**kwargs): + return False + except Exception: + return False + return True + + +__all__ = [ + "BLayout", + "SLayout", + "build_context", + "check_layout", + "check_memory_space", + "check_type", + "passes", + "require_contiguous", +] diff --git a/ptodsl/ptodsl/tilelib/decorator.py b/ptodsl/ptodsl/tilelib/decorator.py new file mode 100644 index 000000000..1ff46ef32 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/decorator.py @@ -0,0 +1,93 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""The ``@tile_template`` decorator + the registered descriptor / specialization artifact.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass + +from . import registry as _registry +from ._render_runtime import _TemplateTrace +from .metadata import TemplateMetadata +from .._tracing import ModuleArtifact + + +@dataclass(frozen=True) +class TileTemplate: + """A registered template version: the Python body + its metadata + parameter order.""" + + py_fn: object + metadata: TemplateMetadata + param_names: tuple + + @property + def name(self) -> str: + return self.metadata.name + + @property + def target(self) -> str: + return self.metadata.target + + @property + def op(self) -> str: + return self.metadata.op + + def specialize(self, **tile_specs) -> "SpecializedTileTemplate": + return SpecializedTileTemplate(self, tile_specs) + + +class SpecializedTileTemplate(ModuleArtifact): + """A ``TileTemplate`` bound to concrete ``TileSpec``s; ``.mlir_text()`` renders it.""" + + def __init__(self, descriptor: TileTemplate, tile_specs: dict): + super().__init__( + descriptor.name, + module_factory=lambda: _TemplateTrace(descriptor, tile_specs).build_module(), + ) + self.descriptor = descriptor + self.tile_specs = tile_specs + + +def tile_template(*, op, target="a5", name=None, dtypes=(), layouts=(), + memory_spaces=(), constraints=(), priority=0, fusible=False, + loop_depth=None, id=None, Tail=None, is_post_update=False, + tags=(), register=True): + """Register a Python function as a TileLib implementation of *op* for *target*.""" + if target != "a5": + raise ValueError("tile-template tracing currently only supports target='a5'") + + def decorator(fn): + descriptor = TileTemplate( + py_fn=fn, + metadata=TemplateMetadata.build( + op=op, + target=target, + name=name or fn.__name__, + dtypes=dtypes, + layouts=layouts, + memory_spaces=memory_spaces, + constraints=constraints, + priority=priority, + fusible=fusible, + loop_depth=loop_depth, + id=id, + Tail=Tail, + is_post_update=is_post_update, + tags=tags, + ), + param_names=tuple(inspect.signature(fn).parameters.keys()), + ) + if register: + _registry.register(descriptor) + return descriptor + + return decorator + + +__all__ = ["TileTemplate", "SpecializedTileTemplate", "tile_template"] diff --git a/ptodsl/ptodsl/tilelib/metadata.py b/ptodsl/ptodsl/tilelib/metadata.py new file mode 100644 index 000000000..f55d9702e --- /dev/null +++ b/ptodsl/ptodsl/tilelib/metadata.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLib template metadata + tile specialization specs. + +``TemplateMetadata`` carries both the *hard constraints* used to decide whether a +template is legal for a concrete TileOp (op/target/dtypes/layouts/memory_spaces) and the +*selection hints* used to rank legal candidates (priority/fusible/tags). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from .._types import ( + float16 as _float16, + float32 as _float32, + int8 as _int8, + int16 as _int16, + int32 as _int32, + int64 as _int64, + tile_buf_type as _tile_buf_type, +) + +from mlir.ir import Type + + +@dataclass(frozen=True) +class ScalarType: + """Author-facing dtype tag (used to build entry tile_buf types at specialize time).""" + + name: str + + def __repr__(self) -> str: + return self.name + + +f32 = ScalarType("f32") +f16 = ScalarType("f16") +bf16 = ScalarType("bf16") +i32 = ScalarType("i32") +i16 = ScalarType("i16") +i8 = ScalarType("i8") + + +def scalar_descriptor(dtype: ScalarType): + """Map a TileLib ``ScalarType`` to a ptodsl ``_types`` dtype descriptor.""" + descriptors = { + "f32": _float32, + "f16": _float16, + "bf16": Type.parse("bf16"), + "i8": _int8, + "i16": _int16, + "i32": _int32, + "i64": _int64, + } + descriptor = descriptors.get(dtype.name) + if descriptor is None: + raise ValueError(f"unsupported scalar dtype {dtype.name}") + return descriptor + + +@dataclass(frozen=True) +class TileSpec: + """Concrete specialization of one tile operand. + + ``valid_shape``/``b_layout``/``s_layout`` are carried for constraint evaluation + (selection). Rendering currently always emits row-major/none-box tile_buf types; a + non-row-major operand is rejected by the relevant template's constraints before render. + """ + + shape: tuple + dtype: ScalarType + memory_space: str = "ub" + valid_shape: tuple | None = None + b_layout: str = "row_major" + s_layout: str = "none_box" + + def __post_init__(self): + if len(self.shape) != 2: + raise ValueError("TileSpec currently only supports rank-2 tile shapes") + if any(not isinstance(dim, int) or dim <= 0 for dim in self.shape): + raise ValueError("TileSpec.shape must contain positive integers") + if self.memory_space != "ub": + raise ValueError("TileSpec currently only supports ub tiles") + + def mlir_type(self): + rows, cols = self.shape + return _tile_buf_type( + [rows, cols], + scalar_descriptor(self.dtype), + [rows, cols], + blayout="RowMajor", + address_space=self.memory_space, + slayout="NoneBox", + fractal_size=512, + pad="Null", + ) + + +@dataclass(frozen=True) +class TemplateMetadata: + """Hard constraints + selection hints for one registered template version.""" + + op: str + target: str + name: str + # Hard constraints + dtypes: tuple = () # tuple of per-operand dtype-name tuples, e.g. (("f32","f32","f32"),) + layouts: tuple = () + memory_spaces: tuple = () + # Hard constraints (legality predicates: callables matched by param name — see constraints.py) + constraints: tuple = () + # Selection hints + priority: int = 0 + fusible: bool = False + loop_depth: int | None = None + id: int | None = None + Tail: object = None + is_post_update: bool = False + tags: tuple = () + + @staticmethod + def build(*, op, target, name, dtypes=(), layouts=(), memory_spaces=(), + constraints=(), priority=0, fusible=False, loop_depth=None, + id=None, Tail=None, is_post_update=False, tags=()): + return TemplateMetadata( + op=op, + target=target, + name=name, + dtypes=tuple(tuple(sig) for sig in dtypes), + layouts=tuple(layouts), + memory_spaces=tuple(memory_spaces), + constraints=tuple(constraints), + priority=priority, + fusible=fusible, + loop_depth=loop_depth, + id=id, + Tail=Tail, + is_post_update=bool(is_post_update), + tags=tuple(tags), + ) + + +__all__ = [ + "ScalarType", + "TileSpec", + "TemplateMetadata", + "scalar_descriptor", + "f32", + "f16", + "bf16", + "i32", + "i16", + "i8", +] diff --git a/ptodsl/ptodsl/tilelib/registry.py b/ptodsl/ptodsl/tilelib/registry.py new file mode 100644 index 000000000..404971f72 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/registry.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLib registry + version selection. + +Mirrors the *logic* of tilelang-dsl's ``select_kernel`` (filter legal candidates, rank by +priority) but is engine-agnostic: it returns the chosen ``TileTemplate`` descriptor, which +the caller renders via ptodsl's engine. Selection order: + + 1. filter by op + target (hard) + 2. filter by dtype-signature (hard) + 3. no legal candidate -> error + 4. one legal candidate -> choose it + 5. several -> highest priority wins; remaining ties -> error +""" + +from __future__ import annotations + +from . import constraints as _constraints + + +class NoMatchingTemplate(Exception): + pass + + +class AmbiguousTemplate(Exception): + pass + + +class TileTemplateRegistry: + def __init__(self): + self._descriptors: list = [] + + def register(self, descriptor) -> None: + # Re-registration (e.g. module reload) replaces the prior entry with the same name. + self._descriptors = [ + d for d in self._descriptors + if not (d.op == descriptor.op and d.target == descriptor.target and d.name == descriptor.name) + ] + self._descriptors.append(descriptor) + + def all(self) -> tuple: + return tuple(self._descriptors) + + def lookup(self, op: str, target: str) -> list: + return [d for d in self._descriptors if d.op == op and d.target == target] + + def legal_candidates(self, op: str, target: str, tile_specs: dict, + context_attrs: dict | None = None) -> list: + candidates = self.lookup(op, target) + if not candidates: + raise NoMatchingTemplate(f"no template registered for op={op!r} target={target!r}") + + legal = [d for d in candidates if _dtype_signature_matches(d, tile_specs)] + if not legal: + sig = _dtype_signature(candidates[0], tile_specs) + raise NoMatchingTemplate( + f"no template for op={op!r} target={target!r} matches dtype signature {sig}" + ) + + # Hard legality constraints (e.g. layout / valid-shape rules). + context = _constraints.build_context(tile_specs, target, op) + legal = [d for d in legal if _constraints.passes(d.metadata.constraints, context)] + if not legal: + raise NoMatchingTemplate( + f"no template for op={op!r} target={target!r} satisfies its constraints " + f"(layout/valid-shape) for the given operands" + ) + + legal.sort(key=lambda d: d.metadata.priority, reverse=True) + return legal + + def select(self, op: str, target: str, tile_specs: dict, + context_attrs: dict | None = None, candidate_id: str | None = None): + legal = self.legal_candidates(op, target, tile_specs, context_attrs) + if candidate_id: + for descriptor in legal: + if descriptor.name == candidate_id: + return descriptor + legal_names = ", ".join(d.name for d in legal) + raise NoMatchingTemplate( + f"candidate {candidate_id!r} is not a legal template for op={op!r} " + f"target={target!r}; legal candidates: {legal_names}" + ) + + if len(legal) == 1: + return legal[0] + + top_priority = legal[0].metadata.priority + winners = [d for d in legal if d.metadata.priority == top_priority] + if len(winners) > 1: + names = ", ".join(d.name for d in winners) + raise AmbiguousTemplate( + f"multiple templates tie at priority {top_priority} for op={op!r} target={target!r}: {names}" + ) + return legal[0] + + +def _dtype_signature(descriptor, tile_specs: dict) -> tuple: + """Per-operand dtype-name tuple in the template's parameter order.""" + return tuple(tile_specs[name].dtype.name for name in descriptor.param_names) + + +def _dtype_signature_matches(descriptor, tile_specs: dict) -> bool: + # Empty dtypes metadata == "accepts any dtype signature". + if not descriptor.metadata.dtypes: + return True + try: + sig = _dtype_signature(descriptor, tile_specs) + except KeyError: + return False + return sig in descriptor.metadata.dtypes + + +# Process-wide default registry (the decorator registers into this one). +_DEFAULT_REGISTRY = TileTemplateRegistry() + + +def default_registry() -> TileTemplateRegistry: + return _DEFAULT_REGISTRY + + +def register(descriptor) -> None: + _DEFAULT_REGISTRY.register(descriptor) + + +def _load_default_templates(op: str, target: str) -> None: + # Import lazily to avoid a registry/templates import cycle during package + # initialization. The loader is cached and registers descriptors as a + # module-import side effect. + from .templates import load_template + + load_template(op, target) + + +def legal_candidates(op: str, target: str, tile_specs: dict, + context_attrs: dict | None = None): + _load_default_templates(op, target) + return _DEFAULT_REGISTRY.legal_candidates(op, target, tile_specs, context_attrs) + + +def select(op: str, target: str, tile_specs: dict, context_attrs: dict | None = None, + candidate_id: str | None = None): + _load_default_templates(op, target) + return _DEFAULT_REGISTRY.select(op, target, tile_specs, context_attrs, candidate_id) + + +__all__ = [ + "TileTemplateRegistry", + "NoMatchingTemplate", + "AmbiguousTemplate", + "default_registry", + "legal_candidates", + "register", + "select", +] diff --git a/ptodsl/ptodsl/tilelib/render.py b/ptodsl/ptodsl/tilelib/render.py new file mode 100644 index 000000000..d18ca4fb4 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/render.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Render a selected TileLib template to MLIR. + +``render_best`` is the seam ``ExpandTileOp`` will call: select the legal/best template +for an op+target+operand specialization, then render it via ptodsl's engine. + +CLI (standalone, parallels lib/TileOps/render_template_mlir.py): + + python3 -m ptodsl.tilelib.render --op pto.tadd --target a5 \\ + --tile dst=8x64@ub:f32 --tile src0=8x64@ub:f32 --tile src1=8x64@ub:f32 \\ + -o /tmp/ptodsl_tadd_tilelib.mlir +""" + +from __future__ import annotations + +import argparse + +from . import registry as _registry +from .metadata import ScalarType, TileSpec + + +def select_and_specialize(op: str, target: str, tile_specs: dict, + context_attrs: dict | None = None, + candidate_id: str | None = None): + # Registry selection lazily imports only the module for this (target, op). + descriptor = _registry.select(op, target, tile_specs, context_attrs, candidate_id) + return descriptor.specialize(**tile_specs) + + +def render_best(op: str, target: str, tile_specs: dict, + context_attrs: dict | None = None, + candidate_id: str | None = None) -> str: + return select_and_specialize(op, target, tile_specs, context_attrs, candidate_id).mlir_text() + + +# ── CLI ───────────────────────────────────────────────────────────────────────────── + +_DTYPES = {"f32", "f16", "bf16", "i32", "i16", "i8"} + + +def _parse_tile_arg(spec: str): + """Parse ``name=RxCxMEM:dtype`` like ``dst=8x64@ub:f32``.""" + name, _, rest = spec.partition("=") + if not name or not rest: + raise argparse.ArgumentTypeError(f"invalid --tile {spec!r}; expected name=RxC@mem:dtype") + shape_mem, _, dtype = rest.partition(":") + shape_str, _, mem = shape_mem.partition("@") + mem = mem or "ub" + dtype = dtype or "f32" + if dtype not in _DTYPES: + raise argparse.ArgumentTypeError(f"unsupported dtype {dtype!r} in --tile {spec!r}") + dims = tuple(int(d) for d in shape_str.split("x")) + return name, TileSpec(shape=dims, dtype=ScalarType(dtype), memory_space=mem) + + +def main(argv=None): + parser = argparse.ArgumentParser(prog="ptodsl.tilelib.render") + parser.add_argument("--op", required=True) + parser.add_argument("--target", default="a5") + parser.add_argument("--tile", action="append", default=[], help="name=RxC@mem:dtype") + parser.add_argument("-o", "--output", default=None) + args = parser.parse_args(argv) + + tile_specs = dict(_parse_tile_arg(t) for t in args.tile) + text = render_best(args.op, args.target, tile_specs) + + if args.output: + with open(args.output, "w", encoding="utf-8") as handle: + handle.write(text) + print(f"wrote {args.output}") + else: + print(text) + + +if __name__ == "__main__": + main() + + +__all__ = ["render_best", "select_and_specialize"] diff --git a/ptodsl/ptodsl/tilelib/templates/__init__.py b/ptodsl/ptodsl/tilelib/templates/__init__.py new file mode 100644 index 000000000..4bf193a33 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Lazy loader for ported, per-architecture TileLib templates.""" + +from functools import lru_cache +from importlib import import_module + + +_TEMPLATE_MODULES = { + ("a5", "pto.tadd"): ".a5.tadd", + ("a5", "pto.tcolmax"): ".a5.tcolmax", + ("a5", "pto.tdiv"): ".a5.tdiv", + ("a5", "pto.tmax"): ".a5.tmax", + ("a5", "pto.tmin"): ".a5.tmin", + ("a5", "pto.tmul"): ".a5.tmul", + ("a5", "pto.tsub"): ".a5.tsub", +} + + +@lru_cache(maxsize=None) +def load_template(op: str, target: str) -> bool: + """Import and register only the template module for ``(target, op)``. + + Both this cache and Python's module cache make repeated requests no-ops. + Returns ``False`` when this TileLib has no module for the requested pair. + """ + module_name = _TEMPLATE_MODULES.get((target, op)) + if module_name is None: + return False + import_module(module_name, package=__name__) + return True + + +__all__ = ["load_template"] diff --git a/ptodsl/ptodsl/tilelib/templates/a5/__init__.py b/ptodsl/ptodsl/tilelib/templates/a5/__init__.py new file mode 100644 index 000000000..0cc0684bb --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""A5 TileLib template modules, loaded individually by operation.""" diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tadd.py b/ptodsl/ptodsl/tilelib/templates/a5/tadd.py new file mode 100644 index 000000000..bf9e3d62d --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tadd.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib templates for pto.tadd.""" + +import ptodsl.tilelib as pto + +from . import tbinop + + +class AddOp: + @staticmethod + def BinInstr(reg_src0, reg_src1, preg): + return pto.vadd(reg_src0, reg_src1, preg) + + +def TAdd(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, version): + tbinop.BinaryInstr(dst, src0, src1, AddOp, version) + + +@pto.tile_template( + op="pto.tadd", + target="a5", + name="template_tadd_2d_no_post_update", + id=0, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(False), + ], + priority=0, + loop_depth=2, + Tail=tbinop.has_tail, + is_post_update=False, + tags=["binop", "2d", "no_post_update"], +) +def template_tadd_2d_no_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TAdd(dst, src0, src1, tbinop.VFIMPL_2D_NO_POST_UPDATE) + + +@pto.tile_template( + op="pto.tadd", + target="a5", + name="template_tadd_1d_no_post_update", + id=1, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(True), + ], + priority=0, + loop_depth=1, + Tail=tbinop.has_tail, + is_post_update=False, + tags=["binop", "1d", "no_post_update"], +) +def template_tadd_1d_no_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TAdd(dst, src0, src1, tbinop.VFIMPL_1D_NO_POST_UPDATE) + + +@pto.tile_template( + op="pto.tadd", + target="a5", + name="template_tadd_2d_post_update", + id=2, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(False), + ], + priority=0, + loop_depth=2, + Tail=tbinop.has_tail, + is_post_update=True, + tags=["binop", "2d", "post_update"], +) +def template_tadd_2d_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TAdd(dst, src0, src1, tbinop.VFIMPL_2D_POST_UPDATE) + + +@pto.tile_template( + op="pto.tadd", + target="a5", + name="template_tadd_1d_post_update", + id=3, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(True), + ], + priority=0, + loop_depth=1, + Tail=tbinop.has_tail, + is_post_update=True, + tags=["binop", "1d", "post_update"], +) +def template_tadd_1d_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TAdd(dst, src0, src1, tbinop.VFIMPL_1D_POST_UPDATE) + + +# Compatibility alias for tests and examples that import the original tadd template. +template_tadd = template_tadd_2d_no_post_update diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tbinop.py b/ptodsl/ptodsl/tilelib/templates/a5/tbinop.py new file mode 100644 index 000000000..e5966fe1e --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tbinop.py @@ -0,0 +1,200 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared A5 binary elementwise TileLib helpers. + +This mirrors the pto-isa ``TBinOp`` shape at the PTODSL template level: op-specific +templates provide the vector instruction (``pto.vadd``, ``pto.vmul``, ...), while this +module owns the variant dispatch and traversal skeletons. +""" + +import ptodsl.tilelib as pto + +VFIMPL_1D_NO_POST_UPDATE = "1d_no_post_update" +VFIMPL_2D_NO_POST_UPDATE = "2d_no_post_update" +VFIMPL_1D_POST_UPDATE = "1d_post_update" +VFIMPL_2D_POST_UPDATE = "2d_post_update" + + +def is_single_row_tile(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + """Trace-time shape check for the basic 1D tile-slice implementation.""" + return src0.shape[0] == 1 and src1.shape[0] == 1 and dst.shape[0] == 1 + + +def has_tail(operand_sizes, **_): + # Placeholder until this matches the final binop tail rule. + return operand_sizes[0] % 8 != 0 + + +def BinaryInstr(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, op, version): + """Dispatch a binary TileOp implementation variant, mirroring pto-isa BinaryInstr.""" + if version == VFIMPL_1D_NO_POST_UPDATE: + TBinOps_1D_NoPostUpdate(dst, src0, src1, op) + elif version == VFIMPL_2D_NO_POST_UPDATE: + TBinOps_2D_NoPostUpdate(dst, src0, src1, op) + elif version == VFIMPL_1D_POST_UPDATE: + TBinOps_1D_PostUpdate(dst, src0, src1, op) + elif version == VFIMPL_2D_POST_UPDATE: + TBinOps_2D_PostUpdate(dst, src0, src1, op) + else: + if is_single_row_tile(src0, src1, dst): + TBinOps_1D_NoPostUpdate(dst, src0, src1, op) + else: + TBinOps_2D_NoPostUpdate(dst, src0, src1, op) + + +def TBinOps_1D_NoPostUpdate(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, op): + """Emit a basic single-row no-post-update binary op.""" + dtype = dst.element_type + _, valid_cols = dst.valid_shape + lanes = pto.get_lanes(dtype) + + col_loop = pto.for_(0, valid_cols, step=lanes).carry(remained=valid_cols) + with col_loop: + col = col_loop.iv + mask, remained = pto.make_mask(dtype, col_loop.remained) + vreg0 = pto.vlds(src0[0, col:]) + vreg1 = pto.vlds(src1[0, col:]) + vreg2 = op.BinInstr(vreg0, vreg1, mask) + pto.vsts(vreg2, dst[0, col:], mask) + col_loop.update(remained=remained) + + +def TBinOps_1D_PostUpdate(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, op): + """Emit a contiguous post-update binary op.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + lanes = pto.get_lanes(dtype) + valid_elems = valid_rows * valid_cols + + src0_ptr = src0.as_ptr() + src1_ptr = src1.as_ptr() + dst_ptr = dst.as_ptr() + + elem_loop = pto.for_(0, valid_elems, step=lanes).carry( + remained=valid_elems, + src0_ptr=src0_ptr, + src1_ptr=src1_ptr, + dst_ptr=dst_ptr, + ) + with elem_loop: + mask, remained = pto.make_mask(dtype, elem_loop.remained) + vreg0, src0_next = pto.vlds( + elem_loop.src0_ptr, lanes, post_update=pto.PostUpdate.ON + ) + vreg1, src1_next = pto.vlds( + elem_loop.src1_ptr, lanes, post_update=pto.PostUpdate.ON + ) + vreg2 = op.BinInstr(vreg0, vreg1, mask) + dst_next = pto.vsts( + vreg2, elem_loop.dst_ptr, lanes, mask, post_update=pto.PostUpdate.ON + ) + elem_loop.update( + remained=remained, + src0_ptr=src0_next, + src1_ptr=src1_next, + dst_ptr=dst_next, + ) + + +def TBinOps_2D_NoPostUpdate(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, op): + """Emit the generic row/column no-post-update binary op.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + lanes = pto.get_lanes(dtype) + + with pto.for_(0, valid_rows, step=1) as row: + col_loop = pto.for_(0, valid_cols, step=lanes).carry(remained=valid_cols) + with col_loop: + col = col_loop.iv + mask, remained = pto.make_mask(dtype, col_loop.remained) + vreg0 = pto.vlds(src0[row, col:]) + vreg1 = pto.vlds(src1[row, col:]) + vreg2 = op.BinInstr(vreg0, vreg1, mask) + pto.vsts(vreg2, dst[row, col:], mask) + col_loop.update(remained=remained) + + +def TBinOps_2D_PostUpdate(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, op): + """Emit a row-wise post-update binary op.""" + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + lanes = pto.get_lanes(dtype) + full_cols = (valid_cols // lanes) * lanes + tail_count = valid_cols % lanes + full_mask = pto.make_mask(dtype, "PAT_ALL") + dst_row_stride = dst.shape[1] + src0_row_stride = src0.shape[1] + src1_row_stride = src1.shape[1] + + src0_base = src0.as_ptr() + src1_base = src1.as_ptr() + dst_base = dst.as_ptr() + + with pto.for_(0, valid_rows, step=1) as row: + src0_row = pto.addptr(src0_base, row * src0_row_stride) + src1_row = pto.addptr(src1_base, row * src1_row_stride) + dst_row = pto.addptr(dst_base, row * dst_row_stride) + + col_loop = pto.for_(0, full_cols, step=lanes).carry( + src0_ptr=src0_row, + src1_ptr=src1_row, + dst_ptr=dst_row, + ) + with col_loop: + vreg0, src0_next = pto.vlds( + col_loop.src0_ptr, lanes, post_update=pto.PostUpdate.ON + ) + vreg1, src1_next = pto.vlds( + col_loop.src1_ptr, lanes, post_update=pto.PostUpdate.ON + ) + vreg2 = op.BinInstr(vreg0, vreg1, full_mask) + dst_next = pto.vsts( + vreg2, + col_loop.dst_ptr, + lanes, + full_mask, + post_update=pto.PostUpdate.ON, + ) + col_loop.update( + src0_ptr=src0_next, + src1_ptr=src1_next, + dst_ptr=dst_next, + ) + + with pto.if_(tail_count != 0) as tail: + with tail.then_: + mask, _ = pto.make_mask(dtype, tail_count) + vreg0, _ = pto.vlds( + col_loop.final("src0_ptr"), lanes, post_update=pto.PostUpdate.ON + ) + vreg1, _ = pto.vlds( + col_loop.final("src1_ptr"), lanes, post_update=pto.PostUpdate.ON + ) + vreg2 = op.BinInstr(vreg0, vreg1, mask) + pto.vsts( + vreg2, + col_loop.final("dst_ptr"), + lanes, + mask, + post_update=pto.PostUpdate.ON, + ) + + +__all__ = [ + "is_single_row_tile", + "has_tail", + "BinaryInstr", + "TBinOps_1D_NoPostUpdate", + "TBinOps_1D_PostUpdate", + "TBinOps_2D_NoPostUpdate", + "TBinOps_2D_PostUpdate", + "VFIMPL_1D_NO_POST_UPDATE", + "VFIMPL_2D_NO_POST_UPDATE", + "VFIMPL_1D_POST_UPDATE", + "VFIMPL_2D_POST_UPDATE", +] diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py b/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py new file mode 100644 index 000000000..c5d84cbcf --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib template for pto.tcolmax (ported from lib/TileOps/tcolmax_template.py). + +Column-wise max reduction → 1-row output. Body is verbatim from the tilelang template +(only the import + decorator changed). The legality predicate ``_validate_tcolmax`` is +copied as-is; it is matched by parameter name against the per-operand selection context +(see constraints.py). +""" + +import ptodsl.tilelib as pto + + +def _validate_tcolmax( + src_shape=(), + src_valid_shape=(), + dst_shape=(), + dst_valid_shape=(), + src_config=None, + dst_config=None, +): + if src_config is None or dst_config is None: + return False + if src_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if dst_config.b_layout != pto.BLayout.ROW_MAJOR: + return False + if src_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_config.s_layout != pto.SLayout.NONE_BOX: + return False + if dst_valid_shape[0] != 1: + return False + return True + + +@pto.tile_template( + op="pto.tcolmax", + target="a5", + name="template_tcolmax", + dtypes=[("f32", "f32")], + layouts=["row_major"], + memory_spaces=["ub"], + constraints=[_validate_tcolmax], + priority=0, +) +def template_tcolmax(src: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = src.valid_shape + + lanes = pto.get_lanes(dtype) + remained = valid_cols + + for col_chunk in range(0, valid_cols, lanes): + mask, remained = pto.make_mask(dtype, remained) + + acc = pto.vlds(src[0, col_chunk:]) + for row in range(1, valid_rows, 1): + row_vec = pto.vlds(src[row, col_chunk:]) + acc = pto.vmax(acc, row_vec, mask) + pto.vsts(acc, dst[0, col_chunk:], mask) diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py b/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py new file mode 100644 index 000000000..cd4113cda --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib template for pto.tdiv — default precision only. + +Ported from lib/TileOps/tdiv_template.py, but only the default-precision branch (plain +pto.vdiv). The high-precision (IEEE-754) path is deferred: it needs a `get_op_attr` +bridge to read the `precisionType` context attr the daemon already receives, plus the +div_hp algorithm — tracked as a follow-up. +""" + +import ptodsl.tilelib as pto + + +@pto.tile_template( + op="pto.tdiv", + target="a5", + name="template_tdiv", + dtypes=[("f32", "f32", "f32")], + layouts=["row_major"], + memory_spaces=["ub"], + priority=0, +) +def template_tdiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + divided = pto.vdiv(lhs, rhs, mask) + pto.vsts(divided, dst[row, col:], mask) diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tmax.py b/ptodsl/ptodsl/tilelib/templates/a5/tmax.py new file mode 100644 index 000000000..3c388541e --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tmax.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib template for pto.tmax (ported from lib/TileOps/tmax_template.py).""" + +import ptodsl.tilelib as pto + + +@pto.tile_template( + op="pto.tmax", + target="a5", + name="template_tmax", + dtypes=[("f32", "f32", "f32")], + layouts=["row_major"], + memory_spaces=["ub"], + priority=0, +) +def template_tmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + max_val = pto.vmax(lhs, rhs, mask) + pto.vsts(max_val, dst[row, col:], mask) diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tmin.py b/ptodsl/ptodsl/tilelib/templates/a5/tmin.py new file mode 100644 index 000000000..d0bcc28fd --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tmin.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib template for pto.tmin (ported from lib/TileOps/tmin_template.py).""" + +import ptodsl.tilelib as pto + + +@pto.tile_template( + op="pto.tmin", + target="a5", + name="template_tmin", + dtypes=[("f32", "f32", "f32")], + layouts=["row_major"], + memory_spaces=["ub"], + priority=0, +) +def template_tmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + min_val = pto.vmin(lhs, rhs, mask) + pto.vsts(min_val, dst[row, col:], mask) diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tmul.py b/ptodsl/ptodsl/tilelib/templates/a5/tmul.py new file mode 100644 index 000000000..91c216790 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tmul.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib templates for pto.tmul.""" + +import ptodsl.tilelib as pto + +from . import tbinop + + +class MulOp: + @staticmethod + def BinInstr(reg_src0, reg_src1, preg): + return pto.vmul(reg_src0, reg_src1, preg) + + +def TMul(dst: pto.Tile, src0: pto.Tile, src1: pto.Tile, version): + tbinop.BinaryInstr(dst, src0, src1, MulOp, version) + + +@pto.tile_template( + op="pto.tmul", + target="a5", + name="template_tmul_2d_no_post_update", + id=0, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(False), + ], + priority=0, + loop_depth=2, + Tail=tbinop.has_tail, + is_post_update=False, + tags=["binop", "2d", "no_post_update"], +) +def template_tmul_2d_no_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TMul(dst, src0, src1, tbinop.VFIMPL_2D_NO_POST_UPDATE) + + +@pto.tile_template( + op="pto.tmul", + target="a5", + name="template_tmul_1d_no_post_update", + id=1, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(False), + ], + priority=0, + loop_depth=1, + Tail=tbinop.has_tail, + is_post_update=False, + tags=["binop", "1d", "no_post_update"], +) +def template_tmul_1d_no_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TMul(dst, src0, src1, tbinop.VFIMPL_1D_NO_POST_UPDATE) + + +@pto.tile_template( + op="pto.tmul", + target="a5", + name="template_tmul_2d_post_update", + id=2, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(False), + ], + priority=0, + loop_depth=2, + Tail=tbinop.has_tail, + is_post_update=True, + tags=["binop", "2d", "post_update"], +) +def template_tmul_2d_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TMul(dst, src0, src1, tbinop.VFIMPL_2D_POST_UPDATE) + + +@pto.tile_template( + op="pto.tmul", + target="a5", + name="template_tmul_1d_post_update", + id=3, + constraints=[ + pto.check_type(("f32", "f32", "f32")), + pto.check_memory_space("ub"), + pto.check_layout("row_major"), + pto.require_contiguous(True), + ], + priority=0, + loop_depth=1, + Tail=tbinop.has_tail, + is_post_update=True, + tags=["binop", "1d", "post_update"], +) +def template_tmul_1d_post_update(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + TMul(dst, src0, src1, tbinop.VFIMPL_1D_POST_UPDATE) + + +# Compatibility alias for tests and examples that import the original tmul template. +template_tmul = template_tmul_2d_no_post_update diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tsub.py b/ptodsl/ptodsl/tilelib/templates/a5/tsub.py new file mode 100644 index 000000000..d1318990e --- /dev/null +++ b/ptodsl/ptodsl/tilelib/templates/a5/tsub.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib template for pto.tsub (ported from lib/TileOps/tsub_template.py).""" + +import ptodsl.tilelib as pto + + +@pto.tile_template( + op="pto.tsub", + target="a5", + name="template_tsub", + dtypes=[("f32", "f32", "f32")], + layouts=["row_major"], + memory_spaces=["ub"], + priority=0, +) +def template_tsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): + dtype = dst.element_type + valid_rows, valid_cols = dst.valid_shape + + for row in range(0, valid_rows, 1): + remained = valid_cols + for col in range(0, valid_cols, pto.get_lanes(dtype)): + mask, remained = pto.make_mask(dtype, remained) + lhs = pto.vlds(src0[row, col:]) + rhs = pto.vlds(src1[row, col:]) + subtracted = pto.vsub(lhs, rhs, mask) + pto.vsts(subtracted, dst[row, col:], mask) diff --git a/ptodsl/tests/fixtures/tadd_a5_8x64_f32.golden.mlir b/ptodsl/tests/fixtures/tadd_a5_8x64_f32.golden.mlir new file mode 100644 index 000000000..61a8e57c2 --- /dev/null +++ b/ptodsl/tests/fixtures/tadd_a5_8x64_f32.golden.mlir @@ -0,0 +1,38 @@ +// tilelang.target = a5 + // tilelang.op = pto.tadd + // tilelang.dtypes = (f32, f32, f32) + // tilelang.verify = True + // tilelang.advanced = False + // tilelang.specialize dst shape=(8, 64) memory_space=ub config=None + // tilelang.specialize src0 shape=(8, 64) memory_space=ub config=None + // tilelang.specialize src1 shape=(8, 64) memory_space=ub config=None + module attributes {pto.target_arch = "a5"} { + func.func @template_tadd(%arg0: !pto.tile_buf, %arg1: !pto.tile_buf, %arg2: !pto.tile_buf) attributes { pto.tilelang.instance, pto.kernel_kind = #pto.kernel_kind } { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %tmp_0 = pto.tile_buf_addr %arg0 : !pto.tile_buf -> memref<8x64xf32, #pto.address_space> + %tmp_1 = pto.tile_buf_addr %arg1 : !pto.tile_buf -> memref<8x64xf32, #pto.address_space> + %tmp_2 = pto.tile_buf_addr %arg2 : !pto.tile_buf -> memref<8x64xf32, #pto.address_space> + %valid_rows_1 = pto.tile_valid_rows %arg2 : !pto.tile_buf -> index + %valid_cols_2 = pto.tile_valid_cols %arg2 : !pto.tile_buf -> index + scf.for %row_3 = %c0 to %valid_rows_1 step %c1 { + %tmp_3 = arith.index_cast %valid_cols_2 : index to i32 + %remained_11:1 = scf.for %col_5 = %c0 to %valid_cols_2 step %c64 iter_args(%remained_iter_0 = %tmp_3) -> (i32) { + %mask_6, %remained_7 = pto.plt_b32 %remained_iter_0 : i32 -> !pto.mask, i32 + %tmp_4 = arith.subi %c64, %col_5 : index + %tmp_5 = memref.subview %tmp_0[%row_3, %col_5] [%c1, %tmp_4] [%c1, %c1] : memref<8x64xf32, #pto.address_space> to memref, #pto.address_space> + %lhs_8 = pto.vlds %tmp_5[%c0] : memref, #pto.address_space> -> !pto.vreg<64xf32> + %tmp_6 = arith.subi %c64, %col_5 : index + %tmp_7 = memref.subview %tmp_1[%row_3, %col_5] [%c1, %tmp_6] [%c1, %c1] : memref<8x64xf32, #pto.address_space> to memref, #pto.address_space> + %rhs_9 = pto.vlds %tmp_7[%c0] : memref, #pto.address_space> -> !pto.vreg<64xf32> + %summed_10 = pto.vadd %lhs_8, %rhs_9, %mask_6 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %tmp_8 = arith.subi %c64, %col_5 : index + %tmp_9 = memref.subview %tmp_2[%row_3, %col_5] [%c1, %tmp_8] [%c1, %c1] : memref<8x64xf32, #pto.address_space> to memref, #pto.address_space> + pto.vsts %summed_10, %tmp_9[%c0], %mask_6 : !pto.vreg<64xf32>, memref, #pto.address_space>, !pto.mask + scf.yield %remained_7 : i32 + } + } + return + } + } diff --git a/ptodsl/tests/test_tilelib_constraints.py b/ptodsl/tests/test_tilelib_constraints.py new file mode 100644 index 000000000..ebf600163 --- /dev/null +++ b/ptodsl/tests/test_tilelib_constraints.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Constraint-driven selection test for the reduction op pto.tcolmax. + +tcolmax is legal only for row-major / none-box operands with a 1-row output +(dst_valid_shape[0] == 1). Selection must accept the legal case and reject the others. +""" + +import unittest + +from ptodsl.tilelib import ScalarType, TileSpec, select +from ptodsl.tilelib.registry import NoMatchingTemplate + +F32 = ScalarType("f32") + + +def _specs(*, dst_valid=(1, 64), dst_blayout="row_major", dst_slayout="none_box"): + src = TileSpec(shape=(8, 64), dtype=F32, valid_shape=(8, 64)) + dst = TileSpec(shape=(8, 64), dtype=F32, valid_shape=dst_valid, + b_layout=dst_blayout, s_layout=dst_slayout) + return {"src": src, "dst": dst} + + +class TileLibConstraintTest(unittest.TestCase): + def test_legal_colmax_selected(self): + chosen = select("pto.tcolmax", "a5", _specs()) + self.assertEqual(chosen.name, "template_tcolmax") + + def test_rejected_when_dst_not_single_row(self): + with self.assertRaises(NoMatchingTemplate): + select("pto.tcolmax", "a5", _specs(dst_valid=(8, 64))) + + def test_rejected_when_not_row_major(self): + with self.assertRaises(NoMatchingTemplate): + select("pto.tcolmax", "a5", _specs(dst_blayout="col_major")) + + def test_legal_colmax_renders_structured_mlir(self): + chosen = select("pto.tcolmax", "a5", _specs()) + mlir = chosen.specialize(**_specs()).mlir_text() + for op in ("pto.tile_valid_rows", "memref.subview", "scf.for", "iter_args", + "pto.vmax", "pto.vsts", "pto.tilelang.instance"): + self.assertIn(op, mlir) + self.assertNotIn("pto.castptr", mlir) + + +if __name__ == "__main__": + unittest.main() diff --git a/ptodsl/tests/test_tilelib_elementwise.py b/ptodsl/tests/test_tilelib_elementwise.py new file mode 100644 index 000000000..ecd8bd520 --- /dev/null +++ b/ptodsl/tests/test_tilelib_elementwise.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Phase-4 breadth test: each ported elementwise op selects + renders to the structured +abstraction, using the right vector op.""" + +import unittest + +from ptodsl.tilelib import ScalarType, TileSpec, select + +# op -> (expected template name, expected vector op in the rendered MLIR) +ELEMENTWISE = { + "pto.tsub": ("template_tsub", "pto.vsub"), + "pto.tmul": ("template_tmul_2d_no_post_update", "pto.vmul"), + "pto.tmax": ("template_tmax", "pto.vmax"), + "pto.tmin": ("template_tmin", "pto.vmin"), + "pto.tdiv": ("template_tdiv", "pto.vdiv"), +} + +# Structured abstraction every elementwise template must preserve. +SHARED_OPS = ["pto.tile_buf_addr", "memref.subview", "scf.for", "iter_args", + "pto.plt_b32", "pto.vlds", "pto.vsts", "pto.tilelang.instance"] + + +def _f32_specs(): + spec = TileSpec(shape=(8, 64), dtype=ScalarType("f32")) + return {"src0": spec, "src1": spec, "dst": spec} + + +class TileLibElementwiseTest(unittest.TestCase): + def test_each_op_selects_and_renders(self): + for op, (name, vop) in ELEMENTWISE.items(): + with self.subTest(op=op): + descriptor = select(op, "a5", _f32_specs(), candidate_id=name) + self.assertEqual(descriptor.name, name) + + mlir = descriptor.specialize(**_f32_specs()).mlir_text() + self.assertIn(vop, mlir) # the op's own vector instruction + for shared in SHARED_OPS: + self.assertIn(shared, mlir) + self.assertNotIn("pto.castptr", mlir) # structured, not bare-pointer + + +if __name__ == "__main__": + unittest.main() diff --git a/ptodsl/tests/test_tilelib_render.py b/ptodsl/tests/test_tilelib_render.py new file mode 100644 index 000000000..d1405ef0a --- /dev/null +++ b/ptodsl/tests/test_tilelib_render.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLib render abstraction-level test for pto.tadd. + +Asserts the rendered MLIR is *on par* with the tilelang golden +(fixtures/tadd_a5_8x64_f32.golden.mlir): same structured abstraction, not byte-identical +(ptodsl differs in SSA naming, constant hoisting, index-vs-i32 carry, subview rank). +""" + +import unittest +from pathlib import Path + +from ptodsl.tilelib import TileSpec, f32 +from ptodsl.tilelib.templates.a5.tadd import template_tadd + +FIXTURE = Path(__file__).parent / "fixtures" / "tadd_a5_8x64_f32.golden.mlir" + +# The structured abstraction the migration must preserve (see golden fixture). +REQUIRED_OPS = [ + "pto.tile_buf_addr", + "memref<8x64xf32, #pto.address_space>", + "pto.tile_valid_rows", + "pto.tile_valid_cols", + "scf.for", + "iter_args", # the inner loop carries `remained` (AST-rewrite .carry path) + "pto.plt_b32", + "memref.subview", + "pto.vlds", + "pto.vadd", + "pto.vsts", + "!pto.vreg<64xf32>", +] + +# The low-level pointer style the team explicitly rejected for TileLib templates. +FORBIDDEN_OPS = ["pto.castptr", "pto.addptr"] + + +def _render(): + spec = TileSpec(shape=(8, 64), dtype=f32) + return template_tadd.specialize(src0=spec, src1=spec, dst=spec).mlir_text() + + +class TileLibRenderTest(unittest.TestCase): + def test_renders_structured_abstraction(self): + text = _render() + for op in REQUIRED_OPS: + self.assertIn(op, text) + for op in FORBIDDEN_OPS: + self.assertNotIn(op, text) + + def test_func_is_a_tilelang_instance(self): + text = _render() + self.assertIn('pto.target_arch = "a5"', text) + self.assertIn("pto.tilelang.instance", text) + self.assertIn("#pto.kernel_kind", text) + self.assertIn("func.func @template_tadd", text) + + def test_golden_fixture_uses_same_abstraction(self): + self.assertTrue(FIXTURE.exists(), f"missing golden fixture {FIXTURE}") + golden = FIXTURE.read_text(encoding="utf-8") + for op in ("pto.tile_buf_addr", "memref.subview", "pto.vlds", "pto.vadd", "pto.vsts", "pto.plt_b32"): + self.assertIn(op, golden) + + +if __name__ == "__main__": + unittest.main() diff --git a/ptodsl/tests/test_tilelib_select.py b/ptodsl/tests/test_tilelib_select.py new file mode 100644 index 000000000..a3415a3b4 --- /dev/null +++ b/ptodsl/tests/test_tilelib_select.py @@ -0,0 +1,108 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""TileLib selection test: metadata-driven legality and descriptor metadata.""" + +import unittest + +from ptodsl.tilelib import ( + AmbiguousTemplate, + ScalarType, + TileSpec, + legal_candidates, + select, +) +from ptodsl.tilelib import constraints as _constraints +from ptodsl.tilelib.registry import NoMatchingTemplate + + +def _f32_specs(): + spec = TileSpec(shape=(8, 64), dtype=ScalarType("f32")) + return {"src0": spec, "src1": spec, "dst": spec} + + +def _f32_single_row_specs(): + spec = TileSpec(shape=(1, 64), dtype=ScalarType("f32")) + return {"src0": spec, "src1": spec, "dst": spec} + + +class TileLibSelectTest(unittest.TestCase): + def test_four_tadd_versions_registered(self): + names = { + candidate.name + for candidate in legal_candidates("pto.tadd", "a5", _f32_specs()) + } + self.assertEqual({ + "template_tadd_2d_no_post_update", + "template_tadd_1d_no_post_update", + "template_tadd_2d_post_update", + "template_tadd_1d_post_update", + }, names) + + def test_plain_tadd_select_is_ambiguous(self): + with self.assertRaises(AmbiguousTemplate): + select("pto.tadd", "a5", _f32_specs()) + + def test_named_tadd_selects_2d_no_post_update_template(self): + chosen = select( + "pto.tadd", + "a5", + _f32_specs(), + candidate_id="template_tadd_2d_no_post_update", + ) + self.assertEqual(chosen.name, "template_tadd_2d_no_post_update") + self.assertFalse(chosen.metadata.is_post_update) + self.assertEqual(chosen.metadata.loop_depth, 2) + self.assertTrue(callable(chosen.metadata.Tail)) + self.assertEqual(chosen.metadata.tags, ("binop", "2d", "no_post_update")) + + def test_single_row_tadd_candidates_are_still_all_visible(self): + candidates = legal_candidates("pto.tadd", "a5", _f32_single_row_specs()) + self.assertEqual(len(candidates), 4) + + def test_legal_candidates_include_loop_depth_metadata(self): + candidates = legal_candidates("pto.tadd", "a5", _f32_specs()) + by_name = {candidate.name: candidate for candidate in candidates} + self.assertEqual(set(by_name), { + "template_tadd_2d_no_post_update", + "template_tadd_1d_no_post_update", + "template_tadd_2d_post_update", + "template_tadd_1d_post_update", + }) + self.assertEqual(by_name["template_tadd_2d_no_post_update"].metadata.loop_depth, 2) + self.assertFalse(by_name["template_tadd_2d_no_post_update"].metadata.is_post_update) + self.assertTrue(callable(by_name["template_tadd_2d_no_post_update"].metadata.Tail)) + self.assertEqual(by_name["template_tadd_1d_no_post_update"].metadata.loop_depth, 1) + self.assertTrue(callable(by_name["template_tadd_1d_no_post_update"].metadata.Tail)) + self.assertTrue(by_name["template_tadd_2d_post_update"].metadata.is_post_update) + self.assertTrue(callable(by_name["template_tadd_2d_post_update"].metadata.Tail)) + self.assertTrue(by_name["template_tadd_1d_post_update"].metadata.is_post_update) + self.assertTrue(callable(by_name["template_tadd_1d_post_update"].metadata.Tail)) + context = _constraints.build_context(_f32_specs(), "a5", "pto.tadd") + self.assertFalse(by_name["template_tadd_1d_post_update"].metadata.Tail(**context)) + + def test_can_select_named_legal_candidate(self): + chosen = select( + "pto.tadd", + "a5", + _f32_specs(), + candidate_id="template_tadd_2d_no_post_update", + ) + self.assertEqual(chosen.name, "template_tadd_2d_no_post_update") + + def test_no_matching_dtype_raises(self): + spec = TileSpec(shape=(8, 64), dtype=ScalarType("i8")) + with self.assertRaises(NoMatchingTemplate): + select("pto.tadd", "a5", {"src0": spec, "src1": spec, "dst": spec}) + + def test_unknown_op_raises(self): + with self.assertRaises(NoMatchingTemplate): + select("pto.tnope", "a5", _f32_specs()) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/ptoas_env.sh b/scripts/ptoas_env.sh index ac5eaa68b..46677d5d2 100644 --- a/scripts/ptoas_env.sh +++ b/scripts/ptoas_env.sh @@ -15,6 +15,7 @@ # export WORKSPACE_DIR=/path/to/workspace # export LLVM_BUILD_DIR=/path/to/llvm-project/build-shared # export PTO_SOURCE_DIR=/path/to/PTOAS +# export PTODSL_PYTHON_ROOT=/path/to/PTOAS/ptodsl # export PTO_INSTALL_DIR=/path/to/PTOAS/install # export PTO_PYTHON_BIN=/path/to/python3 # export PTOAS_ENV_SKIP_SMOKE_TEST=1 # skip legacy MatMul/Abs sample checks @@ -55,6 +56,7 @@ fi export PTOAS_PYTHON_SITE export PTO_PYTHON_ROOT="${PTO_PYTHON_ROOT:-${PTO_INSTALL_DIR}}" export PTO_PYTHON_BUILD_ROOT="${PTO_PYTHON_BUILD_ROOT:-${PTO_SOURCE_DIR}/build/python}" +export PTODSL_PYTHON_ROOT="${PTODSL_PYTHON_ROOT:-${PTO_SOURCE_DIR}/ptodsl}" export PYBIND11_CMAKE_DIR=$(python3 -m pybind11 --cmakedir) export PTOAS_FLAGS="${PTOAS_FLAGS:-}" export PTOAS_OUT_DIR="${PTOAS_OUT_DIR:-${PTO_SOURCE_DIR}/build/output}" @@ -116,6 +118,7 @@ _ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_ROOT}" _ptoas_prepend_path PYTHONPATH "${PTOAS_PYTHON_SITE}" _ptoas_prepend_path PYTHONPATH "${MLIR_PYTHON_ROOT}" _ptoas_prepend_path PYTHONPATH "${PTO_PYTHON_BUILD_ROOT}" +_ptoas_prepend_path PYTHONPATH "${PTODSL_PYTHON_ROOT}" _ptoas_prepend_path LD_LIBRARY_PATH "${LLVM_BUILD_DIR}/lib" _ptoas_prepend_path LD_LIBRARY_PATH "${PTO_INSTALL_DIR}/lib" From 4635db9958fc43aa4fdfb5ad9c0860ca7f83d97f Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Tue, 30 Jun 2026 20:11:31 +0000 Subject: [PATCH 2/6] feat(ptodsl): add TileLib daemon serving layer --- ptodsl/ptodsl/tilelib/registry.py | 2 + ptodsl/ptodsl/tilelib/serving/__init__.py | 33 +++ ptodsl/ptodsl/tilelib/serving/client.py | 77 +++++ ptodsl/ptodsl/tilelib/serving/daemon.py | 333 ++++++++++++++++++++++ ptodsl/ptodsl/tilelib/serving/helper.py | 73 +++++ ptodsl/ptodsl/tilelib/serving/wire.py | 41 +++ ptodsl/tests/test_tilelib_daemon.py | 169 +++++++++++ ptodsl/tests/test_tilelib_select.py | 34 +++ 8 files changed, 762 insertions(+) create mode 100644 ptodsl/ptodsl/tilelib/serving/__init__.py create mode 100644 ptodsl/ptodsl/tilelib/serving/client.py create mode 100644 ptodsl/ptodsl/tilelib/serving/daemon.py create mode 100644 ptodsl/ptodsl/tilelib/serving/helper.py create mode 100644 ptodsl/ptodsl/tilelib/serving/wire.py create mode 100644 ptodsl/tests/test_tilelib_daemon.py diff --git a/ptodsl/ptodsl/tilelib/registry.py b/ptodsl/ptodsl/tilelib/registry.py index 404971f72..72d09e8b3 100644 --- a/ptodsl/ptodsl/tilelib/registry.py +++ b/ptodsl/ptodsl/tilelib/registry.py @@ -64,6 +64,8 @@ def legal_candidates(self, op: str, target: str, tile_specs: dict, # Hard legality constraints (e.g. layout / valid-shape rules). context = _constraints.build_context(tile_specs, target, op) + if context_attrs: + context.update(context_attrs) legal = [d for d in legal if _constraints.passes(d.metadata.constraints, context)] if not legal: raise NoMatchingTemplate( diff --git a/ptodsl/ptodsl/tilelib/serving/__init__.py b/ptodsl/ptodsl/tilelib/serving/__init__.py new file mode 100644 index 000000000..3ac3c8f6a --- /dev/null +++ b/ptodsl/ptodsl/tilelib/serving/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Unix-socket serving layer for the PTODSL TileLib.""" + +from .client import DaemonClient, DaemonError + + +def __getattr__(name): + # Keep daemon.py unloaded when executing it with ``python -m``. + if name in {"TileLibDaemonServer", "metadata_request", "render_request"}: + from .daemon import TileLibDaemonServer, metadata_request, render_request + + exports = { + "TileLibDaemonServer": TileLibDaemonServer, + "metadata_request": metadata_request, + "render_request": render_request, + } + return exports[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "DaemonClient", + "DaemonError", + "TileLibDaemonServer", + "metadata_request", + "render_request", +] diff --git a/ptodsl/ptodsl/tilelib/serving/client.py b/ptodsl/ptodsl/tilelib/serving/client.py new file mode 100644 index 000000000..9a4db37b0 --- /dev/null +++ b/ptodsl/ptodsl/tilelib/serving/client.py @@ -0,0 +1,77 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Synchronous client for the PTODSL TileLib daemon.""" + +from __future__ import annotations + +import socket + +from .wire import recv_message, send_message + + +class DaemonError(Exception): + """An RPC reached the daemon but the requested operation failed.""" + + +class DaemonClient: + """Issue one daemon RPC per Unix-socket connection.""" + + def __init__(self, socket_path: str): + self.socket_path = socket_path + + def _call(self, method: str, params: dict | None = None): + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.connect(self.socket_path) + send_message(sock, {"method": method, "params": params or {}}) + response = recv_message(sock) + + if not response.get("success"): + raise DaemonError(response.get("error", "unknown daemon error")) + return response["result"] + + def ping(self): + return self._call("ping") + + def get_metadata(self, target, op, operand_specs, context_attrs=None): + return self._call( + "get_metadata", + { + "target": target, + "op": op, + "operand_specs": operand_specs, + "context_attrs": context_attrs or {}, + }, + ) + + def instantiate( + self, + target, + op, + operand_specs, + context_attrs=None, + candidate_id=None, + ): + return self._call( + "instantiate", + { + "target": target, + "op": op, + "operand_specs": operand_specs, + "context_attrs": context_attrs or {}, + "candidate_id": candidate_id, + }, + ) + + def get_stats(self): + return self._call("get_stats") + + def clear(self): + return self._call("clear") + + +__all__ = ["DaemonClient", "DaemonError"] diff --git a/ptodsl/ptodsl/tilelib/serving/daemon.py b/ptodsl/ptodsl/tilelib/serving/daemon.py new file mode 100644 index 000000000..02ac4533b --- /dev/null +++ b/ptodsl/ptodsl/tilelib/serving/daemon.py @@ -0,0 +1,333 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""PTODSL TileLib daemon for the ExpandTileOp Unix-socket RPC contract. + +The daemon owns template discovery, selection, specialization, rendering, and an +in-memory instance cache. PTODSL templates are loaded from the Python package, so +the daemon does not scan or depend on an external template directory. + +Run it with: + + python3 -m ptodsl.tilelib.serving.daemon --socket +""" + +from __future__ import annotations + +import argparse +import json +import os +import signal +import socketserver +import threading + +from .. import registry as _registry +from ..metadata import ScalarType, TileSpec +from ..templates import load_template +from .wire import recv_message, send_message + + +def _build_tile_specs(descriptor, operand_specs: list) -> dict: + """Map positional daemon operands onto a template's parameter names.""" + if not isinstance(operand_specs, list): + raise TypeError("operand_specs must be a list") + if len(operand_specs) != len(descriptor.param_names): + raise ValueError( + f"template {descriptor.name!r} expects {len(descriptor.param_names)} " + f"operands, got {len(operand_specs)}" + ) + + tile_specs = {} + for index, (name, spec) in enumerate(zip(descriptor.param_names, operand_specs)): + if not isinstance(spec, dict): + raise TypeError(f"operand_specs[{index}] must be an object") + + kind = spec.get("kind") + if kind != "tile": + raise NotImplementedError( + "PTODSL TileLib daemon currently supports only tile operands; " + f"operand {index} ({name!r}) has kind {kind!r}" + ) + + config = spec.get("config") or {} + if not isinstance(config, dict): + raise TypeError(f"operand_specs[{index}].config must be an object") + + try: + shape = tuple(spec["shape"]) + dtype = ScalarType(spec["dtype"]) + except KeyError as exc: + raise ValueError( + f"tile operand {index} ({name!r}) is missing {exc.args[0]!r}" + ) from exc + + valid_shape = spec.get("valid_shape") + tile_specs[name] = TileSpec( + shape=shape, + dtype=dtype, + memory_space=spec.get("memory_space", "ub"), + valid_shape=tuple(valid_shape) if valid_shape is not None else None, + b_layout=config.get("b_layout", "row_major"), + s_layout=config.get("s_layout", "none_box"), + ) + return tile_specs + + +def _constraint_name(predicate) -> str: + return getattr(predicate, "__name__", repr(predicate)) + + +def _metadata_value(value): + if callable(value): + return {"callable": _constraint_name(value)} + return value + + +def _metadata_for_descriptor(descriptor) -> dict: + metadata = descriptor.metadata + return { + "op": metadata.op, + "target": metadata.target, + "name": metadata.name, + "dtypes": [list(signature) for signature in metadata.dtypes], + "layouts": list(metadata.layouts), + "memory_spaces": list(metadata.memory_spaces), + "constraints": [ + _constraint_name(predicate) for predicate in metadata.constraints + ], + "priority": metadata.priority, + "fusible": metadata.fusible, + "loop_depth": metadata.loop_depth, + "id": metadata.id, + "Tail": _metadata_value(metadata.Tail), + "is_post_update": metadata.is_post_update, + "tags": list(metadata.tags), + } + + +def _tile_specs_for_request(target: str, op: str, operand_specs: list) -> dict: + # Import only this op's template module. Registration happens as an import + # side effect and repeated requests are no-ops because the loader is cached. + load_template(op, target) + candidates = _registry.default_registry().lookup(op, target) + if not candidates: + raise _registry.NoMatchingTemplate( + f"no template registered for op={op!r} target={target!r}" + ) + + # All versions of an op share one parameter order. Any candidate can map the + # positional wire operands before legality filtering chooses a version. + return _build_tile_specs(candidates[0], operand_specs) + + +def metadata_request( + target: str, + op: str, + operand_specs: list, + context_attrs: dict | None = None, +) -> dict: + """Return every legal candidate and its selection metadata.""" + tile_specs = _tile_specs_for_request(target, op, operand_specs) + legal = _registry.legal_candidates(op, target, tile_specs, context_attrs) + return { + "target": target, + "op": op, + "candidates": { + descriptor.name: _metadata_for_descriptor(descriptor) + for descriptor in legal + }, + } + + +def render_request( + target: str, + op: str, + operand_specs: list, + context_attrs: dict | None = None, + candidate_id: str | None = None, +) -> str: + """Select and render one PTODSL template as MLIR text.""" + tile_specs = _tile_specs_for_request(target, op, operand_specs) + descriptor = _registry.select( + op, + target, + tile_specs, + context_attrs, + candidate_id, + ) + return descriptor.specialize(**tile_specs).mlir_text() + + +class TileLibDaemonServer(socketserver.ThreadingUnixStreamServer): + """Threaded Unix-socket RPC server with an in-memory render cache.""" + + allow_reuse_address = True + daemon_threads = True + + def __init__(self, socket_path: str, max_entries: int = 1000): + if max_entries <= 0: + raise ValueError("max_entries must be greater than zero") + super().__init__(socket_path, _Handler) + self._cache: dict[str, str] = {} + self._state_lock = threading.Lock() + self._max_entries = max_entries + self._stats = {"hits": 0, "misses": 0, "evictions": 0} + + @property + def stats(self) -> dict: + """Return a snapshot of cache counters for diagnostics and tests.""" + with self._state_lock: + return dict(self._stats) + + def dispatch(self, request: dict) -> dict: + if not isinstance(request, dict): + return {"success": False, "error": "request must be a JSON object"} + + method = request.get("method") + params = request.get("params") or {} + if not isinstance(params, dict): + return {"success": False, "error": "request params must be a JSON object"} + + try: + if method == "instantiate": + result = self._instantiate(**params) + elif method == "get_metadata": + result = self._get_metadata(**params) + elif method == "ping": + result = "pong" + elif method == "get_stats": + result = self._get_stats() + elif method == "clear": + result = self._clear() + else: + return {"success": False, "error": f"unknown method {method!r}"} + return {"success": True, "result": result} + except Exception as exc: + return { + "success": False, + "error": f"{type(exc).__name__}: {exc}", + } + + def _get_metadata(self, target, op, operand_specs, context_attrs=None): + return metadata_request(target, op, operand_specs, context_attrs) + + def _get_stats(self): + with self._state_lock: + requests = self._stats["hits"] + self._stats["misses"] + total_entries = len(self._cache) + return { + **self._stats, + "entries": total_entries, + "total_entries": total_entries, + "max_entries": self._max_entries, + "hit_rate": self._stats["hits"] / requests if requests else 0.0, + } + + def _clear(self): + with self._state_lock: + self._cache.clear() + return {"cleared": True} + + def _instantiate( + self, + target, + op, + operand_specs, + context_attrs=None, + candidate_id=None, + ): + key = json.dumps( + { + "target": target, + "op": op, + "operand_specs": operand_specs, + "context_attrs": context_attrs, + "candidate_id": candidate_id, + }, + sort_keys=True, + separators=(",", ":"), + ) + + with self._state_lock: + cached = self._cache.get(key) + if cached is not None: + self._stats["hits"] += 1 + return cached + self._stats["misses"] += 1 + + mlir_text = render_request( + target, + op, + operand_specs, + context_attrs, + candidate_id, + ) + + with self._state_lock: + if len(self._cache) >= self._max_entries: + self._cache.pop(next(iter(self._cache))) + self._stats["evictions"] += 1 + self._cache[key] = mlir_text + return mlir_text + + +class _Handler(socketserver.BaseRequestHandler): + def handle(self): + try: + request = recv_message(self.request) + except (ConnectionError, UnicodeDecodeError, ValueError): + return + send_message(self.request, self.server.dispatch(request)) + + +def _parse_args(argv): + parser = argparse.ArgumentParser(prog="ptodsl.tilelib.serving.daemon") + parser.add_argument("--socket", required=True) + parser.add_argument( + "--template-dir", + default=None, + help="accepted during migration but ignored; PTODSL templates are in-package", + ) + parser.add_argument("--max-entries", type=int, default=1000) + parser.add_argument("--verbose", action="store_true") + return parser.parse_args(argv) + + +def main(argv=None): + args = _parse_args(argv) + + if os.path.exists(args.socket): + os.unlink(args.socket) + + server = TileLibDaemonServer(args.socket, max_entries=args.max_entries) + stop = threading.Event() + + def _request_shutdown(*_): + stop.set() + + signal.signal(signal.SIGTERM, _request_shutdown) + signal.signal(signal.SIGINT, _request_shutdown) + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + if args.verbose: + print(f"PTODSL TileLib daemon listening on {args.socket}", flush=True) + + try: + stop.wait() + finally: + server.shutdown() + server.server_close() + if os.path.exists(args.socket): + os.unlink(args.socket) + + +if __name__ == "__main__": + main() + + +__all__ = ["TileLibDaemonServer", "main", "metadata_request", "render_request"] diff --git a/ptodsl/ptodsl/tilelib/serving/helper.py b/ptodsl/ptodsl/tilelib/serving/helper.py new file mode 100644 index 000000000..5925556ef --- /dev/null +++ b/ptodsl/ptodsl/tilelib/serving/helper.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""One-shot command-line client for the ExpandTileOp daemon contract. + +Example: + + python3 -m ptodsl.tilelib.serving.helper --socket --target a5 \ + --op pto.tadd --operand-specs '[...]' +""" + +from __future__ import annotations + +import argparse +import json +import sys + +from .client import DaemonClient, DaemonError + + +def main(argv=None): + parser = argparse.ArgumentParser(prog="ptodsl.tilelib.serving.helper") + parser.add_argument("--socket", required=True) + parser.add_argument("--target", required=True) + parser.add_argument("--op", required=True) + parser.add_argument("--operand-specs", required=True) + parser.add_argument("--context-attrs", default=None) + parser.add_argument( + "--method", + choices=("instantiate", "get_metadata"), + default="instantiate", + ) + parser.add_argument("--candidate-id", default=None) + args = parser.parse_args(argv) + + try: + operand_specs = json.loads(args.operand_specs) + context_attrs = json.loads(args.context_attrs) if args.context_attrs else {} + except json.JSONDecodeError as exc: + parser.error(f"invalid JSON input: {exc}") + + try: + client = DaemonClient(args.socket) + if args.method == "get_metadata": + result = client.get_metadata( + args.target, + args.op, + operand_specs, + context_attrs, + ) + sys.stdout.write(json.dumps(result)) + return + + result = client.instantiate( + args.target, + args.op, + operand_specs, + context_attrs, + args.candidate_id, + ) + except (DaemonError, OSError) as exc: + sys.stderr.write(f"Error: daemon RPC failed: {exc}\n") + raise SystemExit(1) from exc + + sys.stdout.write(result) + + +if __name__ == "__main__": + main() diff --git a/ptodsl/ptodsl/tilelib/serving/wire.py b/ptodsl/ptodsl/tilelib/serving/wire.py new file mode 100644 index 000000000..5a8b8076d --- /dev/null +++ b/ptodsl/ptodsl/tilelib/serving/wire.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Length-prefixed JSON framing for the TileLib daemon RPC.""" + +from __future__ import annotations + +import json + + +def recv_exactly(sock, length: int) -> bytes: + """Read exactly ``length`` bytes or fail if the peer closes early.""" + chunks = [] + remaining = length + while remaining: + chunk = sock.recv(remaining) + if not chunk: + raise ConnectionError("socket closed mid-message") + chunks.append(chunk) + remaining -= len(chunk) + return b"".join(chunks) + + +def send_message(sock, message: dict) -> None: + """Send one UTF-8 JSON message with a 4-byte big-endian length prefix.""" + payload = json.dumps(message).encode("utf-8") + sock.sendall(len(payload).to_bytes(4, byteorder="big")) + sock.sendall(payload) + + +def recv_message(sock) -> dict: + """Receive one length-prefixed UTF-8 JSON message.""" + length = int.from_bytes(recv_exactly(sock, 4), byteorder="big") + return json.loads(recv_exactly(sock, length).decode("utf-8")) + + +__all__ = ["recv_exactly", "recv_message", "send_message"] diff --git a/ptodsl/tests/test_tilelib_daemon.py b/ptodsl/tests/test_tilelib_daemon.py new file mode 100644 index 000000000..9f5b7db3e --- /dev/null +++ b/ptodsl/tests/test_tilelib_daemon.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""End-to-end tests for the PTODSL TileLib daemon's Unix-socket RPC.""" + +import os +import tempfile +import threading +import unittest + +from ptodsl.tilelib.serving.client import DaemonClient, DaemonError +from ptodsl.tilelib.serving.daemon import TileLibDaemonServer + + +def _tile_spec(dtype="f32"): + return { + "kind": "tile", + "dtype": dtype, + "shape": [8, 64], + "valid_shape": [8, 64], + "memory_space": "ub", + "config": { + "b_layout": "row_major", + "s_layout": "none_box", + "s_fractal_size": 512, + "pad_value": "0x0", + }, + } + + +# ExpandTileOp sends tadd as ins(src0, src1), outs(dst), matching the +# template parameter order (src0, src1, dst). +TADD_OPERANDS = [_tile_spec(), _tile_spec(), _tile_spec()] +TADD_2D_NO_POST_UPDATE = "template_tadd_2d_no_post_update" + + +class TileLibDaemonTest(unittest.TestCase): + def setUp(self): + self._temporary_directory = tempfile.TemporaryDirectory() + self.socket_path = os.path.join( + self._temporary_directory.name, + "ptodsl_tilelib.sock", + ) + self.server = TileLibDaemonServer(self.socket_path) + self._thread = threading.Thread( + target=self.server.serve_forever, + daemon=True, + ) + self._thread.start() + self.client = DaemonClient(self.socket_path) + + def tearDown(self): + self.server.shutdown() + self.server.server_close() + self._thread.join() + self._temporary_directory.cleanup() + + def test_ping(self): + self.assertEqual(self.client.ping(), "pong") + + def test_instantiate_named_candidate_returns_structured_mlir(self): + mlir = self.client.instantiate( + "a5", + "pto.tadd", + TADD_OPERANDS, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + self.assertIn(f"func.func @{TADD_2D_NO_POST_UPDATE}", mlir) + for operation in ( + "pto.tile_buf_addr", + "memref.subview", + "pto.vlds", + "pto.vadd", + "pto.vsts", + "pto.plt_b32", + "pto.tilelang.instance", + ): + self.assertIn(operation, mlir) + self.assertNotIn("pto.castptr", mlir) + + def test_instantiate_requires_candidate_when_top_priority_ties(self): + with self.assertRaises(DaemonError): + self.client.instantiate("a5", "pto.tadd", TADD_OPERANDS) + + def test_get_metadata_returns_legal_candidates(self): + metadata = self.client.get_metadata("a5", "pto.tadd", TADD_OPERANDS) + candidates = metadata["candidates"] + self.assertEqual( + set(candidates), + { + TADD_2D_NO_POST_UPDATE, + "template_tadd_1d_no_post_update", + "template_tadd_2d_post_update", + "template_tadd_1d_post_update", + }, + ) + + selected = candidates[TADD_2D_NO_POST_UPDATE] + self.assertEqual(selected["loop_depth"], 2) + self.assertEqual(selected["Tail"], {"callable": "has_tail"}) + self.assertFalse(selected["is_post_update"]) + self.assertEqual(selected["tags"], ["binop", "2d", "no_post_update"]) + + def test_cache_stats_and_clear_are_available_over_rpc(self): + arguments = ( + "a5", + "pto.tadd", + TADD_OPERANDS, + ) + self.client.instantiate( + *arguments, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + self.client.instantiate( + *arguments, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + + stats = self.client.get_stats() + self.assertEqual(stats["misses"], 1) + self.assertEqual(stats["hits"], 1) + self.assertEqual(stats["entries"], 1) + + self.assertEqual(self.client.clear(), {"cleared": True}) + self.assertEqual(self.client.get_stats()["entries"], 0) + + def test_cache_key_includes_context_attributes(self): + self.client.instantiate( + "a5", + "pto.tadd", + TADD_OPERANDS, + context_attrs={"variant": 0}, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + self.client.instantiate( + "a5", + "pto.tadd", + TADD_OPERANDS, + context_attrs={"variant": 1}, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + self.assertEqual(self.client.get_stats()["misses"], 2) + + def test_non_tile_operand_is_rejected_explicitly(self): + operands = list(TADD_OPERANDS) + operands[0] = {"kind": "scalar", "dtype": "f32", "value": 1.0} + + with self.assertRaisesRegex( + DaemonError, + "currently supports only tile operands", + ): + self.client.instantiate( + "a5", + "pto.tadd", + operands, + candidate_id=TADD_2D_NO_POST_UPDATE, + ) + + def test_unknown_op_errors(self): + with self.assertRaises(DaemonError): + self.client.instantiate("a5", "pto.tnope", TADD_OPERANDS) + + +if __name__ == "__main__": + unittest.main() diff --git a/ptodsl/tests/test_tilelib_select.py b/ptodsl/tests/test_tilelib_select.py index a3415a3b4..8e6b67915 100644 --- a/ptodsl/tests/test_tilelib_select.py +++ b/ptodsl/tests/test_tilelib_select.py @@ -8,11 +8,14 @@ """TileLib selection test: metadata-driven legality and descriptor metadata.""" import unittest +from types import SimpleNamespace from ptodsl.tilelib import ( AmbiguousTemplate, ScalarType, + TemplateMetadata, TileSpec, + TileTemplateRegistry, legal_candidates, select, ) @@ -31,6 +34,37 @@ def _f32_single_row_specs(): class TileLibSelectTest(unittest.TestCase): + def test_context_attributes_are_available_to_constraints(self): + registry = TileTemplateRegistry() + registry.register(SimpleNamespace( + op="pto.test_context", + target="a5", + name="context_candidate", + param_names=("src0", "src1", "dst"), + metadata=TemplateMetadata.build( + op="pto.test_context", + target="a5", + name="context_candidate", + constraints=(lambda mode: mode == "enabled",), + ), + )) + + legal = registry.legal_candidates( + "pto.test_context", + "a5", + _f32_specs(), + context_attrs={"mode": "enabled"}, + ) + self.assertEqual([candidate.name for candidate in legal], ["context_candidate"]) + + with self.assertRaises(NoMatchingTemplate): + registry.legal_candidates( + "pto.test_context", + "a5", + _f32_specs(), + context_attrs={"mode": "disabled"}, + ) + def test_four_tadd_versions_registered(self): names = { candidate.name From 432b386f9bb00b86153df796e8ede4cb8bab976f Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Wed, 1 Jul 2026 14:41:22 +0000 Subject: [PATCH 3/6] feat(ptodsl): connect PTOAS to the PTODSL TileLib daemon --- include/PTO/Transforms/Passes.td | 22 +++-- lib/PTO/Transforms/ExpandTileOp.cpp | 78 +++++++++++------ ptodsl/README.md | 22 +++++ test/lit/vpto/expand_tile_op_ptodsl_tsub.pto | 47 ++++++++++ tools/ptoas/CMakeLists.txt | 9 +- tools/ptoas/TilelangDaemon.cpp | 49 +++++++---- tools/ptoas/TilelangDaemon.h | 13 +-- tools/ptoas/ptoas.cpp | 91 +++++++++++++++++--- 8 files changed, 262 insertions(+), 69 deletions(-) create mode 100644 test/lit/vpto/expand_tile_op_ptodsl_tsub.pto diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 63b06b6db..58bc02a63 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -452,12 +452,13 @@ def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> } def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { - let summary = "Expand tile ops into calls to TileLang DSL template functions"; + let summary = "Expand tile ops into calls to TileLib template functions"; let description = [{ Expands tile-level operations (pto.tadd, pto.tsub, etc.) by invoking the - TileLang Python DSL to instantiate template libraries. The generated - template functions use tile_buf parameters and contain vector-level - implementations (pto.vecscope, pto.vlds, pto.vadd, pto.vsts, etc.). + selected Python TileLib backend to instantiate template libraries. The + generated template functions use tile_buf parameters and contain + vector-level implementations (pto.vecscope, pto.vlds, pto.vadd, + pto.vsts, etc.). Each tile op is replaced by a func.call to the generated template function, with tile_buf operands passed directly (no type bridging). @@ -484,10 +485,19 @@ def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { "PYTHONPATH for tilelang_dsl package (added to env)">, Option<"pythonExe", "python-exe", "std::string", /*default=*/"\"python3\"", - "Python executable for tilelang DSL invocation">, + "Python executable for TileLib invocation">, Option<"daemonSocketPath", "daemon-socket-path", "std::string", /*default=*/"\"\"", - "Path to Unix domain socket for daemon RPC (if empty, uses subprocess)"> + "Path to Unix domain socket for daemon RPC">, + Option<"tileLibBackend", "tile-lib-backend", "std::string", + /*default=*/"\"tilelang\"", + "TileLib backend: tilelang or ptodsl">, + Option<"tileLibPkgPath", "tile-lib-pkg-path", "std::string", + /*default=*/"\"\"", + "PYTHONPATH root for the selected TileLib backend">, + Option<"daemonHelperModule", "daemon-helper-module", "std::string", + /*default=*/"\"tilelang_dsl.daemon_helper\"", + "Python module used for daemon helper RPC calls"> ]; } diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 91769e629..3f3c4a3a4 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -9,8 +9,8 @@ //===- ExpandTileOp.cpp ---------------------------------------------------===// //===----------------------------------------------------------------------===// // -// Expand tile-level ops (pto.tadd, pto.tsub, ...) by invoking the TileLang -// Python DSL to instantiate template libraries. +// Expand tile-level ops (pto.tadd, pto.tsub, ...) by invoking the selected +// Python TileLib backend to instantiate template libraries. // // The generated template functions use tile_buf parameters. After this pass, // the Inline pass inlines the template body, and FoldTileBufIntrinsics @@ -700,13 +700,16 @@ struct ExpandState { std::string tilelangPath; std::string tilelangPkgPath; + std::string tileLibBackend; + std::string tileLibPkgPath; + std::string daemonHelperModule; std::string pythonExe; std::string daemonSocketPath; - func::FuncOp invokeTilelangDSL(const SpecKey &key, Operation *tileOp, - ModuleOp mod, MLIRContext *ctx); - func::FuncOp invokeTilelangDaemon(const SpecKey &key, Operation *tileOp, - ModuleOp mod, MLIRContext *ctx); + func::FuncOp invokeTileLib(const SpecKey &key, Operation *tileOp, + ModuleOp mod, MLIRContext *ctx); + func::FuncOp invokeTileLibDaemon(const SpecKey &key, Operation *tileOp, + ModuleOp mod, MLIRContext *ctx); LogicalResult expandTileOpsInFunction(func::FuncOp func, ModuleOp mod, MLIRContext *ctx); @@ -872,9 +875,9 @@ static std::string buildContextAttrsJson(const SpecKey &key) { // ============================================================================ // Invoke Python DSL daemon RPC to generate a specialized template function. // ============================================================================ -func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, - Operation *tileOp, - ModuleOp mod, MLIRContext *ctx) { +func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, + Operation *tileOp, ModuleOp mod, + MLIRContext *ctx) { // 1. Locate the Python executable. auto pythonPath = llvm::sys::findProgramByName(pythonExe); if (!pythonPath) { @@ -893,7 +896,7 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, // 3. Create temp file for stdout redirect. SmallString<128> tmpPath; int tmpFD; - if (auto ec = llvm::sys::fs::createTemporaryFile("tilelang_daemon", "mlir", + if (auto ec = llvm::sys::fs::createTemporaryFile("tilelib_daemon", "mlir", tmpFD, tmpPath)) { llvm::errs() << "ExpandTileOp: cannot create temp file: " << ec.message() << "\n"; @@ -904,7 +907,7 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, // 4. Build command args for daemon helper. std::string opName = "pto." + key.opName; SmallVector args = { - *pythonPath, "-m", "tilelang_dsl.daemon_helper", + *pythonPath, "-m", daemonHelperModule, "--socket", daemonSocketPath, "--target", key.targetArch, "--op", opName, @@ -922,10 +925,10 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, SmallVector envp; std::string pythonPathEnv; std::vector envStorage; - bool hasPythonPath = !tilelangPkgPath.empty(); + bool hasPythonPath = !tileLibPkgPath.empty(); if (hasPythonPath) { const char *existingPath = ::getenv("PYTHONPATH"); - pythonPathEnv = "PYTHONPATH=" + tilelangPkgPath; + pythonPathEnv = "PYTHONPATH=" + tileLibPkgPath; if (existingPath && existingPath[0] != '\0') { pythonPathEnv += ":"; pythonPathEnv += existingPath; @@ -1039,18 +1042,29 @@ func::FuncOp ExpandState::invokeTilelangDaemon(const SpecKey &key, } // ============================================================================ -// Invoke Python DSL helper to generate a specialized template function. +// Invoke the selected TileLib backend to generate a specialized template. // ============================================================================ -func::FuncOp ExpandState::invokeTilelangDSL(const SpecKey &key, - Operation *tileOp, - ModuleOp mod, MLIRContext *ctx) { +func::FuncOp ExpandState::invokeTileLib(const SpecKey &key, + Operation *tileOp, ModuleOp mod, + MLIRContext *ctx) { // Try daemon first if daemon socket path is provided. if (!daemonSocketPath.empty()) { - func::FuncOp daemonResult = invokeTilelangDaemon(key, tileOp, mod, ctx); + func::FuncOp daemonResult = invokeTileLibDaemon(key, tileOp, mod, ctx); if (daemonResult) return daemonResult; - // Daemon failed, fall back to subprocess mode. - llvm::errs() << "ExpandTileOp: daemon RPC failed, falling back to subprocess mode\n"; + if (tileLibBackend == "ptodsl") { + llvm::errs() + << "ExpandTileOp: PTODSL daemon RPC failed; refusing to fall back " + "to TileLang\n"; + return nullptr; + } + llvm::errs() << "ExpandTileOp: daemon RPC failed, falling back to legacy " + "TileLang subprocess mode\n"; + } + + if (tileLibBackend == "ptodsl") { + llvm::errs() << "ExpandTileOp: PTODSL backend requires its daemon\n"; + return nullptr; } // 1. Locate the Python executable. @@ -1264,11 +1278,11 @@ LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, return failure(); } - // Invoke tilelang DSL (with caching). - func::FuncOp dslFn = invokeTilelangDSL(*specKeyOpt, op, mod, ctx); + // Invoke the selected TileLib backend (with daemon-side caching). + func::FuncOp dslFn = invokeTileLib(*specKeyOpt, op, mod, ctx); if (!dslFn) { StringRef opName = getTileOpName(op); - op->emitError("ExpandTileOp: failed to instantiate tilelang template for " + + op->emitError("ExpandTileOp: failed to instantiate TileLib template for " + opName); return failure(); } @@ -1302,16 +1316,32 @@ void ExpandTileOpPass::runOnOperation() { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - if (tilelangPath.empty()) { + if (tileLibBackend != "tilelang" && tileLibBackend != "ptodsl") { + mod.emitError("ExpandTileOp received unsupported tile-lib-backend '" + + std::string(tileLibBackend) + "'"); + signalPassFailure(); + return; + } + + if (tileLibBackend == "tilelang" && tilelangPath.empty()) { mod.emitError( "ExpandTileOp requires a non-empty tilelang-path on the VPTO backend"); signalPassFailure(); return; } + if (tileLibBackend == "ptodsl" && daemonSocketPath.empty()) { + mod.emitError("ExpandTileOp requires a running PTODSL TileLib daemon"); + signalPassFailure(); + return; + } + ExpandState state; state.tilelangPath = std::string(tilelangPath); state.tilelangPkgPath = std::string(tilelangPkgPath); + state.tileLibBackend = std::string(tileLibBackend); + state.tileLibPkgPath = std::string(tileLibPkgPath); + state.daemonHelperModule = std::string(daemonHelperModule); state.pythonExe = std::string(pythonExe); state.daemonSocketPath = std::string(daemonSocketPath); diff --git a/ptodsl/README.md b/ptodsl/README.md index c2e034ac5..97485a884 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -55,6 +55,28 @@ pip install -e . --- +## PTODSL TileLib backend + +PTOAS uses the legacy TileLang TileLib by default. Select the PTODSL TileLib +daemon for VPTO expansion with: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ + --tile-lib-backend=ptodsl input.pto -o - +``` + +The source-tree build bakes in `$PTOAS_REPO_ROOT/ptodsl` as the package root. +The `PTODSL_PYTHON_ROOT` environment variable from `scripts/ptoas_env.sh` +overrides that default. Use `--ptodsl-pkg-path=/path/to/package/root` for an +explicit command-line override. PTODSL daemon failures are reported as errors +and never fall back to the TileLang implementation. + +At this migration stage, end-to-end expansion is intended for operations with +one legal PTODSL candidate, such as `pto.tsub`. Operations with tied template +candidates will be enabled by the metadata-selection milestone. + +--- + ## JIT examples `ptodsl/examples/` contains self-contained `@pto.jit` examples that cover diff --git a/test/lit/vpto/expand_tile_op_ptodsl_tsub.pto b/test/lit/vpto/expand_tile_op_ptodsl_tsub.pto new file mode 100644 index 000000000..79d46afeb --- /dev/null +++ b/test/lit/vpto/expand_tile_op_ptodsl_tsub.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Test that PTOAS can select the PTODSL TileLib daemon and expand the +// single-candidate pto.tsub template without using the legacy TileLang path. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --tile-lib-backend=ptodsl %s -o - 2>/dev/null | FileCheck %s + +// CHECK: func.func @TSUB +// CHECK-NOT: pto.tsub ins +// CHECK: pto.vecscope +// CHECK: pto.castptr +// CHECK: pto.addptr +// CHECK: pto.vlds +// CHECK: pto.vsub +// CHECK: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TSUB() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %tile_buf = pto.alloc_tile + : !pto.tile_buf + + pto.tsub ins( + %a, %b + : !pto.tile_buf, + !pto.tile_buf) + outs( + %tile_buf + : !pto.tile_buf) + return + } +} diff --git a/tools/ptoas/CMakeLists.txt b/tools/ptoas/CMakeLists.txt index ced12af99..ce13efd4c 100644 --- a/tools/ptoas/CMakeLists.txt +++ b/tools/ptoas/CMakeLists.txt @@ -31,12 +31,13 @@ add_llvm_executable(pto-opt set_target_properties(pto-opt PROPERTIES OUTPUT_NAME "ptoas") target_compile_definitions(pto-opt PRIVATE PTOAS_RELEASE_VERSION="${PTOAS_CLI_VERSION}" - # Source-tree defaults for TileLang DSL expansion. These let ptoas run - # directly from the build tree without passing --tilelang-path / - # --tilelang-pkg-path. Installed layouts that move these directories - # still need to override the flags explicitly. + # Source-tree defaults for TileLib expansion. These let ptoas run directly + # from the build tree without passing --tilelang-path / + # --tilelang-pkg-path / --ptodsl-pkg-path. Installed layouts that move + # these directories still need to override the flags explicitly. PTOAS_DEFAULT_TILELANG_PATH="${CMAKE_SOURCE_DIR}/lib/TileOps" PTOAS_DEFAULT_TILELANG_PKG_PATH="${CMAKE_SOURCE_DIR}/tilelang-dsl/python" + PTOAS_DEFAULT_PTODSL_PKG_PATH="${CMAKE_SOURCE_DIR}/ptodsl" ) # [修改 2] 更新链接库名称 # 原因:In-tree 时你的库叫 MLIRPTODialect,但现在 Out-of-tree 它们是你自己定义的 diff --git a/tools/ptoas/TilelangDaemon.cpp b/tools/ptoas/TilelangDaemon.cpp index 570585cc2..61083da93 100644 --- a/tools/ptoas/TilelangDaemon.cpp +++ b/tools/ptoas/TilelangDaemon.cpp @@ -7,15 +7,16 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "TilelangDaemon.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/Program.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Program.h" #include #include #include #include #include +#include extern char **environ; @@ -24,12 +25,13 @@ namespace ptoas { std::optional> DaemonManager::processInfo; std::string DaemonManager::generateSocketPath() { - return "/tmp/tilelang_daemon_" + std::to_string(::getpid()) + ".sock"; + return "/tmp/tilelib_daemon_" + std::to_string(::getpid()) + ".sock"; } bool DaemonManager::start(const std::string &socketPath, - const std::string &templateDir, - const std::string &pkgPath) { + const std::string &daemonModule, + const std::string &pkgPath, + const std::string &templateDir) { auto pythonPath = llvm::sys::findProgramByName("python3"); if (!pythonPath) { llvm::errs() << "Error: Cannot find python3 executable for daemon\n"; @@ -37,10 +39,12 @@ bool DaemonManager::start(const std::string &socketPath, } llvm::SmallVector args = { - *pythonPath, "-m", "tilelang_dsl.daemon", - "--socket", socketPath, - "--template-dir", templateDir, + *pythonPath, "-m", daemonModule, "--socket", socketPath, }; + if (!templateDir.empty()) { + args.push_back("--template-dir"); + args.push_back(templateDir); + } llvm::SmallVector envp; std::string pythonPathEnv; @@ -69,26 +73,41 @@ bool DaemonManager::start(const std::string &socketPath, llvm::sys::ProcessInfo procInfo = llvm::sys::ExecuteNoWait( *pythonPath, args, - !pkgPath.empty() ? std::optional>(envp) : std::nullopt, + !pkgPath.empty() + ? std::optional>(envp) + : std::nullopt, {}, 0, &errMsg, &executionFailed, nullptr, true); if (executionFailed || procInfo.Pid == llvm::sys::ProcessInfo::InvalidPid) { - llvm::errs() << "Error: Failed to start TileLang daemon: " << errMsg << "\n"; + llvm::errs() << "Error: Failed to start TileLib daemon module '" + << daemonModule << "': " << errMsg << "\n"; return false; } processInfo = std::make_pair(procInfo.Pid, socketPath); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + // Python startup time depends on the selected TileLib frontend and its + // imports. Poll instead of relying on one fixed sleep. + bool socketReady = false; + for (int attempt = 0; attempt < 200; ++attempt) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + if (llvm::sys::fs::exists(socketPath)) { + socketReady = true; + break; + } + } - if (!llvm::sys::fs::exists(socketPath)) { + if (!socketReady) { llvm::errs() << "Error: Daemon socket not created at " << socketPath << "\n"; - llvm::errs() << "Note: Daemon process started (pid=" << procInfo.Pid + llvm::errs() << "Note: Daemon process started (pid=" << procInfo.Pid << ") but socket not found. Check daemon logs.\n"; + kill(procInfo.Pid, SIGTERM); + processInfo = std::nullopt; return false; } - llvm::errs() << "TileLang daemon started (pid=" << procInfo.Pid + llvm::errs() << "TileLib daemon '" << daemonModule << "' started (pid=" + << procInfo.Pid << ", socket=" << socketPath << ")\n"; return true; } @@ -108,7 +127,7 @@ void DaemonManager::stop() { llvm::sys::fs::remove(socketPath); } - llvm::errs() << "TileLang daemon stopped (pid=" << pid << ")\n"; + llvm::errs() << "TileLib daemon stopped (pid=" << pid << ")\n"; processInfo = std::nullopt; } diff --git a/tools/ptoas/TilelangDaemon.h b/tools/ptoas/TilelangDaemon.h index f51ebf8ed..e9a59d5d3 100644 --- a/tools/ptoas/TilelangDaemon.h +++ b/tools/ptoas/TilelangDaemon.h @@ -22,13 +22,14 @@ namespace ptoas { class DaemonManager { public: static std::string generateSocketPath(); - + static bool start(const std::string &socketPath, - const std::string &templateDir, - const std::string &pkgPath); - + const std::string &daemonModule, + const std::string &pkgPath, + const std::string &templateDir = ""); + static void stop(); - + static bool isRunning(); private: @@ -39,4 +40,4 @@ void registerDaemonCleanup(); } // namespace ptoas -#endif // PTOAS_TILELANG_DAEMON_H \ No newline at end of file +#endif // PTOAS_TILELANG_DAEMON_H diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 8e73de48e..94d6158eb 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Math/IR/Math.h" #include +#include #include #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -362,6 +363,9 @@ static llvm::cl::opt enableTileOpExpand( #ifndef PTOAS_DEFAULT_TILELANG_PKG_PATH #define PTOAS_DEFAULT_TILELANG_PKG_PATH "" #endif +#ifndef PTOAS_DEFAULT_PTODSL_PKG_PATH +#define PTOAS_DEFAULT_PTODSL_PKG_PATH "" +#endif static llvm::cl::opt tilelangPath( "tilelang-path", @@ -375,33 +379,78 @@ static llvm::cl::opt tilelangPkgPath( "(default: /tilelang-dsl/python, baked in at build time)"), llvm::cl::init(PTOAS_DEFAULT_TILELANG_PKG_PATH)); +static llvm::cl::opt ptodslPkgPath( + "ptodsl-pkg-path", + llvm::cl::desc("PYTHONPATH for the ptodsl package " + "(default: /ptodsl, baked in at build time)"), + llvm::cl::init(PTOAS_DEFAULT_PTODSL_PKG_PATH)); + static llvm::cl::opt daemonSocketPath( "daemon-socket-path", llvm::cl::desc("Path to Unix domain socket for daemon RPC " - "(default: /tmp/tilelang_daemon_{pid}.sock)"), + "(default: /tmp/tilelib_daemon_{pid}.sock)"), llvm::cl::init("")); +enum class TileLibBackend { + TileLang, + PTODSL, +}; + +static llvm::cl::opt tileLibBackend( + "tile-lib-backend", + llvm::cl::desc("TileLib backend used by ExpandTileOp"), + llvm::cl::values( + clEnumValN(TileLibBackend::TileLang, "tilelang", + "Use the legacy TileLang DSL TileLib"), + clEnumValN(TileLibBackend::PTODSL, "ptodsl", + "Use the PTODSL TileLib daemon")), + llvm::cl::init(TileLibBackend::TileLang)); + static pto::ExpandTileOpOptions resolveExpandTileOpOptions(int argc, char **argv) { pto::ExpandTileOpOptions expandOpts; expandOpts.tilelangPath = tilelangPath; expandOpts.tilelangPkgPath = tilelangPkgPath; + const bool usePTODSLTileLib = tileLibBackend == TileLibBackend::PTODSL; + std::string resolvedPtodslPkgPath = ptodslPkgPath; - if (!hasCLIOption(argc, argv, "--tilelang-path")) { - std::string detectedTilelangPath = detectInstalledTilelangPath(argv[0]); - if (!detectedTilelangPath.empty()) - expandOpts.tilelangPath = detectedTilelangPath; + if (!hasCLIOption(argc, argv, "--ptodsl-pkg-path")) { + const char *envPtodslRoot = ::getenv("PTODSL_PYTHON_ROOT"); + if (envPtodslRoot && envPtodslRoot[0] != '\0') + resolvedPtodslPkgPath = envPtodslRoot; } - if (!hasCLIOption(argc, argv, "--tilelang-pkg-path")) { - std::string detectedTilelangPkgPath = detectInstalledTilelangPkgPath(argv[0]); - if (!detectedTilelangPkgPath.empty()) - expandOpts.tilelangPkgPath = detectedTilelangPkgPath; + if (usePTODSLTileLib) { + // The PTODSL backend is package-based and must not depend on legacy + // TileLang template or package paths. + expandOpts.tilelangPath.clear(); + expandOpts.tilelangPkgPath.clear(); + } else { + if (!hasCLIOption(argc, argv, "--tilelang-path")) { + std::string detectedTilelangPath = detectInstalledTilelangPath(argv[0]); + if (!detectedTilelangPath.empty()) + expandOpts.tilelangPath = detectedTilelangPath; + } + + if (!hasCLIOption(argc, argv, "--tilelang-pkg-path")) { + std::string detectedTilelangPkgPath = + detectInstalledTilelangPkgPath(argv[0]); + if (!detectedTilelangPkgPath.empty()) + expandOpts.tilelangPkgPath = detectedTilelangPkgPath; + } } + expandOpts.tileLibBackend = usePTODSLTileLib ? "ptodsl" : "tilelang"; + expandOpts.daemonHelperModule = + usePTODSLTileLib ? "ptodsl.tilelib.serving.helper" + : "tilelang_dsl.daemon_helper"; + expandOpts.tileLibPkgPath = + usePTODSLTileLib ? resolvedPtodslPkgPath + : std::string(expandOpts.tilelangPkgPath); + // Daemon mode is default (no CLI option needed) // Automatically start daemon for instance caching - if (!expandOpts.tilelangPath.empty()) { + if (usePTODSLTileLib || !expandOpts.tilelangPath.empty()) { std::string socket = daemonSocketPath; if (socket.empty()) socket = ptoas::DaemonManager::generateSocketPath(); @@ -409,14 +458,28 @@ static pto::ExpandTileOpOptions resolveExpandTileOpOptions(int argc, // Register cleanup handler (daemon will be stopped on PTOAS exit) ptoas::registerDaemonCleanup(); + const std::string daemonModule = + usePTODSLTileLib ? "ptodsl.tilelib.serving.daemon" + : "tilelang_dsl.daemon"; + const std::string templateDir = + usePTODSLTileLib ? "" : std::string(expandOpts.tilelangPath); + // Try to start daemon automatically - if (ptoas::DaemonManager::start(socket, expandOpts.tilelangPath, expandOpts.tilelangPkgPath)) { + if (ptoas::DaemonManager::start(socket, daemonModule, + expandOpts.tileLibPkgPath, templateDir)) { expandOpts.daemonSocketPath = socket; - llvm::errs() << "Info: TileLang daemon started successfully\n"; + llvm::errs() << "Info: " << expandOpts.tileLibBackend + << " TileLib daemon started successfully\n"; } else { - // Fallback: daemon failed, use subprocess mode (current approach) expandOpts.daemonSocketPath = ""; - llvm::errs() << "Warning: Failed to start daemon, using subprocess mode (fallback)\n"; + if (usePTODSLTileLib) { + llvm::errs() + << "Error: Failed to start the PTODSL TileLib daemon; no TileLang " + "fallback will be used\n"; + } else { + llvm::errs() << "Warning: Failed to start daemon, using legacy " + "TileLang subprocess mode\n"; + } } } From a49f43df1a02fd0264ad0576edcdb711374f80be Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Wed, 1 Jul 2026 14:58:30 +0000 Subject: [PATCH 4/6] feat(ptodsl): split TileLib metadata and render daemon calls --- lib/PTO/Transforms/ExpandTileOp.cpp | 145 +++++++++++++----- ptodsl/README.md | 11 +- ptodsl/docs/tilelib-migration-testing.md | 103 +++++++++++++ ...tile_op_ptodsl_tadd_requires_selection.pto | 40 +++++ 4 files changed, 262 insertions(+), 37 deletions(-) create mode 100644 ptodsl/docs/tilelib-migration-testing.md create mode 100644 test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 3f3c4a3a4..9f314dac1 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -18,10 +18,11 @@ // // Workflow per tile op: // 1. Extract SpecKey from ALL operands' tile_buf types. -// 2. Invoke Python DSL helper to generate a specialized MLIR function -// (with tile_buf parameters). -// 3. Parse the generated MLIR and clone the function into the module. -// 4. Replace the original tile op with func.call, passing tile_buf +// 2. For PTODSL, query legal-candidate metadata and require one candidate. +// 3. Invoke the selected TileLib helper to generate a specialized MLIR +// function (with tile_buf parameters). +// 4. Parse the generated MLIR and clone the function into the module. +// 5. Replace the original tile op with func.call, passing tile_buf // operands directly (no type bridging needed). // @@ -48,13 +49,16 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/Program.h" #include "llvm/Support/raw_ostream.h" #include +#include #include #include @@ -706,9 +710,14 @@ struct ExpandState { std::string pythonExe; std::string daemonSocketPath; - func::FuncOp invokeTileLib(const SpecKey &key, Operation *tileOp, - ModuleOp mod, MLIRContext *ctx); - func::FuncOp invokeTileLibDaemon(const SpecKey &key, Operation *tileOp, + std::optional + invokeTileLibHelper(const SpecKey &key, StringRef method = {}, + StringRef candidateId = {}); + std::optional + discoverSingleTileLibCandidate(const SpecKey &key); + func::FuncOp invokeTileLib(const SpecKey &key, ModuleOp mod, + MLIRContext *ctx); + func::FuncOp invokeTileLibDaemon(const SpecKey &key, StringRef candidateId, ModuleOp mod, MLIRContext *ctx); LogicalResult expandTileOpsInFunction(func::FuncOp func, ModuleOp mod, @@ -873,38 +882,34 @@ static std::string buildContextAttrsJson(const SpecKey &key) { } // ============================================================================ -// Invoke Python DSL daemon RPC to generate a specialized template function. +// Invoke the configured one-shot helper and return its stdout. // ============================================================================ -func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, - Operation *tileOp, ModuleOp mod, - MLIRContext *ctx) { - // 1. Locate the Python executable. +std::optional +ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef method, + StringRef candidateId) { auto pythonPath = llvm::sys::findProgramByName(pythonExe); if (!pythonPath) { llvm::errs() << "ExpandTileOp: cannot find '" << pythonExe << "'\n"; - return nullptr; + return std::nullopt; } - // 2. Build operand schema JSON for daemon RPC. std::string operandSpecsJson = buildOperandSpecsJson(key); std::string contextAttrsJson = buildContextAttrsJson(key); if (key.targetArch.empty()) { llvm::errs() << "ExpandTileOp: missing pto.target_arch module attribute\n"; - return nullptr; + return std::nullopt; } - // 3. Create temp file for stdout redirect. SmallString<128> tmpPath; int tmpFD; - if (auto ec = llvm::sys::fs::createTemporaryFile("tilelib_daemon", "mlir", - tmpFD, tmpPath)) { + if (auto ec = llvm::sys::fs::createTemporaryFile("tilelib_helper", "out", + tmpFD, tmpPath)) { llvm::errs() << "ExpandTileOp: cannot create temp file: " << ec.message() << "\n"; - return nullptr; + return std::nullopt; } ::close(tmpFD); - // 4. Build command args for daemon helper. std::string opName = "pto." + key.opName; SmallVector args = { *pythonPath, "-m", daemonHelperModule, @@ -913,12 +918,19 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, "--op", opName, "--operand-specs", operandSpecsJson, }; + if (!method.empty()) { + args.push_back("--method"); + args.push_back(method); + } if (!key.contextAttrs.empty()) { args.push_back("--context-attrs"); args.push_back(contextAttrsJson); } + if (!candidateId.empty()) { + args.push_back("--candidate-id"); + args.push_back(candidateId); + } - // 5. Set up environment with PYTHONPATH. std::optional redirects[] = {std::nullopt, StringRef(tmpPath), std::nullopt}; @@ -944,7 +956,6 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, envp.push_back(s); } - // 6. Execute daemon helper. std::string errMsg; int rc = llvm::sys::ExecuteAndWait( *pythonPath, args, @@ -952,27 +963,83 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errMsg); if (rc != 0) { - llvm::errs() << "ExpandTileOp: daemon helper failed (rc=" << rc + StringRef operation = + method.empty() ? StringRef("instantiate") : method; + llvm::errs() << "ExpandTileOp: daemon helper " << operation + << " failed (rc=" << rc << "): " << errMsg << "\n"; llvm::sys::fs::remove(tmpPath); - return nullptr; + return std::nullopt; } - // 7. Read the generated MLIR. auto bufOrErr = llvm::MemoryBuffer::getFile(tmpPath); llvm::sys::fs::remove(tmpPath); if (!bufOrErr) { llvm::errs() << "ExpandTileOp: cannot read daemon output\n"; - return nullptr; + return std::nullopt; } - StringRef mlirText = (*bufOrErr)->getBuffer(); - if (mlirText.empty()) { + std::string output = (*bufOrErr)->getBuffer().str(); + if (output.empty()) { llvm::errs() << "ExpandTileOp: empty daemon output\n"; - return nullptr; + return std::nullopt; } + return output; +} - // 8. Parse the MLIR text. - auto parsedMod = parseSourceString(mlirText, ctx); +// ============================================================================ +// Discover the only legal candidate supported by this migration milestone. +// ============================================================================ +std::optional +ExpandState::discoverSingleTileLibCandidate(const SpecKey &key) { + auto metadataText = invokeTileLibHelper(key, "get_metadata"); + if (!metadataText) + return std::nullopt; + + auto parsed = llvm::json::parse(*metadataText); + if (!parsed) { + llvm::errs() << "ExpandTileOp: failed to parse PTODSL metadata: " + << llvm::toString(parsed.takeError()) << "\n"; + return std::nullopt; + } + + auto *root = parsed->getAsObject(); + auto *candidates = root ? root->getObject("candidates") : nullptr; + if (!candidates) { + llvm::errs() << "ExpandTileOp: PTODSL metadata is missing the " + "'candidates' object\n"; + return std::nullopt; + } + + std::string opName = "pto." + key.opName; + if (candidates->size() != 1) { + llvm::errs() << "ExpandTileOp: PTODSL metadata returned " + << candidates->size() << " legal candidates for " << opName + << "; version selection is required before rendering\n"; + return std::nullopt; + } + + const auto &candidate = *candidates->begin(); + if (!candidate.second.getAsObject()) { + llvm::errs() << "ExpandTileOp: malformed metadata for candidate '" + << candidate.first.str() << "'\n"; + return std::nullopt; + } + return candidate.first.str(); +} + +// ============================================================================ +// Invoke the daemon RPC to generate a specialized template function. +// ============================================================================ +func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, + StringRef candidateId, + ModuleOp mod, + MLIRContext *ctx) { + auto mlirText = invokeTileLibHelper(key, /*method=*/{}, candidateId); + if (!mlirText) + return nullptr; + + // Parse the rendered MLIR. + auto parsedMod = parseSourceString(*mlirText, ctx); if (!parsedMod) { llvm::errs() << "ExpandTileOp: failed to parse daemon output\n"; return nullptr; @@ -1044,12 +1111,20 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, // ============================================================================ // Invoke the selected TileLib backend to generate a specialized template. // ============================================================================ -func::FuncOp ExpandState::invokeTileLib(const SpecKey &key, - Operation *tileOp, ModuleOp mod, +func::FuncOp ExpandState::invokeTileLib(const SpecKey &key, ModuleOp mod, MLIRContext *ctx) { // Try daemon first if daemon socket path is provided. if (!daemonSocketPath.empty()) { - func::FuncOp daemonResult = invokeTileLibDaemon(key, tileOp, mod, ctx); + std::string candidateId; + if (tileLibBackend == "ptodsl") { + auto discoveredCandidate = discoverSingleTileLibCandidate(key); + if (!discoveredCandidate) + return nullptr; + candidateId = std::move(*discoveredCandidate); + } + + func::FuncOp daemonResult = + invokeTileLibDaemon(key, candidateId, mod, ctx); if (daemonResult) return daemonResult; if (tileLibBackend == "ptodsl") { @@ -1279,7 +1354,7 @@ LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, } // Invoke the selected TileLib backend (with daemon-side caching). - func::FuncOp dslFn = invokeTileLib(*specKeyOpt, op, mod, ctx); + func::FuncOp dslFn = invokeTileLib(*specKeyOpt, mod, ctx); if (!dslFn) { StringRef opName = getTileOpName(op); op->emitError("ExpandTileOp: failed to instantiate TileLib template for " + diff --git a/ptodsl/README.md b/ptodsl/README.md index 97485a884..9247dba88 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -72,8 +72,15 @@ explicit command-line override. PTODSL daemon failures are reported as errors and never fall back to the TileLang implementation. At this migration stage, end-to-end expansion is intended for operations with -one legal PTODSL candidate, such as `pto.tsub`. Operations with tied template -candidates will be enabled by the metadata-selection milestone. +one legal PTODSL candidate, such as `pto.tsub`. `ExpandTileOp` first requests +legal-candidate metadata and then renders with the sole candidate's ID. +Operations with tied template candidates will be enabled when candidate +discovery moves into the planned `DiscoverTileLibCandidates` pass and version +selection becomes a separate stage. + +See the +[PTODSL TileLib migration test checklist](docs/tilelib-migration-testing.md) +for the complete test inventory, commands, and expected outcomes. --- diff --git a/ptodsl/docs/tilelib-migration-testing.md b/ptodsl/docs/tilelib-migration-testing.md new file mode 100644 index 000000000..d28107c19 --- /dev/null +++ b/ptodsl/docs/tilelib-migration-testing.md @@ -0,0 +1,103 @@ +# PTODSL TileLib Migration Test Checklist + +This page tracks the tests used while migrating PTOAS TileLib expansion from +the legacy TileLang implementation to PTODSL. Run commands from the repository +root. + +## Environment + +Set up PTOAS, PTODSL, MLIR, and LLVM test-tool paths: + +```bash +export PTOAS_ENV_SKIP_SMOKE_TEST=1 +source scripts/ptoas_env.sh +export FILECHECK="$LLVM_BUILD_DIR/bin/FileCheck" +``` + +The Python-only tests do not require rebuilding PTOAS. Tests that invoke +`ptoas` must use a binary rebuilt after the corresponding C++ or TableGen +changes. + +## Milestone coverage + +| Milestone | Test | Purpose | +|---|---|---| +| Legacy baseline | `expand_tile_op_tilelang_tsub.pto` | Confirms the default TileLang backend still works | +| PTODSL TileLib package | `test_tilelib_constraints.py`, `test_tilelib_elementwise.py`, `test_tilelib_render.py`, `test_tilelib_select.py` | Covers legality constraints, template registration and selection, and rendering | +| PTODSL daemon | `test_tilelib_daemon.py` | Covers the Unix-socket protocol, metadata, rendering, candidate IDs, and caching | +| PTOAS daemon selection | `expand_tile_op_ptodsl_tsub.pto` | Confirms `--tile-lib-backend=ptodsl` starts and uses the PTODSL daemon | +| Two-call expansion | `expand_tile_op_ptodsl_tsub.pto` | Confirms metadata discovery followed by rendering with the sole candidate ID | +| Multi-candidate boundary | `expand_tile_op_ptodsl_tadd_requires_selection.pto` | Confirms four legal `tadd` candidates require a separate selection stage | + +## Python TileLib tests + +Run every Python TileLib test: + +```bash +python3 -m unittest discover -s ptodsl/tests -p 'test_tilelib_*.py' +``` + +Run the layers individually: + +```bash +python3 ptodsl/tests/test_tilelib_constraints.py +python3 ptodsl/tests/test_tilelib_elementwise.py +python3 ptodsl/tests/test_tilelib_render.py +python3 ptodsl/tests/test_tilelib_select.py +python3 ptodsl/tests/test_tilelib_daemon.py +``` + +Each command prints `OK` when successful. + +## PTOAS integration tests + +### PTODSL positive path: one legal candidate + +`pto.tsub` has one legal PTODSL candidate. The test checks that PTOAS expands +it into vector operations: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ + --tile-lib-backend=ptodsl \ + test/lit/vpto/expand_tile_op_ptodsl_tsub.pto -o - 2>/dev/null | +"$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tsub.pto +``` + +### PTODSL negative path: selection is still required + +`pto.tadd` currently has four legal candidates. PTOAS is expected to reject it +after metadata discovery because version selection is not a separate stage yet: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ + --tile-lib-backend=ptodsl \ + test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto \ + -o /dev/null 2>&1 | +"$FILECHECK" \ + test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto +``` + +The `ptoas` process fails intentionally in this test. `FileCheck` succeeds +only when it sees the expected four-candidate diagnostic. + +### Legacy backend regression + +Omitting `--tile-lib-backend` must continue to select TileLang: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ + test/lit/vpto/expand_tile_op_tilelang_tsub.pto -o - 2>/dev/null | +"$FILECHECK" test/lit/vpto/expand_tile_op_tilelang_tsub.pto +``` + +## Reading the result + +`FileCheck` is silent when it succeeds. Immediately check its status with: + +```bash +echo $? +``` + +`0` means the check passed. When candidate selection is implemented, replace +the expected-failure `tadd` coverage with a positive selected-version test and +update the milestone table above. diff --git a/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto b/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto new file mode 100644 index 000000000..7279a4456 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// The two-call migration milestone discovers candidates before rendering, but +// deliberately does not choose among multiple legal versions yet. +// +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --tile-lib-backend=ptodsl %s -o /dev/null 2>&1 | FileCheck %s + +// CHECK: ExpandTileOp: PTODSL metadata returned 4 legal candidates for pto.tadd; version selection is required before rendering + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tadd ins( + %a, %b + : !pto.tile_buf, + !pto.tile_buf) + outs( + %dst + : !pto.tile_buf) + return + } +} From f509b39f6d9278a7861aa3e523398dbcd93ba103 Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Wed, 1 Jul 2026 16:17:13 +0000 Subject: [PATCH 5/6] feat(ptodsl): add a pass to insert template candidates --- include/PTO/Transforms/Passes.h | 3 + include/PTO/Transforms/Passes.td | 30 ++ lib/PTO/Transforms/CMakeLists.txt | 1 + lib/PTO/Transforms/ExpandTileOp.cpp | 101 ++-- .../Transforms/InsertTemplateAttributes.cpp | 431 ++++++++++++++++++ ptodsl/README.md | 12 +- ptodsl/docs/tilelib-migration-testing.md | 66 ++- ptodsl/ptodsl/tilelib/serving/daemon.py | 16 +- ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py | 3 + ptodsl/ptodsl/tilelib/templates/a5/tdiv.py | 3 + ptodsl/ptodsl/tilelib/templates/a5/tmax.py | 3 + ptodsl/ptodsl/tilelib/templates/a5/tmin.py | 3 + ptodsl/ptodsl/tilelib/templates/a5/tsub.py | 3 + ptodsl/tests/test_tilelib_daemon.py | 20 +- test/lit/vpto/expand_tile_op_ptodsl_tadd.pto | 82 ++++ ...tile_op_ptodsl_tadd_requires_selection.pto | 40 -- tools/ptoas/ptoas.cpp | 62 ++- 17 files changed, 741 insertions(+), 138 deletions(-) create mode 100644 lib/PTO/Transforms/InsertTemplateAttributes.cpp create mode 100644 test/lit/vpto/expand_tile_op_ptodsl_tadd.pto delete mode 100644 test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3de31a89b..4fbbb41ae 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -107,6 +107,9 @@ LogicalResult validateVPTOEmissionIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +std::unique_ptr createInsertTemplateAttributesPass(); +std::unique_ptr createInsertTemplateAttributesPass( + const InsertTemplateAttributesOptions &options); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 58bc02a63..1c1d5f34c 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -451,6 +451,36 @@ def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> ]; } +def InsertTemplateAttributes + : Pass<"pto-insert-template-attributes", "ModuleOp"> { + let summary = "Attach legal PTODSL template candidates to tile operations"; + let description = [{ + Queries the PTODSL TileLib daemon for legal template candidates and stores + the compact candidate list on each tile operation as the `candidates` + attribute. Each candidate contains only id, name, loop_depth, postupdate, + and tail metadata. + }]; + let constructor = "mlir::pto::createInsertTemplateAttributesPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::func::FuncDialect" + ]; + let options = [ + Option<"pythonExe", "python-exe", "std::string", + /*default=*/"\"python3\"", + "Python executable for TileLib metadata invocation">, + Option<"daemonSocketPath", "daemon-socket-path", "std::string", + /*default=*/"\"\"", + "Path to the PTODSL TileLib daemon Unix socket">, + Option<"tileLibPkgPath", "tile-lib-pkg-path", "std::string", + /*default=*/"\"\"", + "PYTHONPATH root for PTODSL">, + Option<"daemonHelperModule", "daemon-helper-module", "std::string", + /*default=*/"\"ptodsl.tilelib.serving.helper\"", + "Python module used for daemon metadata RPC calls"> + ]; +} + def ExpandTileOp : Pass<"pto-expand-tile-op", "ModuleOp"> { let summary = "Expand tile ops into calls to TileLib template functions"; let description = [{ diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 3a6c04aed..54674bb2a 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -42,6 +42,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp PTOValidateIntToPtrUses.cpp + InsertTemplateAttributes.cpp ExpandTileOp.cpp FoldTileBufIntrinsics.cpp PTOLowerToOpLibCalls.cpp diff --git a/lib/PTO/Transforms/ExpandTileOp.cpp b/lib/PTO/Transforms/ExpandTileOp.cpp index 9f314dac1..189abc7a5 100644 --- a/lib/PTO/Transforms/ExpandTileOp.cpp +++ b/lib/PTO/Transforms/ExpandTileOp.cpp @@ -18,7 +18,8 @@ // // Workflow per tile op: // 1. Extract SpecKey from ALL operands' tile_buf types. -// 2. For PTODSL, query legal-candidate metadata and require one candidate. +// 2. For PTODSL, read candidates attached by InsertTemplateAttributes and +// select the first candidate still present. // 3. Invoke the selected TileLib helper to generate a specialized MLIR // function (with tile_buf parameters). // 4. Parse the generated MLIR and clone the function into the module. @@ -49,9 +50,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/Program.h" @@ -79,6 +78,8 @@ namespace pto { namespace { +constexpr llvm::StringLiteral kCandidatesAttr = "candidates"; + // ============================================================================ // OperandTypeInfo: describes one operand for template specialization. // @@ -711,12 +712,9 @@ struct ExpandState { std::string daemonSocketPath; std::optional - invokeTileLibHelper(const SpecKey &key, StringRef method = {}, - StringRef candidateId = {}); - std::optional - discoverSingleTileLibCandidate(const SpecKey &key); - func::FuncOp invokeTileLib(const SpecKey &key, ModuleOp mod, - MLIRContext *ctx); + invokeTileLibHelper(const SpecKey &key, StringRef candidateId = {}); + func::FuncOp invokeTileLib(const SpecKey &key, Operation *tileOp, + ModuleOp mod, MLIRContext *ctx); func::FuncOp invokeTileLibDaemon(const SpecKey &key, StringRef candidateId, ModuleOp mod, MLIRContext *ctx); @@ -885,7 +883,7 @@ static std::string buildContextAttrsJson(const SpecKey &key) { // Invoke the configured one-shot helper and return its stdout. // ============================================================================ std::optional -ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef method, +ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef candidateId) { auto pythonPath = llvm::sys::findProgramByName(pythonExe); if (!pythonPath) { @@ -918,10 +916,6 @@ ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef method, "--op", opName, "--operand-specs", operandSpecsJson, }; - if (!method.empty()) { - args.push_back("--method"); - args.push_back(method); - } if (!key.contextAttrs.empty()) { args.push_back("--context-attrs"); args.push_back(contextAttrsJson); @@ -963,10 +957,8 @@ ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef method, redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errMsg); if (rc != 0) { - StringRef operation = - method.empty() ? StringRef("instantiate") : method; - llvm::errs() << "ExpandTileOp: daemon helper " << operation - << " failed (rc=" << rc + llvm::errs() << "ExpandTileOp: daemon helper instantiate failed (rc=" + << rc << "): " << errMsg << "\n"; llvm::sys::fs::remove(tmpPath); return std::nullopt; @@ -986,47 +978,6 @@ ExpandState::invokeTileLibHelper(const SpecKey &key, StringRef method, return output; } -// ============================================================================ -// Discover the only legal candidate supported by this migration milestone. -// ============================================================================ -std::optional -ExpandState::discoverSingleTileLibCandidate(const SpecKey &key) { - auto metadataText = invokeTileLibHelper(key, "get_metadata"); - if (!metadataText) - return std::nullopt; - - auto parsed = llvm::json::parse(*metadataText); - if (!parsed) { - llvm::errs() << "ExpandTileOp: failed to parse PTODSL metadata: " - << llvm::toString(parsed.takeError()) << "\n"; - return std::nullopt; - } - - auto *root = parsed->getAsObject(); - auto *candidates = root ? root->getObject("candidates") : nullptr; - if (!candidates) { - llvm::errs() << "ExpandTileOp: PTODSL metadata is missing the " - "'candidates' object\n"; - return std::nullopt; - } - - std::string opName = "pto." + key.opName; - if (candidates->size() != 1) { - llvm::errs() << "ExpandTileOp: PTODSL metadata returned " - << candidates->size() << " legal candidates for " << opName - << "; version selection is required before rendering\n"; - return std::nullopt; - } - - const auto &candidate = *candidates->begin(); - if (!candidate.second.getAsObject()) { - llvm::errs() << "ExpandTileOp: malformed metadata for candidate '" - << candidate.first.str() << "'\n"; - return std::nullopt; - } - return candidate.first.str(); -} - // ============================================================================ // Invoke the daemon RPC to generate a specialized template function. // ============================================================================ @@ -1034,7 +985,7 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, StringRef candidateId, ModuleOp mod, MLIRContext *ctx) { - auto mlirText = invokeTileLibHelper(key, /*method=*/{}, candidateId); + auto mlirText = invokeTileLibHelper(key, candidateId); if (!mlirText) return nullptr; @@ -1060,6 +1011,8 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, SmallVector clonedFuncs; std::string uniqueName = buildUniqueFunctionBaseName(key); + if (!candidateId.empty()) + uniqueName += "__" + candidateId.str(); SymbolTable targetSymTable(mod); if (auto existingFunc = targetSymTable.lookup(uniqueName)) return cast(existingFunc); @@ -1111,16 +1064,34 @@ func::FuncOp ExpandState::invokeTileLibDaemon(const SpecKey &key, // ============================================================================ // Invoke the selected TileLib backend to generate a specialized template. // ============================================================================ -func::FuncOp ExpandState::invokeTileLib(const SpecKey &key, ModuleOp mod, +func::FuncOp ExpandState::invokeTileLib(const SpecKey &key, + Operation *tileOp, ModuleOp mod, MLIRContext *ctx) { // Try daemon first if daemon socket path is provided. if (!daemonSocketPath.empty()) { std::string candidateId; if (tileLibBackend == "ptodsl") { - auto discoveredCandidate = discoverSingleTileLibCandidate(key); - if (!discoveredCandidate) + auto candidates = + tileOp->getAttrOfType(kCandidatesAttr); + if (!candidates || candidates.empty()) { + tileOp->emitError( + "ExpandTileOp requires at least one template candidate"); + return nullptr; + } + + auto selected = dyn_cast(candidates[0]); + if (!selected) { + tileOp->emitError( + "ExpandTileOp candidate 0 must be a dictionary"); return nullptr; - candidateId = std::move(*discoveredCandidate); + } + auto selectedName = selected.getAs("name"); + if (!selectedName) { + tileOp->emitError( + "ExpandTileOp candidate 0 requires a string name"); + return nullptr; + } + candidateId = selectedName.getValue().str(); } func::FuncOp daemonResult = @@ -1354,7 +1325,7 @@ LogicalResult ExpandState::expandTileOpsInFunction(func::FuncOp func, } // Invoke the selected TileLib backend (with daemon-side caching). - func::FuncOp dslFn = invokeTileLib(*specKeyOpt, mod, ctx); + func::FuncOp dslFn = invokeTileLib(*specKeyOpt, op, mod, ctx); if (!dslFn) { StringRef opName = getTileOpName(op); op->emitError("ExpandTileOp: failed to instantiate TileLib template for " + diff --git a/lib/PTO/Transforms/InsertTemplateAttributes.cpp b/lib/PTO/Transforms/InsertTemplateAttributes.cpp new file mode 100644 index 000000000..071779a3f --- /dev/null +++ b/lib/PTO/Transforms/InsertTemplateAttributes.cpp @@ -0,0 +1,431 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include + +extern "C" { +extern char **environ; +} + +using namespace mlir; + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_INSERTTEMPLATEATTRIBUTES +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +namespace { + +constexpr llvm::StringLiteral kCandidatesAttr = "candidates"; + +struct CandidateMetadata { + int64_t id; + std::string name; + int64_t loopDepth; + bool postUpdate; + bool tail; +}; + +static std::string getDtypeString(Type elementType) { + if (elementType.isF32()) + return "f32"; + if (elementType.isF16()) + return "f16"; + if (elementType.isBF16()) + return "bf16"; + if (elementType.isSignlessInteger(32)) + return "i32"; + if (elementType.isSignlessInteger(16)) + return "i16"; + if (elementType.isSignlessInteger(8)) + return "i8"; + return ""; +} + +static std::string stringifyMemorySpace(pto::AddressSpace space) { + switch (space) { + case pto::AddressSpace::GM: + return "gm"; + case pto::AddressSpace::MAT: + return "mat"; + case pto::AddressSpace::LEFT: + return "left"; + case pto::AddressSpace::RIGHT: + return "right"; + case pto::AddressSpace::ACC: + return "acc"; + case pto::AddressSpace::BIAS: + return "bias"; + case pto::AddressSpace::SCALING: + return "scaling"; + case pto::AddressSpace::VEC: + case pto::AddressSpace::Zero: + return "ub"; + } + return "ub"; +} + +static std::string getMemorySpaceString(pto::TileBufType tileType) { + auto memorySpace = + dyn_cast_or_null(tileType.getMemorySpace()); + return memorySpace ? stringifyMemorySpace(memorySpace.getAddressSpace()) + : "ub"; +} + +static StringRef getBLayoutString(pto::BLayout layout) { + return layout == pto::BLayout::ColMajor ? "col_major" : "row_major"; +} + +static StringRef getSLayoutString(pto::SLayout layout) { + if (layout == pto::SLayout::RowMajor) + return "row_major"; + if (layout == pto::SLayout::ColMajor) + return "col_major"; + return "none_box"; +} + +static void appendJsonIntArray(std::string &json, ArrayRef values) { + json += "["; + for (auto [index, value] : llvm::enumerate(values)) { + if (index != 0) + json += ","; + json += std::to_string(value); + } + json += "]"; +} + +static std::optional +buildOperandSpecsJson(Operation *operation) { + std::string json = "["; + for (auto [index, operand] : llvm::enumerate(operation->getOperands())) { + auto tileType = dyn_cast(operand.getType()); + if (!tileType) { + operation->emitError( + "InsertTemplateAttributes currently supports only tile operands"); + return std::nullopt; + } + + std::string dtype = getDtypeString(tileType.getElementType()); + if (dtype.empty()) { + operation->emitError( + "InsertTemplateAttributes encountered an unsupported tile dtype"); + return std::nullopt; + } + + if (index != 0) + json += ","; + json += "{\"kind\":\"tile\",\"dtype\":\"" + dtype + "\",\"shape\":"; + appendJsonIntArray(json, tileType.getShape()); + json += ",\"valid_shape\":"; + auto validShape = tileType.getValidShape(); + appendJsonIntArray(json, + validShape.empty() ? tileType.getShape() : validShape); + json += ",\"memory_space\":\""; + json += getMemorySpaceString(tileType); + json += "\",\"config\":{"; + + pto::BLayout bLayout = pto::BLayout::RowMajor; + pto::SLayout sLayout = pto::SLayout::NoneBox; + int64_t fractalSize = 0; + uint64_t padValue = 0; + if (auto config = tileType.getConfigAttr()) { + bLayout = config.getBLayout().getValue(); + sLayout = config.getSLayout().getValue(); + if (config.getSFractalSize()) + fractalSize = config.getSFractalSize().getInt(); + padValue = static_cast(config.getPad().getValue()); + } + + json += "\"b_layout\":\""; + json += getBLayoutString(bLayout); + json += "\",\"s_layout\":\""; + json += getSLayoutString(sLayout); + json += "\",\"s_fractal_size\":"; + json += std::to_string(fractalSize); + json += ",\"pad_value\":\"0x"; + json += llvm::utohexstr(padValue, /*LowerCase=*/false); + json += "\"}}"; + } + json += "]"; + return json; +} + +static std::optional +getTargetArch(Operation *operation) { + auto module = operation->getParentOfType(); + if (!module) { + operation->emitError( + "InsertTemplateAttributes requires a parent module"); + return std::nullopt; + } + auto target = module->getAttrOfType("pto.target_arch"); + if (!target) { + operation->emitError( + "InsertTemplateAttributes requires pto.target_arch"); + return std::nullopt; + } + return target.getValue().str(); +} + +static std::optional +invokeMetadataHelper(Operation *operation, StringRef pythonExe, + StringRef daemonSocketPath, StringRef tileLibPkgPath, + StringRef daemonHelperModule) { + auto pythonPath = llvm::sys::findProgramByName(pythonExe); + if (!pythonPath) { + operation->emitError("InsertTemplateAttributes cannot find Python '") + << pythonExe << "'"; + return std::nullopt; + } + + auto target = getTargetArch(operation); + auto operandSpecs = buildOperandSpecsJson(operation); + if (!target || !operandSpecs) + return std::nullopt; + + llvm::SmallString<128> outputPath; + int outputFd; + if (auto error = llvm::sys::fs::createTemporaryFile( + "tilelib_metadata", "json", outputFd, outputPath)) { + operation->emitError("InsertTemplateAttributes cannot create temporary " + "metadata output: ") + << error.message(); + return std::nullopt; + } + ::close(outputFd); + + std::string opName = operation->getName().getStringRef().str(); + SmallVector args = { + *pythonPath, "-m", daemonHelperModule, + "--method", "get_metadata", "--socket", + daemonSocketPath, "--target", *target, + "--op", opName, "--operand-specs", + *operandSpecs, + }; + + std::optional redirects[] = { + std::nullopt, + StringRef(outputPath), + std::nullopt, + }; + + SmallVector environment; + std::string pythonPathEnvironment; + std::vector environmentStorage; + bool hasPythonPath = !tileLibPkgPath.empty(); + if (hasPythonPath) { + const char *existingPath = ::getenv("PYTHONPATH"); + pythonPathEnvironment = "PYTHONPATH=" + tileLibPkgPath.str(); + if (existingPath && existingPath[0] != '\0') + pythonPathEnvironment += ":" + std::string(existingPath); + + for (char **entry = environ; *entry; ++entry) { + StringRef value(*entry); + if (!value.starts_with("PYTHONPATH=")) + environmentStorage.push_back(value.str()); + } + environmentStorage.push_back(pythonPathEnvironment); + for (std::string &value : environmentStorage) + environment.push_back(value); + } + + std::string errorMessage; + int result = llvm::sys::ExecuteAndWait( + *pythonPath, args, + hasPythonPath + ? std::optional>(environment) + : std::nullopt, + redirects, /*secondsToWait=*/30, /*memoryLimit=*/0, &errorMessage); + if (result != 0) { + llvm::sys::fs::remove(outputPath); + operation->emitError("InsertTemplateAttributes metadata RPC failed: ") + << errorMessage; + return std::nullopt; + } + + auto output = llvm::MemoryBuffer::getFile(outputPath); + llvm::sys::fs::remove(outputPath); + if (!output) { + operation->emitError( + "InsertTemplateAttributes cannot read metadata output"); + return std::nullopt; + } + return (*output)->getBuffer().str(); +} + +static FailureOr +parseCandidateAttributes(Operation *operation, StringRef metadataJson) { + auto parsed = llvm::json::parse(metadataJson); + if (!parsed) { + llvm::consumeError(parsed.takeError()); + operation->emitError( + "InsertTemplateAttributes received invalid metadata JSON"); + return failure(); + } + + auto *root = parsed->getAsObject(); + auto *candidates = root ? root->getObject("candidates") : nullptr; + if (!candidates || candidates->empty()) { + operation->emitError( + "InsertTemplateAttributes found no legal template candidates"); + return failure(); + } + + SmallVector parsedCandidates; + parsedCandidates.reserve(candidates->size()); + for (const auto &entry : *candidates) { + auto *metadata = entry.second.getAsObject(); + if (!metadata) { + operation->emitError( + "InsertTemplateAttributes candidate metadata must be an object"); + return failure(); + } + + auto name = metadata->getString("name"); + auto id = metadata->getInteger("id"); + auto loopDepth = metadata->getInteger("loop_depth"); + auto postUpdate = metadata->getBoolean("is_post_update"); + auto tail = metadata->getBoolean("has_tail"); + if (!name || !loopDepth || !postUpdate || !tail) { + operation->emitError( + "InsertTemplateAttributes candidate metadata is missing name, " + "loop_depth, is_post_update, or has_tail"); + return failure(); + } + if (!id && candidates->size() != 1) { + operation->emitError( + "InsertTemplateAttributes requires an id for every " + "multi-candidate template"); + return failure(); + } + + parsedCandidates.push_back(CandidateMetadata{ + id.value_or(0), + name->str(), + *loopDepth, + *postUpdate, + *tail, + }); + } + + llvm::sort(parsedCandidates, + [](const CandidateMetadata &left, + const CandidateMetadata &right) { + if (left.id != right.id) + return left.id < right.id; + return left.name < right.name; + }); + for (auto [index, candidate] : llvm::enumerate(parsedCandidates)) { + if (index != 0 && candidate.id == parsedCandidates[index - 1].id) { + operation->emitError( + "InsertTemplateAttributes candidate ids must be unique"); + return failure(); + } + } + + Builder builder(operation->getContext()); + SmallVector attributes; + attributes.reserve(parsedCandidates.size()); + for (const CandidateMetadata &candidate : parsedCandidates) { + attributes.push_back(DictionaryAttr::get( + operation->getContext(), + { + builder.getNamedAttr("id", builder.getI64IntegerAttr(candidate.id)), + builder.getNamedAttr("name", + builder.getStringAttr(candidate.name)), + builder.getNamedAttr( + "loop_depth", + builder.getI64IntegerAttr(candidate.loopDepth)), + builder.getNamedAttr( + "postupdate", + builder.getI64IntegerAttr(candidate.postUpdate ? 1 : 0)), + builder.getNamedAttr( + "tail", builder.getI64IntegerAttr(candidate.tail ? 1 : 0)), + })); + } + return builder.getArrayAttr(attributes); +} + +struct InsertTemplateAttributesPass + : public pto::impl::InsertTemplateAttributesBase< + InsertTemplateAttributesPass> { + using InsertTemplateAttributesBase::InsertTemplateAttributesBase; + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (daemonSocketPath.empty()) { + module.emitError( + "InsertTemplateAttributes requires a PTODSL daemon socket"); + return signalPassFailure(); + } + + SmallVector tileOperations; + module.walk([&](Operation *operation) { + if (isa(operation)) + return; + if (isa(operation)) + tileOperations.push_back(operation); + }); + + for (Operation *operation : tileOperations) { + auto metadata = invokeMetadataHelper( + operation, pythonExe, daemonSocketPath, tileLibPkgPath, + daemonHelperModule); + if (!metadata) + return signalPassFailure(); + + auto candidates = parseCandidateAttributes(operation, *metadata); + if (failed(candidates)) + return signalPassFailure(); + operation->setAttr(kCandidatesAttr, *candidates); + } + } +}; + +} // namespace + +namespace mlir { +namespace pto { + +std::unique_ptr createInsertTemplateAttributesPass() { + return std::make_unique(); +} + +std::unique_ptr createInsertTemplateAttributesPass( + const InsertTemplateAttributesOptions &options) { + return std::make_unique(options); +} + +} // namespace pto +} // namespace mlir diff --git a/ptodsl/README.md b/ptodsl/README.md index 9247dba88..d4d9a1460 100644 --- a/ptodsl/README.md +++ b/ptodsl/README.md @@ -71,12 +71,12 @@ overrides that default. Use `--ptodsl-pkg-path=/path/to/package/root` for an explicit command-line override. PTODSL daemon failures are reported as errors and never fall back to the TileLang implementation. -At this migration stage, end-to-end expansion is intended for operations with -one legal PTODSL candidate, such as `pto.tsub`. `ExpandTileOp` first requests -legal-candidate metadata and then renders with the sole candidate's ID. -Operations with tied template candidates will be enabled when candidate -discovery moves into the planned `DiscoverTileLibCandidates` pass and version -selection becomes a separate stage. +`InsertTemplateAttributes` queries legal-candidate metadata before fusion and +stores an ordered `candidates` array containing only `id`, `name`, +`loop_depth`, `postupdate`, and `tail`. Fusion may filter this array. +Candidates are ordered by unique `id`. `ExpandTileOp` renders the first +candidate that remains, providing a deterministic fallback when several +candidates reach expansion. See the [PTODSL TileLib migration test checklist](docs/tilelib-migration-testing.md) diff --git a/ptodsl/docs/tilelib-migration-testing.md b/ptodsl/docs/tilelib-migration-testing.md index d28107c19..caaf7dc98 100644 --- a/ptodsl/docs/tilelib-migration-testing.md +++ b/ptodsl/docs/tilelib-migration-testing.md @@ -26,8 +26,8 @@ changes. | PTODSL TileLib package | `test_tilelib_constraints.py`, `test_tilelib_elementwise.py`, `test_tilelib_render.py`, `test_tilelib_select.py` | Covers legality constraints, template registration and selection, and rendering | | PTODSL daemon | `test_tilelib_daemon.py` | Covers the Unix-socket protocol, metadata, rendering, candidate IDs, and caching | | PTOAS daemon selection | `expand_tile_op_ptodsl_tsub.pto` | Confirms `--tile-lib-backend=ptodsl` starts and uses the PTODSL daemon | -| Two-call expansion | `expand_tile_op_ptodsl_tsub.pto` | Confirms metadata discovery followed by rendering with the sole candidate ID | -| Multi-candidate boundary | `expand_tile_op_ptodsl_tadd_requires_selection.pto` | Confirms four legal `tadd` candidates require a separate selection stage | +| Separate metadata/render passes | `expand_tile_op_ptodsl_tadd.pto` | Confirms `InsertTemplateAttributes` records compact metadata before `ExpandTileOp` renders | +| Multi-candidate fallback | `expand_tile_op_ptodsl_tadd.pto` | Confirms `ExpandTileOp` renders candidate index zero when several candidates remain | ## Python TileLib tests @@ -51,6 +51,15 @@ Each command prints `OK` when successful. ## PTOAS integration tests +Run the focused lit tests through the generated site configuration. Start lit +from `build/test/lit`; passing source files under `test/lit` directly bypasses +the generated LLVM configuration: + +```bash +"$LLVM_BUILD_DIR/bin/llvm-lit" -sv build/test/lit \ + --filter='expand_tile_op_(ptodsl_tsub|ptodsl_tadd|tilelang_tsub)' +``` + ### PTODSL positive path: one legal candidate `pto.tsub` has one legal PTODSL candidate. The test checks that PTOAS expands @@ -63,22 +72,54 @@ ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ "$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tsub.pto ``` -### PTODSL negative path: selection is still required +### PTODSL candidate attributes and multi-candidate fallback + +Inspect the compact candidate list inserted before fusion: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-pto-ir \ + --tile-lib-backend=ptodsl \ + --mlir-print-ir-after=pto-insert-template-attributes \ + test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + -o /dev/null 2>&1 | +"$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + --check-prefix=META +``` + +Confirm that insertion also runs before `FusionPlan` when fusion is enabled: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --pto-level=level2 \ + --enable-op-fusion --emit-pto-ir --tile-lib-backend=ptodsl \ + --mlir-print-ir-before=pto-fusion-plan \ + test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + -o /dev/null 2>&1 | +"$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + --check-prefix=PREFUSION +``` -`pto.tadd` currently has four legal candidates. PTOAS is expected to reject it -after metadata discovery because version selection is not a separate stage yet: +Inspect `ExpandTileOp` immediately after selection and confirm candidate zero +was used: ```bash ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ --tile-lib-backend=ptodsl \ - test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto \ + --mlir-print-ir-after=pto-expand-tile-op \ + test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ -o /dev/null 2>&1 | -"$FILECHECK" \ - test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto +"$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + --check-prefix=SELECT ``` -The `ptoas` process fails intentionally in this test. `FileCheck` succeeds -only when it sees the expected four-candidate diagnostic. +Confirm the selected template expands through the full VPTO pipeline: + +```bash +ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ + --tile-lib-backend=ptodsl \ + test/lit/vpto/expand_tile_op_ptodsl_tadd.pto -o - 2>/dev/null | +"$FILECHECK" test/lit/vpto/expand_tile_op_ptodsl_tadd.pto \ + --check-prefix=EXPAND +``` ### Legacy backend regression @@ -98,6 +139,5 @@ ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto \ echo $? ``` -`0` means the check passed. When candidate selection is implemented, replace -the expected-failure `tadd` coverage with a positive selected-version test and -update the milestone table above. +`0` means the check passed. When fusion begins filtering candidates, add +coverage for the filtered array while retaining the index-zero fallback test. diff --git a/ptodsl/ptodsl/tilelib/serving/daemon.py b/ptodsl/ptodsl/tilelib/serving/daemon.py index 02ac4533b..1445c990f 100644 --- a/ptodsl/ptodsl/tilelib/serving/daemon.py +++ b/ptodsl/ptodsl/tilelib/serving/daemon.py @@ -25,6 +25,7 @@ import socketserver import threading +from .. import constraints as _constraints from .. import registry as _registry from ..metadata import ScalarType, TileSpec from ..templates import load_template @@ -87,8 +88,12 @@ def _metadata_value(value): return value -def _metadata_for_descriptor(descriptor) -> dict: +def _metadata_for_descriptor(descriptor, constraint_context: dict) -> dict: metadata = descriptor.metadata + if callable(metadata.Tail): + has_tail = _constraints.passes((metadata.Tail,), constraint_context) + else: + has_tail = bool(metadata.Tail) return { "op": metadata.op, "target": metadata.target, @@ -104,6 +109,7 @@ def _metadata_for_descriptor(descriptor) -> dict: "loop_depth": metadata.loop_depth, "id": metadata.id, "Tail": _metadata_value(metadata.Tail), + "has_tail": has_tail, "is_post_update": metadata.is_post_update, "tags": list(metadata.tags), } @@ -133,11 +139,17 @@ def metadata_request( """Return every legal candidate and its selection metadata.""" tile_specs = _tile_specs_for_request(target, op, operand_specs) legal = _registry.legal_candidates(op, target, tile_specs, context_attrs) + constraint_context = _constraints.build_context(tile_specs, target, op) + if context_attrs: + constraint_context.update(context_attrs) return { "target": target, "op": op, "candidates": { - descriptor.name: _metadata_for_descriptor(descriptor) + descriptor.name: _metadata_for_descriptor( + descriptor, + constraint_context, + ) for descriptor in legal }, } diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py b/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py index c5d84cbcf..c7ec44b63 100644 --- a/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py +++ b/ptodsl/ptodsl/tilelib/templates/a5/tcolmax.py @@ -48,6 +48,9 @@ def _validate_tcolmax( memory_spaces=["ub"], constraints=[_validate_tcolmax], priority=0, + id=0, + loop_depth=2, + is_post_update=False, ) def template_tcolmax(src: pto.Tile, dst: pto.Tile): dtype = dst.element_type diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py b/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py index cd4113cda..0d9366fd4 100644 --- a/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py +++ b/ptodsl/ptodsl/tilelib/templates/a5/tdiv.py @@ -24,6 +24,9 @@ layouts=["row_major"], memory_spaces=["ub"], priority=0, + id=0, + loop_depth=2, + is_post_update=False, ) def template_tdiv(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tmax.py b/ptodsl/ptodsl/tilelib/templates/a5/tmax.py index 3c388541e..a424bc460 100644 --- a/ptodsl/ptodsl/tilelib/templates/a5/tmax.py +++ b/ptodsl/ptodsl/tilelib/templates/a5/tmax.py @@ -18,6 +18,9 @@ layouts=["row_major"], memory_spaces=["ub"], priority=0, + id=0, + loop_depth=2, + is_post_update=False, ) def template_tmax(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tmin.py b/ptodsl/ptodsl/tilelib/templates/a5/tmin.py index d0bcc28fd..d03b81c0a 100644 --- a/ptodsl/ptodsl/tilelib/templates/a5/tmin.py +++ b/ptodsl/ptodsl/tilelib/templates/a5/tmin.py @@ -18,6 +18,9 @@ layouts=["row_major"], memory_spaces=["ub"], priority=0, + id=0, + loop_depth=2, + is_post_update=False, ) def template_tmin(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type diff --git a/ptodsl/ptodsl/tilelib/templates/a5/tsub.py b/ptodsl/ptodsl/tilelib/templates/a5/tsub.py index d1318990e..f885e9b2e 100644 --- a/ptodsl/ptodsl/tilelib/templates/a5/tsub.py +++ b/ptodsl/ptodsl/tilelib/templates/a5/tsub.py @@ -18,6 +18,9 @@ layouts=["row_major"], memory_spaces=["ub"], priority=0, + id=0, + loop_depth=2, + is_post_update=False, ) def template_tsub(src0: pto.Tile, src1: pto.Tile, dst: pto.Tile): dtype = dst.element_type diff --git a/ptodsl/tests/test_tilelib_daemon.py b/ptodsl/tests/test_tilelib_daemon.py index 9f5b7db3e..81fcd5805 100644 --- a/ptodsl/tests/test_tilelib_daemon.py +++ b/ptodsl/tests/test_tilelib_daemon.py @@ -16,12 +16,12 @@ from ptodsl.tilelib.serving.daemon import TileLibDaemonServer -def _tile_spec(dtype="f32"): +def _tile_spec(dtype="f32", shape=(8, 64)): return { "kind": "tile", "dtype": dtype, - "shape": [8, 64], - "valid_shape": [8, 64], + "shape": list(shape), + "valid_shape": list(shape), "memory_space": "ub", "config": { "b_layout": "row_major", @@ -102,9 +102,23 @@ def test_get_metadata_returns_legal_candidates(self): selected = candidates[TADD_2D_NO_POST_UPDATE] self.assertEqual(selected["loop_depth"], 2) self.assertEqual(selected["Tail"], {"callable": "has_tail"}) + self.assertFalse(selected["has_tail"]) self.assertFalse(selected["is_post_update"]) self.assertEqual(selected["tags"], ["binop", "2d", "no_post_update"]) + def test_get_metadata_evaluates_tail_for_each_request(self): + tail_operands = [ + _tile_spec(shape=(7, 65)), + _tile_spec(shape=(7, 65)), + _tile_spec(shape=(7, 65)), + ] + + metadata = self.client.get_metadata("a5", "pto.tadd", tail_operands) + + self.assertTrue( + metadata["candidates"][TADD_2D_NO_POST_UPDATE]["has_tail"] + ) + def test_cache_stats_and_clear_are_available_over_rpc(self): arguments = ( "a5", diff --git a/test/lit/vpto/expand_tile_op_ptodsl_tadd.pto b/test/lit/vpto/expand_tile_op_ptodsl_tadd.pto new file mode 100644 index 000000000..1c34a5ab0 --- /dev/null +++ b/test/lit/vpto/expand_tile_op_ptodsl_tadd.pto @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// InsertTemplateAttributes records a compact, ordered candidate list before +// fusion. ExpandTileOp renders the first candidate that remains. +// +// Running without --enable-op-fusion proves metadata insertion is not gated +// by fusion. Printing before FusionPlan proves its position when fusion is on. +// +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-pto-ir --tile-lib-backend=ptodsl --mlir-print-ir-after=pto-insert-template-attributes %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=META +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --pto-level=level2 --enable-op-fusion --emit-pto-ir --tile-lib-backend=ptodsl --mlir-print-ir-before=pto-fusion-plan %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=PREFUSION +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --tile-lib-backend=ptodsl --mlir-print-ir-after=pto-expand-tile-op %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=SELECT +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --tile-lib-backend=ptodsl %s -o - 2>/dev/null | FileCheck %s --check-prefix=EXPAND + +// META: pto.tadd +// META-SAME: candidates = [ +// META-SAME: id = 0 : i64 +// META-SAME: loop_depth = 2 : i64 +// META-SAME: name = "template_tadd_2d_no_post_update" +// META-SAME: postupdate = 0 : i64 +// META-SAME: tail = 0 : i64 +// META-SAME: id = 1 : i64 +// META-SAME: loop_depth = 1 : i64 +// META-SAME: name = "template_tadd_1d_no_post_update" +// META-SAME: postupdate = 0 : i64 +// META-SAME: tail = 0 : i64 +// META-SAME: id = 2 : i64 +// META-SAME: loop_depth = 2 : i64 +// META-SAME: name = "template_tadd_2d_post_update" +// META-SAME: postupdate = 1 : i64 +// META-SAME: tail = 0 : i64 +// META-SAME: id = 3 : i64 +// META-SAME: loop_depth = 1 : i64 +// META-SAME: name = "template_tadd_1d_post_update" +// META-SAME: postupdate = 1 : i64 +// META-SAME: tail = 0 : i64}] +// META-NOT: priority = +// META-NOT: tags = +// META-NOT: fusible = + +// PREFUSION: IR Dump Before FusionPlan +// PREFUSION: pto.tadd +// PREFUSION-SAME: candidates = [ + +// SELECT: func.func {{.*}}@{{.*}}__template_tadd_2d_no_post_update + +// EXPAND: func.func @TADD +// EXPAND-NOT: pto.tadd ins +// EXPAND: pto.vecscope +// EXPAND: pto.vadd +// EXPAND: pto.vsts + +module attributes {pto.kernel_kind = #pto.kernel_kind} { + func.func @TADD() { + %a = pto.alloc_tile + : !pto.tile_buf + %b = pto.alloc_tile + : !pto.tile_buf + %dst = pto.alloc_tile + : !pto.tile_buf + + pto.tadd ins( + %a, %b + : !pto.tile_buf, + !pto.tile_buf) + outs( + %dst + : !pto.tile_buf) + return + } +} diff --git a/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto b/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto deleted file mode 100644 index 7279a4456..000000000 --- a/test/lit/vpto/expand_tile_op_ptodsl_tadd_requires_selection.pto +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// The two-call migration milestone discovers candidates before rendering, but -// deliberately does not choose among multiple legal versions yet. -// -// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto --tile-lib-backend=ptodsl %s -o /dev/null 2>&1 | FileCheck %s - -// CHECK: ExpandTileOp: PTODSL metadata returned 4 legal candidates for pto.tadd; version selection is required before rendering - -module attributes {pto.kernel_kind = #pto.kernel_kind} { - func.func @TADD() { - %a = pto.alloc_tile - : !pto.tile_buf - %b = pto.alloc_tile - : !pto.tile_buf - %dst = pto.alloc_tile - : !pto.tile_buf - - pto.tadd ins( - %a, %b - : !pto.tile_buf, - !pto.tile_buf) - outs( - %dst - : !pto.tile_buf) - return - } -} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 94d6158eb..70fce968a 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -486,6 +486,17 @@ static pto::ExpandTileOpOptions resolveExpandTileOpOptions(int argc, return expandOpts; } +static pto::InsertTemplateAttributesOptions +buildInsertTemplateAttributesOptions( + const pto::ExpandTileOpOptions &expandOptions) { + pto::InsertTemplateAttributesOptions options; + options.pythonExe = expandOptions.pythonExe; + options.daemonSocketPath = expandOptions.daemonSocketPath; + options.tileLibPkgPath = expandOptions.tileLibPkgPath; + options.daemonHelperModule = expandOptions.daemonHelperModule; + return options; +} + static llvm::cl::opt enableOpFusion( "enable-op-fusion", llvm::cl::desc("Enable A5 tile fusion on level2/level3. EmitC uses " @@ -1610,15 +1621,15 @@ static void prepareVPTOForEmission(PassManager &pm) { kernelModulePM.addPass(pto::createPTOValidateVPTOEmissionIRPass()); } -static void lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, int argc, - char **argv) { +static void +lowerPTOToVPTOBackend(PassManager &pm, ModuleOp module, + const pto::ExpandTileOpOptions &expandOpts) { auto &kernelModulePM = pm.nest(); auto moduleArchAttr = module->getAttrOfType("pto.target_arch"); const bool enableA5VPTOPostLoweringFusionLifecycle = enableOpFusion && moduleArchAttr && moduleArchAttr.getValue() == "a5"; - pto::ExpandTileOpOptions expandOpts = resolveExpandTileOpOptions(argc, argv); kernelModulePM.addPass(pto::createExpandTileOpPass(expandOpts)); kernelModulePM.addPass(pto::createPTOInlineLibCallPass()); @@ -1698,14 +1709,21 @@ static int emitVPTOBackendResult(ModuleOp module, PTOASCompileResult &result, } static LogicalResult runVPTOBackendPipeline(OwningOpRef &module, - int argc, char **argv, - bool hasTileOpsToExpand) { + bool hasTileOpsToExpand, + const pto::ExpandTileOpOptions + *expandOptions) { PassManager pm(module->getContext()); pm.enableVerifier(); pm.addPass(pto::createVPTOSplitCVModulePass()); pm.addPass(pto::createVPTONormalizeContainerPass()); - if (hasTileOpsToExpand) - lowerPTOToVPTOBackend(pm, module.get(), argc, argv); + if (hasTileOpsToExpand) { + if (!expandOptions) { + llvm::errs() << "Error: tile expansion requires resolved TileLib " + "options.\n"; + return failure(); + } + lowerPTOToVPTOBackend(pm, module.get(), *expandOptions); + } prepareVPTOForEmission(pm); if (failed(applyConfiguredPassManagerCLOptions( pm, "VPTO unified emission pipeline"))) @@ -1886,6 +1904,10 @@ int mlir::pto::compilePTOASModule( } const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); + std::optional expandOptions; + if (effectiveBackend == PTOBackend::VPTO && hasTileOpsToExpand && + tileLibBackend == TileLibBackend::PTODSL) + expandOptions = resolveExpandTileOpOptions(argc, argv); if (effectiveBackend == PTOBackend::VPTO && !hasTileOpsToExpand) { if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { @@ -1893,7 +1915,8 @@ int mlir::pto::compilePTOASModule( "skipping the shared PTO-to-VPTO lowering pipeline.\n"; return 1; } - if (failed(runVPTOBackendPipeline(module, argc, argv, hasTileOpsToExpand))) + if (failed(runVPTOBackendPipeline(module, hasTileOpsToExpand, + /*expandOptions=*/nullptr))) return 1; return emitVPTOBackendResult(*module, result, emitVPTOHostStub, context.getCANNVersionOrDefault()); @@ -1926,6 +1949,16 @@ int mlir::pto::compilePTOASModule( pm.addNestedPass( pto::createPTOValidateIntToPtrUsesPass()); + // PTODSL legality discovery happens on tile-native PTO IR before fusion. + // Fusion may later filter the ordered `candidates` array; ExpandTileOp + // consumes the first candidate that remains. + if (expandOptions && expandOptions->tileLibBackend == "ptodsl") { + auto insertOptions = + buildInsertTemplateAttributesOptions(*expandOptions); + pm.addPass( + pto::createInsertTemplateAttributesPass(insertOptions)); + } + // Keep frontend fusion on tile-native PTO IR and annotate last_use directly // on scheduled block-local spans before the shared mainline lowers tiles. // The shape-inference switch drives FusionPlan only: that is where the @@ -1989,6 +2022,9 @@ int mlir::pto::compilePTOASModule( // or an `arith.select` chain (dynamic slot). The multi-address cast // produced by PlanMemory survives as the alloc anchor. pm.addPass(pto::createPTOResolveBufferSelectPass()); + module->getOperation()->setAttr( + "pto.target_arch", + mlir::StringAttr::get(module->getContext(), arch)); if (emitMlirIR) { if (failed(pm.run(*module))) { @@ -2021,6 +2057,12 @@ int mlir::pto::compilePTOASModule( return 1; } + // The PTODSL daemon is needed before the main pipeline for metadata. + // Legacy TileLang can still be resolved lazily immediately before + // ExpandTileOp, preserving the prior --emit-pto-ir behavior. + if (hasTileOpsToExpand && !expandOptions) + expandOptions = resolveExpandTileOpOptions(argc, argv); + if (ptoPrintSeamIR) { module->print(llvm::errs()); llvm::errs() << "\n"; @@ -2028,7 +2070,9 @@ int mlir::pto::compilePTOASModule( if (failed(emitSharedPreBackendSeamIR(*module, ptoSeamIRFile))) return 1; - if (failed(runVPTOBackendPipeline(module, argc, argv, hasTileOpsToExpand))) + if (failed(runVPTOBackendPipeline( + module, hasTileOpsToExpand, + expandOptions ? &*expandOptions : nullptr))) return 1; return emitVPTOBackendResult(*module, result, emitVPTOHostStub, context.getCANNVersionOrDefault()); From 839ce4cb96bef17a7ac9ffba7a1499d6978664d6 Mon Sep 17 00:00:00 2001 From: ManiSadati Date: Wed, 1 Jul 2026 20:53:34 +0000 Subject: [PATCH 6/6] fix(ptodsl): harden TileLib daemon socket handling --- ptodsl/ptodsl/tilelib/serving/daemon.py | 66 ++++++++++++------------- ptodsl/ptodsl/tilelib/serving/wire.py | 18 ++++++- ptodsl/tests/test_tilelib_daemon.py | 36 +++++++++++++- 3 files changed, 85 insertions(+), 35 deletions(-) diff --git a/ptodsl/ptodsl/tilelib/serving/daemon.py b/ptodsl/ptodsl/tilelib/serving/daemon.py index 1445c990f..e71cd374d 100644 --- a/ptodsl/ptodsl/tilelib/serving/daemon.py +++ b/ptodsl/ptodsl/tilelib/serving/daemon.py @@ -32,6 +32,14 @@ from .wire import recv_message, send_message +def _remove_socket_path(socket_path: str) -> None: + """Remove an existing socket entry, including a broken symlink.""" + try: + os.unlink(socket_path) + except FileNotFoundError: + pass + + def _build_tile_specs(descriptor, operand_specs: list) -> dict: """Map positional daemon operands onto a template's parameter names.""" if not isinstance(operand_specs, list): @@ -174,26 +182,24 @@ def render_request( return descriptor.specialize(**tile_specs).mlir_text() -class TileLibDaemonServer(socketserver.ThreadingUnixStreamServer): - """Threaded Unix-socket RPC server with an in-memory render cache.""" +class TileLibDaemonServer(socketserver.UnixStreamServer): + """Sequential Unix-socket RPC server with an in-memory render cache.""" allow_reuse_address = True - daemon_threads = True def __init__(self, socket_path: str, max_entries: int = 1000): if max_entries <= 0: raise ValueError("max_entries must be greater than zero") super().__init__(socket_path, _Handler) + os.chmod(socket_path, 0o600) self._cache: dict[str, str] = {} - self._state_lock = threading.Lock() self._max_entries = max_entries self._stats = {"hits": 0, "misses": 0, "evictions": 0} @property def stats(self) -> dict: """Return a snapshot of cache counters for diagnostics and tests.""" - with self._state_lock: - return dict(self._stats) + return dict(self._stats) def dispatch(self, request: dict) -> dict: if not isinstance(request, dict): @@ -228,20 +234,18 @@ def _get_metadata(self, target, op, operand_specs, context_attrs=None): return metadata_request(target, op, operand_specs, context_attrs) def _get_stats(self): - with self._state_lock: - requests = self._stats["hits"] + self._stats["misses"] - total_entries = len(self._cache) - return { - **self._stats, - "entries": total_entries, - "total_entries": total_entries, - "max_entries": self._max_entries, - "hit_rate": self._stats["hits"] / requests if requests else 0.0, - } + requests = self._stats["hits"] + self._stats["misses"] + total_entries = len(self._cache) + return { + **self._stats, + "entries": total_entries, + "total_entries": total_entries, + "max_entries": self._max_entries, + "hit_rate": self._stats["hits"] / requests if requests else 0.0, + } def _clear(self): - with self._state_lock: - self._cache.clear() + self._cache.clear() return {"cleared": True} def _instantiate( @@ -264,12 +268,11 @@ def _instantiate( separators=(",", ":"), ) - with self._state_lock: - cached = self._cache.get(key) - if cached is not None: - self._stats["hits"] += 1 - return cached - self._stats["misses"] += 1 + cached = self._cache.get(key) + if cached is not None: + self._stats["hits"] += 1 + return cached + self._stats["misses"] += 1 mlir_text = render_request( target, @@ -279,11 +282,10 @@ def _instantiate( candidate_id, ) - with self._state_lock: - if len(self._cache) >= self._max_entries: - self._cache.pop(next(iter(self._cache))) - self._stats["evictions"] += 1 - self._cache[key] = mlir_text + if len(self._cache) >= self._max_entries: + self._cache.pop(next(iter(self._cache))) + self._stats["evictions"] += 1 + self._cache[key] = mlir_text return mlir_text @@ -312,8 +314,7 @@ def _parse_args(argv): def main(argv=None): args = _parse_args(argv) - if os.path.exists(args.socket): - os.unlink(args.socket) + _remove_socket_path(args.socket) server = TileLibDaemonServer(args.socket, max_entries=args.max_entries) stop = threading.Event() @@ -334,8 +335,7 @@ def _request_shutdown(*_): finally: server.shutdown() server.server_close() - if os.path.exists(args.socket): - os.unlink(args.socket) + _remove_socket_path(args.socket) if __name__ == "__main__": diff --git a/ptodsl/ptodsl/tilelib/serving/wire.py b/ptodsl/ptodsl/tilelib/serving/wire.py index 5a8b8076d..8f9185137 100644 --- a/ptodsl/ptodsl/tilelib/serving/wire.py +++ b/ptodsl/ptodsl/tilelib/serving/wire.py @@ -12,6 +12,9 @@ import json +MAX_MESSAGE_SIZE = 64 * 1024 * 1024 + + def recv_exactly(sock, length: int) -> bytes: """Read exactly ``length`` bytes or fail if the peer closes early.""" chunks = [] @@ -28,6 +31,10 @@ def recv_exactly(sock, length: int) -> bytes: def send_message(sock, message: dict) -> None: """Send one UTF-8 JSON message with a 4-byte big-endian length prefix.""" payload = json.dumps(message).encode("utf-8") + if len(payload) > MAX_MESSAGE_SIZE: + raise ValueError( + f"message length {len(payload)} exceeds limit {MAX_MESSAGE_SIZE}" + ) sock.sendall(len(payload).to_bytes(4, byteorder="big")) sock.sendall(payload) @@ -35,7 +42,16 @@ def send_message(sock, message: dict) -> None: def recv_message(sock) -> dict: """Receive one length-prefixed UTF-8 JSON message.""" length = int.from_bytes(recv_exactly(sock, 4), byteorder="big") + if length > MAX_MESSAGE_SIZE: + raise ValueError( + f"message length {length} exceeds limit {MAX_MESSAGE_SIZE}" + ) return json.loads(recv_exactly(sock, length).decode("utf-8")) -__all__ = ["recv_exactly", "recv_message", "send_message"] +__all__ = [ + "MAX_MESSAGE_SIZE", + "recv_exactly", + "recv_message", + "send_message", +] diff --git a/ptodsl/tests/test_tilelib_daemon.py b/ptodsl/tests/test_tilelib_daemon.py index 81fcd5805..528fd80ec 100644 --- a/ptodsl/tests/test_tilelib_daemon.py +++ b/ptodsl/tests/test_tilelib_daemon.py @@ -8,12 +8,18 @@ """End-to-end tests for the PTODSL TileLib daemon's Unix-socket RPC.""" import os +import socket +import stat import tempfile import threading import unittest from ptodsl.tilelib.serving.client import DaemonClient, DaemonError -from ptodsl.tilelib.serving.daemon import TileLibDaemonServer +from ptodsl.tilelib.serving.daemon import ( + TileLibDaemonServer, + _remove_socket_path, +) +from ptodsl.tilelib.serving.wire import MAX_MESSAGE_SIZE, recv_message def _tile_spec(dtype="f32", shape=(8, 64)): @@ -62,6 +68,10 @@ def tearDown(self): def test_ping(self): self.assertEqual(self.client.ping(), "pong") + def test_socket_is_accessible_only_by_owner(self): + mode = stat.S_IMODE(os.stat(self.socket_path).st_mode) + self.assertEqual(mode, 0o600) + def test_instantiate_named_candidate_returns_structured_mlir(self): mlir = self.client.instantiate( "a5", @@ -159,6 +169,30 @@ def test_cache_key_includes_context_attributes(self): ) self.assertEqual(self.client.get_stats()["misses"], 2) + def test_oversized_wire_message_is_rejected_before_payload_read(self): + receiver, sender = socket.socketpair() + self.addCleanup(receiver.close) + self.addCleanup(sender.close) + sender.sendall((MAX_MESSAGE_SIZE + 1).to_bytes(4, byteorder="big")) + + with self.assertRaisesRegex(ValueError, "exceeds limit"): + recv_message(receiver) + + def test_socket_cleanup_removes_broken_symlink(self): + missing_target = os.path.join( + self._temporary_directory.name, + "missing.sock", + ) + broken_link = os.path.join( + self._temporary_directory.name, + "broken.sock", + ) + os.symlink(missing_target, broken_link) + + _remove_socket_path(broken_link) + + self.assertFalse(os.path.lexists(broken_link)) + def test_non_tile_operand_is_rejected_explicitly(self): operands = list(TADD_OPERANDS) operands[0] = {"kind": "scalar", "dtype": "f32", "value": 1.0}