From 4533059c8fbe5ce183b8826f227180ff36bb2320 Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Wed, 1 Jul 2026 21:01:01 +0800 Subject: [PATCH 01/54] feat(ptodsl): support launching directly from the pto ir file --- .../03-kernel-entry-and-subkernels.md | 92 ++++++++ ptodsl/ptodsl/_diagnostics.py | 75 +++++++ ptodsl/ptodsl/_jit.py | 20 ++ ptodsl/ptodsl/_kernel_compilation.py | 39 +++- ptodsl/ptodsl/_runtime/native_build.py | 16 ++ ptodsl/ptodsl/_source_loader.py | 179 +++++++++++++++ ptodsl/ptodsl/_tracing/artifacts.py | 10 +- ptodsl/ptodsl/_tracing/module_builder.py | 1 + .../tests/support/docs_fragment_fixtures.py | 21 ++ ptodsl/tests/test_docs_as_test.py | 73 ++++-- ptodsl/tests/test_jit_compile.py | 210 +++++++++++++++++- ptodsl/tests/test_jit_diagnostics.py | 195 ++++++++++++++++ .../cases/micro-op/a5-extra/vmadd/compare.py | 59 ----- .../cases/micro-op/a5-extra/vmadd/golden.py | 40 ---- .../cases/micro-op/a5-extra/vmadd/kernel.py | 76 +++++++ .../cases/micro-op/a5-extra/vmadd/launch.cpp | 45 ---- .../cases/micro-op/a5-extra/vmadd/main.cpp | 120 ---------- test/vpto/scripts/run_host_vpto_validation.sh | 96 ++++++-- .../run_host_vpto_validation_parallel.sh | 48 ++-- 19 files changed, 1086 insertions(+), 329 deletions(-) create mode 100644 ptodsl/ptodsl/_source_loader.py delete mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/compare.py delete mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/golden.py create mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py delete mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp delete mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 7bb02647ca..0c6150ef5c 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -224,6 +224,98 @@ compiled[grid, stream](A.ctypes.data, O.ctypes.data, 4, 128) **Only `entry=True` kernels support `.compile()` and `[grid, stream]` launch.** Calling `.compile()` on an `entry=False` module raises an error. +### Loading an existing PTO file + +Use `source=` when the kernel implementation already exists as a hand-written +PTO file, but you still want PTODSL's Python compile and launch workflow. The +decorated Python function declares the host-side ABI; the PTO file provides the +kernel body. + + +```python +@pto.jit( + name="tadd_entry", + target="a5", + backend="vpto", + source="kernels/tadd_entry.pto", +) +def tadd( + A_ptr: pto.ptr(pto.f32, "gm"), + B_ptr: pto.ptr(pto.f32, "gm"), + O_ptr: pto.ptr(pto.f32, "gm"), + numel: pto.i32, +): + # The body is ignored when source= is provided. + pass + + +compiled = tadd.compile() +compiled[grid, stream](A, B, O, numel) +``` + +The Python function body is not traced in this form. Keep it empty, or leave a +short comment for readers. Positional parameters still matter: PTODSL uses them +to build the launch wrapper and marshal Python, NumPy, or torch-npu arguments. + +The PTO file must contain one non-declaration `func.func` whose symbol matches +the JIT entry name. By default, the entry name is the Python function name. Use +`name=` when the PTO symbol has a different name, or when the file contains more +than one kernel: + +```mlir +module { + func.func @tadd_entry( + %a: !pto.ptr, + %b: !pto.ptr, + %o: !pto.ptr, + %numel: i32) { + // hand-written PTO body + return + } +} +``` + +PTODSL checks the selected PTO function before compiling: + +- The number of PTO function arguments must match the Python positional + parameters. +- Each argument type must match the Python annotation, position by position. +- The PTO entry must return no values. +- If the file or entry cannot be found, the diagnostic names the requested entry + and source path. + +`source` is a filesystem path. Relative paths are resolved from the Python file +that declares the decorated function, so tests can keep the Python wrapper next +to the PTO file: + +```text +case.py +kernels/tadd_entry.pto +``` + +```python +# case.py +@pto.jit(name="tadd_entry", source="kernels/tadd_entry.pto") +def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")): + pass +``` + +Source-backed entries use the same `.compile()` and `compiled[grid, stream](...)` +launch syntax as ordinary traced entries. If the PTO file contents change, +compiling the same declaration again rebuilds the cached artifact. + +Limitations: + +- `source=` is only supported for launchable `entry=True` kernels. +- Keyword-only `pto.const_expr` parameters are not supported with `source=`. + Source files are loaded as fixed PTO IR text; PTODSL does not template or + specialize the source file. +- `.compile(...)` does not accept constexpr bindings for source-backed kernels. +- `backend=`, `mode=`, and `insert_sync=` still matter. For source-backed VPTO + files, set `backend="vpto"`. When `mode="explicit"`, PTODSL compiles the + source as explicit PTO; otherwise sync insertion follows the same policy as + ordinary `@pto.jit` entries. + ### SPMD built-ins Available inside an `entry=True` body: diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py index de1f303d58..572b1d5950 100644 --- a/ptodsl/ptodsl/_diagnostics.py +++ b/ptodsl/ptodsl/_diagnostics.py @@ -164,6 +164,75 @@ def kernel_module_launch_error(function_name: str | None = None) -> RuntimeError ) +def jit_source_entry_false_error( + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for unsupported ``@pto.jit(entry=False, source=...)``.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + return TypeError( + f"{target} does not support entry=False while source={source!r}. " + "Source-backed JIT is currently limited to launchable entry kernels." + ) + + +def jit_source_constexpr_error( + name: str, + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for unsupported source-backed ``pto.const_expr`` params.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + return TypeError( + f"{target} does not support keyword-only pto.const_expr parameter '{name}' while source={source!r}. " + "Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text." + ) + + +def jit_source_compile_constexpr_error( + names: list[str] | tuple[str, ...], + source: object, + *, + function_name: str | None = None, +) -> TypeError: + """Return one diagnostic for ``.compile(...)`` constexpr bindings in source mode.""" + target = "@pto.jit(source=...) kernel" + if function_name: + target = f"@pto.jit(source=...) kernel {function_name!r}" + joined = ", ".join(names) + return TypeError( + f"{target} does not accept .compile(...) constexpr binding(s) {joined} while source={source!r}. " + "Source-backed JIT currently loads a fixed PTO IR file and does not template or specialize source text." + ) + + +def jit_source_file_error(source: object, resolved_path: object, reason: str) -> FileNotFoundError: + """Return one diagnostic for source path resolution/loading failures.""" + return FileNotFoundError( + f"@pto.jit(source={source!r}) could not load PTO IR source file {str(resolved_path)!r}: {reason}" + ) + + +def jit_source_entry_error(source_path: object, entry_name: str, reason: str) -> TypeError: + """Return one diagnostic for source entry selection failures.""" + return TypeError( + f"@pto.jit(source=...) could not bind entry {entry_name!r} in {str(source_path)!r}: {reason}" + ) + + +def jit_source_abi_error(source_path: object, entry_name: str, reason: str) -> TypeError: + """Return one diagnostic for source ABI verification failures.""" + return TypeError( + f"@pto.jit(source=...) ABI mismatch for entry {entry_name!r} in {str(source_path)!r}: {reason}" + ) + + def jit_keyword_only_non_constexpr_error(name: str, annotation: object) -> TypeError: """Return one diagnostic for keyword-only params that are not ``pto.const_expr``.""" return TypeError( @@ -419,6 +488,12 @@ def unsupported_public_surface_error(name: str) -> AttributeError: "kernel_module_compile_error", "kernel_module_launch_error", "kernel_module_return_value_error", + "jit_source_abi_error", + "jit_source_compile_constexpr_error", + "jit_source_constexpr_error", + "jit_source_entry_false_error", + "jit_source_entry_error", + "jit_source_file_error", "invalid_jit_mode_error", "invalid_jit_backend_error", "jit_legacy_tensor_spec_helper_error", diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py index 5c23bc97f0..fc86a21b9f 100644 --- a/ptodsl/ptodsl/_jit.py +++ b/ptodsl/ptodsl/_jit.py @@ -15,6 +15,8 @@ from ._diagnostics import ( invalid_jit_backend_error, invalid_jit_mode_error, + jit_source_constexpr_error, + jit_source_entry_false_error, kernel_module_launch_error, ) from ._kernel_compilation import CompiledKernelHandle, KernelCompiler @@ -169,6 +171,7 @@ def jit( insert_sync: bool | None = None, ast_rewrite: bool | None = None, frontend_options: Mapping | None = None, + source: str | None = None, ): """ Decorator that wraps a Python function as a PTODSL JIT kernel template. @@ -195,6 +198,11 @@ def jit( frontend_options: Reserved structured frontend options. Currently supports ``ast_rewrite`` and ``rewrite_part={"control_flow"}``. + source: + Optional filesystem path to a PTO IR source file. When + provided, PTODSL keeps the Python signature as the host ABI + declaration but loads the kernel implementation from the + source file instead of tracing the Python body. The decorated function is replaced by a :class:`KernelHandle` that: @@ -214,6 +222,17 @@ def decorator(fn): kernel_signature = parse_jit_kernel_signature(fn, entry=entry) normalized_mode = _normalize_mode(mode, fn=fn) normalized_backend = _normalize_backend(backend, fn=fn) + if source is not None: + if not isinstance(source, str): + raise TypeError("@pto.jit source must be a filesystem path string when provided") + if entry is False: + raise jit_source_entry_false_error(source, function_name=fn_name) + if kernel_signature.constexpr_parameters: + raise jit_source_constexpr_error( + kernel_signature.constexpr_parameters[0].name, + source, + function_name=fn_name, + ) source_file = None try: source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) @@ -232,6 +251,7 @@ def decorator(fn): module_style=ModuleStyle.BACKEND_PARTITIONED, source_file=source_file, source_line=getattr(fn.__code__, "co_firstlineno", None), + jit_source=source, ), kernel_signature, fn, diff --git a/ptodsl/ptodsl/_kernel_compilation.py b/ptodsl/ptodsl/_kernel_compilation.py index 1029d3888d..3d2217ee44 100644 --- a/ptodsl/ptodsl/_kernel_compilation.py +++ b/ptodsl/ptodsl/_kernel_compilation.py @@ -12,8 +12,13 @@ import inspect from ._ast_rewrite import rewrite_jit_function -from ._diagnostics import kernel_module_compile_error, kernel_module_launch_error +from ._diagnostics import ( + jit_source_compile_constexpr_error, + kernel_module_compile_error, + kernel_module_launch_error, +) from ._runtime.launch import LaunchHandle, parse_launch_spec +from ._source_loader import SourceModuleLoader from ._tracing import ModuleArtifact, SignatureTracingRuntime @@ -89,6 +94,8 @@ def tracing_callback(self): def compile(self, **constexpr_bindings): if self._module_spec.entry is False: raise kernel_module_compile_error(self._py_name) + if self._module_spec.jit_source is not None: + return self._compile_source_backed(**constexpr_bindings) normalized_bindings = self._kernel_signature.bind_constexpr_bindings(constexpr_bindings) kernel_identity = self._kernel_identity if self._ast_rewrite: @@ -124,6 +131,36 @@ def compile(self, **constexpr_bindings): self._compiled_cache[specialization_key] = compiled return compiled + def _compile_source_backed(self, **constexpr_bindings): + if constexpr_bindings: + raise jit_source_compile_constexpr_error( + tuple(sorted(constexpr_bindings)), + self._module_spec.jit_source, + function_name=self._module_spec.function_name, + ) + + loader = SourceModuleLoader(self._module_spec, self._kernel_signature) + specialization_key = self._kernel_signature.specialization_key( + loader.cache_identity(), + {}, + ) + + cached = self._compiled_cache.get(specialization_key) + if cached is not None: + return cached + + compiled = CompiledKernelHandle( + self._py_name, + specialization_key=specialization_key, + constexpr_bindings={}, + module_factory=loader.build_module, + module_spec=self._module_spec, + kernel_signature=self._kernel_signature, + ) + compiled.build() + self._compiled_cache[specialization_key] = compiled + return compiled + def cached_specializations(self): return tuple(self._compiled_cache.values()) diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index 777fa54201..0821c326be 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -45,12 +45,18 @@ def _run_ptoas( *, target_arch: str, insert_sync: bool | None = None, + backend: str | None = None, + pto_level: str | None = None, ) -> None: ptoas = resolve_ptoas_binary() cmd = [ str(ptoas), f"--pto-arch={target_arch}", ] + if backend is not None: + cmd.append(f"--pto-backend={backend}") + if pto_level is not None: + cmd.append(f"--pto-level={pto_level}") if insert_sync is True: cmd.append("--enable-insert-sync") cmd.extend([ @@ -70,6 +76,15 @@ def _effective_insert_sync(*, mode: str, insert_sync: bool | None) -> bool: return mode != "explicit" +def _source_ptoas_overrides(module_spec) -> dict: + if getattr(module_spec, "jit_source", None) is None: + return {} + overrides = {"backend": module_spec.backend} + if module_spec.mode == "explicit": + overrides["pto_level"] = "level3" + return overrides + + def _host_compile_flags() -> list[str]: return common_include_flags() + [ "-std=gnu++17", @@ -199,6 +214,7 @@ def build_native_library( mode=module_spec.mode, insert_sync=module_spec.insert_sync, ), + **_source_ptoas_overrides(module_spec), ) launch_object = artifacts.cache_dir / "launch.o" diff --git a/ptodsl/ptodsl/_source_loader.py b/ptodsl/ptodsl/_source_loader.py new file mode 100644 index 0000000000..834c1dbaf6 --- /dev/null +++ b/ptodsl/ptodsl/_source_loader.py @@ -0,0 +1,179 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Source-backed module loading for ``@pto.jit(source=...)``.""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path + +from ._bootstrap import make_context +from ._diagnostics import ( + jit_source_abi_error, + jit_source_entry_error, + jit_source_file_error, +) + +from mlir.ir import Location, Module + + +@dataclass(frozen=True) +class SourceModuleArtifact: + """Parsed source-backed module plus identity metadata.""" + + module: Module + mlir_text: str + resolved_path: Path + content_digest: str + + +class SourceModuleLoader: + """Resolve, parse, and verify one source-backed JIT module.""" + + def __init__(self, module_spec, kernel_signature): + self._module_spec = module_spec + self._kernel_signature = kernel_signature + self._artifact: SourceModuleArtifact | None = None + + @property + def source(self) -> str: + source = self._module_spec.jit_source + if source is None: + raise RuntimeError("source-backed loader requires KernelModuleSpec.jit_source") + return source + + def cache_identity(self) -> tuple: + """Return source identity for the specialization key.""" + artifact = self._load() + return ( + "source", + str(artifact.resolved_path), + artifact.content_digest, + self._module_spec.function_name, + self._module_spec.entry, + self._module_spec.target_arch, + self._module_spec.kernel_kind, + self._module_spec.backend, + self._module_spec.mode, + self._module_spec.insert_sync, + self._module_spec.module_style, + ) + + def build_module(self): + """Return ``(module, metadata)`` for ``ModuleArtifact``.""" + artifact = self._load() + return artifact.module, { + "mlir_text": artifact.mlir_text, + "source_path": str(artifact.resolved_path), + "source_digest": artifact.content_digest, + } + + def _load(self) -> SourceModuleArtifact: + if self._artifact is None: + resolved_path = self._resolve_source_path() + mlir_text = self._read_source_text(resolved_path) + content_digest = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest() + ctx = make_context() + with ctx, Location.unknown(): + module = Module.parse(mlir_text) + entry = self._select_entry(module, resolved_path) + self._verify_entry_abi(entry, resolved_path) + module.operation.verify() + self._artifact = SourceModuleArtifact( + module=module, + mlir_text=mlir_text, + resolved_path=resolved_path, + content_digest=content_digest, + ) + return self._artifact + + def _resolve_source_path(self) -> Path: + raw_path = Path(self.source) + if raw_path.is_absolute(): + return raw_path.resolve() + declaring_file = self._module_spec.source_file + if declaring_file: + return (Path(declaring_file).resolve().parent / raw_path).resolve() + return raw_path.resolve() + + def _read_source_text(self, resolved_path: Path) -> str: + try: + return resolved_path.read_text(encoding="utf-8") + except FileNotFoundError as exc: + raise jit_source_file_error(self.source, resolved_path, "file does not exist") from exc + except OSError as exc: + raise jit_source_file_error(self.source, resolved_path, str(exc)) from exc + + def _select_entry(self, module: Module, resolved_path: Path): + matches = [] + for op in _walk_ops(module.operation): + if op.operation.name != "func.func": + continue + if getattr(op, "is_external", False): + continue + if _symbol_name(op) == self._module_spec.function_name: + matches.append(op) + + if not matches: + raise jit_source_entry_error( + resolved_path, + self._module_spec.function_name, + "missing non-declaration func.func with this symbol name", + ) + if len(matches) > 1: + raise jit_source_entry_error( + resolved_path, + self._module_spec.function_name, + f"found {len(matches)} matching non-declaration func.func ops", + ) + return matches[0] + + def _verify_entry_abi(self, entry, resolved_path: Path) -> None: + expected = tuple(str(type_obj) for type_obj in self._kernel_signature.compute_entry_arg_types()) + actual = tuple(str(type_obj) for type_obj in entry.type.inputs) + results = tuple(str(type_obj) for type_obj in entry.type.results) + + if results: + raise jit_source_abi_error( + resolved_path, + self._module_spec.function_name, + f"source entry must return no values, got ({', '.join(results)})", + ) + if len(actual) != len(expected): + raise jit_source_abi_error( + resolved_path, + self._module_spec.function_name, + "parameter count differs; " + f"expected ({', '.join(expected)}), got ({', '.join(actual)})", + ) + for index, (expected_type, actual_type) in enumerate(zip(expected, actual)): + if expected_type != actual_type: + raise jit_source_abi_error( + resolved_path, + self._module_spec.function_name, + f"parameter {index} differs; expected {expected_type}, got {actual_type}", + ) + + +def _symbol_name(op) -> str | None: + attrs = op.attributes + if "sym_name" not in attrs: + return None + return str(attrs["sym_name"]).strip('"') + + +def _walk_ops(root_op): + for region in root_op.regions: + for block in region.blocks: + for op in block.operations: + yield op + yield from _walk_ops(op.operation) + + +__all__ = ["SourceModuleLoader"] diff --git a/ptodsl/ptodsl/_tracing/artifacts.py b/ptodsl/ptodsl/_tracing/artifacts.py index 14a50ec4a2..5dc51e3feb 100644 --- a/ptodsl/ptodsl/_tracing/artifacts.py +++ b/ptodsl/ptodsl/_tracing/artifacts.py @@ -19,10 +19,11 @@ class ModuleArtifact: Subclasses may either pass an eager ``module`` or a lazy ``module_factory``. """ - def __init__(self, py_name: str, *, module=None, module_factory=None): + def __init__(self, py_name: str, *, module=None, module_factory=None, mlir_text: str | None = None): self._py_name = py_name self._cached_module = module self._module_factory = module_factory + self._cached_mlir_text = mlir_text self._build_metadata = {} def build(self): @@ -34,6 +35,8 @@ def build(self): if isinstance(built, tuple): self._cached_module, metadata = built self._build_metadata = dict(metadata or {}) + if "mlir_text" in self._build_metadata: + self._cached_mlir_text = self._build_metadata["mlir_text"] else: self._cached_module = built self._build_metadata = {} @@ -45,7 +48,10 @@ def mlir_module(self): def mlir_text(self) -> str: """Return the textual MLIR form.""" - return str(self.build()) + self.build() + if self._cached_mlir_text is not None: + return self._cached_mlir_text + return str(self._cached_module) def verify(self) -> None: """Verify the cached module operation.""" diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py index f15f62e398..f4108724c5 100644 --- a/ptodsl/ptodsl/_tracing/module_builder.py +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -38,6 +38,7 @@ class KernelModuleSpec: module_style: ModuleStyle = ModuleStyle.NESTED source_file: str | None = None source_line: int | None = None + jit_source: str | None = None def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): diff --git a/ptodsl/tests/support/docs_fragment_fixtures.py b/ptodsl/tests/support/docs_fragment_fixtures.py index 788cf31786..db833700c2 100644 --- a/ptodsl/tests/support/docs_fragment_fixtures.py +++ b/ptodsl/tests/support/docs_fragment_fixtures.py @@ -369,6 +369,27 @@ def kernel_name( assert record.marshaled_arg_count == 4 """ ), + "launch.source_backed_tadd": _fixture( + f""" + grid = 2 + stream = object() + A = 0x1000 + B = 0x2000 + O = 0x3000 + numel = 128 + + {SNIPPET_PLACEHOLDER} + + assert compiled.ir_function_name == "tadd_entry" + assert compiled.build_metadata()["source_path"].endswith("kernels/tadd_entry.pto") + assert len(PTODSL_DOC_LAUNCH_RECORDS) == 1 + record = PTODSL_DOC_LAUNCH_RECORDS[0] + assert record.grid == grid + assert record.stream is stream + assert record.args == (A, B, O, numel) + assert record.marshaled_arg_count == 4 + """ + ), "launch.mat_add_wrapper": _fixture( f""" import numpy as np diff --git a/ptodsl/tests/test_docs_as_test.py b/ptodsl/tests/test_docs_as_test.py index 27a6910bec..60fc852fba 100644 --- a/ptodsl/tests/test_docs_as_test.py +++ b/ptodsl/tests/test_docs_as_test.py @@ -69,6 +69,7 @@ class DocTestDirective: symbol: str | None = None compile_kwargs: dict[str, object] | None = None fixture: str | None = None + files: dict[str, str] | None = None @dataclass(frozen=True) @@ -257,11 +258,25 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: symbol = payload.get("symbol") compile_kwargs = payload.get("compile") fixture = payload.get("fixture") + files = payload.get("files") expect( isinstance(mode, str) and mode, f"{block_label(block)}: ptodsl-doc-test metadata must define a non-empty string 'mode'", ) + if files is not None: + expect( + isinstance(files, dict) and all(isinstance(path, str) and isinstance(text, str) for path, text in files.items()), + f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " + "ptodsl-doc-test metadata 'files' must be an object mapping relative file paths to text", + ) + for path in files: + file_path = Path(path) + expect( + not file_path.is_absolute() and ".." not in file_path.parts, + f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " + f"ptodsl-doc-test metadata file path must be relative and stay inside the snippet directory: {path!r}", + ) if mode in ("compile", "compile_fragment"): expect( isinstance(symbol, str) and symbol, @@ -282,8 +297,9 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: symbol=symbol, compile_kwargs=compile_kwargs, fixture=fixture, + files=files, ) - return DocTestDirective(mode=mode, symbol=symbol, compile_kwargs=compile_kwargs) + return DocTestDirective(mode=mode, symbol=symbol, compile_kwargs=compile_kwargs, files=files) if mode == "launch_fragment": expect( @@ -300,7 +316,7 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: f"{block_label(block, symbol if isinstance(symbol, str) and symbol else None)}: " "ptodsl-doc-test launch_fragment does not accept a 'compile' object; the snippet owns its compile/launch flow", ) - return DocTestDirective(mode=mode, symbol=symbol, fixture=fixture) + return DocTestDirective(mode=mode, symbol=symbol, fixture=fixture, files=files) expect( False, @@ -310,23 +326,46 @@ def parse_test_directive(block: MarkdownCodeBlock) -> DocTestDirective: return DocTestDirective(mode=mode) +def _write_directive_files(snippet_dir: Path, files: dict[str, str] | None) -> None: + if not files: + return + for relative_path, text in files.items(): + output_path = snippet_dir / relative_path + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(text, encoding="utf-8") + + +@contextmanager +def directive_execution_dir(directive: DocTestDirective): + if not directive.files: + yield None + return + + with tempfile.TemporaryDirectory() as temp_dir: + snippet_dir = Path(temp_dir) + _write_directive_files(snippet_dir, directive.files) + yield snippet_dir + + def execute_source( source: str, block: MarkdownCodeBlock, symbol: str | None = None, *, extra_namespace: dict[str, object] | None = None, + source_dir: Path | None = None, ) -> dict[str, object]: + source_file = block.path if source_dir is None else source_dir / "case.py" namespace: dict[str, object] = { "__builtins__": __builtins__, "__name__": "__ptodsl_doc_snippet__", - "__file__": str(block.path), + "__file__": str(source_file), "pto": pto, "scalar": scalar, } if extra_namespace is not None: namespace.update(extra_namespace) - filename = f"{block.path}::codeblock:{block.start_line}" + filename = f"{source_file}::codeblock:{block.start_line}" source_lines = source.splitlines(keepends=True) linecache.cache[filename] = (len(source), None, source_lines, filename) try: @@ -409,8 +448,9 @@ def verify_compiled_target( def run_compile_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: directive = parse_test_directive(block) - namespace = execute_source(block.text, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) + with directive_execution_dir(directive) as source_dir: + namespace = execute_source(block.text, block, directive.symbol, source_dir=source_dir) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: @@ -429,8 +469,9 @@ def run_compile_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> Non raise AssertionError( f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" ) from exc - namespace = execute_source(rendered_source, block, directive.symbol) - verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) + with directive_execution_dir(directive) as source_dir: + namespace = execute_source(rendered_source, block, directive.symbol, source_dir=source_dir) + verify_compiled_target(block, directive, namespace, ptoas_bin, frontend_verify=False) def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None: @@ -450,13 +491,15 @@ def run_launch_fragment_block(block: MarkdownCodeBlock, ptoas_bin: Path) -> None f"{block_label(block, directive.symbol)}: fragment fixture {directive.fixture!r} is invalid: {exc}" ) from exc - with capture_launch_records() as launch_records: - execute_source( - rendered_source, - block, - directive.symbol, - extra_namespace={"PTODSL_DOC_LAUNCH_RECORDS": launch_records}, - ) + with directive_execution_dir(directive) as source_dir: + with capture_launch_records() as launch_records: + execute_source( + rendered_source, + block, + directive.symbol, + extra_namespace={"PTODSL_DOC_LAUNCH_RECORDS": launch_records}, + source_dir=source_dir, + ) expect( bool(launch_records), diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index 2157acdb1a..ad086dd56d 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -12,6 +12,7 @@ import re import sys from tempfile import TemporaryDirectory +from importlib.util import module_from_spec, spec_from_file_location from unittest import mock @@ -3117,6 +3118,160 @@ def main() -> None: pointer_block64.specialization_key.constexpr_signature == (("BLOCK", 64),), "pointer-first specialization key should change only with constexpr bindings", ) + source_native_build_compiled = None + source_explicit_native_build_compiled = None + source_no_insert_sync_native_build_compiled = None + with TemporaryDirectory() as tmpdir: + source_path = Path(tmpdir) / "source_kernel.pto" + source_text_v1 = ( + "module {\n" + " func.func @selected_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + " func.func @other_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_text_v2 = ( + "module {\n" + " func.func @selected_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " %c0 = arith.constant 0 : i32\n" + " return\n" + " }\n" + " func.func @other_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_path.write_text(source_text_v1, encoding="utf-8") + + @pto.jit(name="selected_source_entry", target="a5", source=str(source_path)) + def source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_compiled_v1 = source_backed_probe.compile() + source_compiled_v1_again = source_backed_probe.compile() + expect(source_compiled_v1 is source_compiled_v1_again, "unchanged source-backed JIT should hit specialization cache") + expect( + source_compiled_v1.mlir_text() == source_text_v1, + "source-backed JIT mlir_text() should preserve the authored source text", + ) + expect( + source_compiled_v1.ir_function_name == "selected_source_entry", + "source-backed JIT should use name= for entry selection and launch wrapper naming", + ) + expect( + source_compiled_v1.build_metadata()["source_path"] == str(source_path.resolve()), + "source-backed JIT metadata should expose the resolved source path", + ) + + source_path.write_text(source_text_v2, encoding="utf-8") + source_compiled_v2 = source_backed_probe.compile() + expect( + source_compiled_v2 is not source_compiled_v1, + "editing the source file should materialize a new specialization", + ) + expect( + source_compiled_v2.specialization_key != source_compiled_v1.specialization_key, + "source-backed specialization key should include source content", + ) + expect( + source_compiled_v2.mlir_text() == source_text_v2, + "source-backed JIT should reload changed source text", + ) + source_auto_path = Path(tmpdir) / "source_auto.pto" + source_auto_text = ( + "module {\n" + " func.func @source_auto_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_auto_path.write_text(source_auto_text, encoding="utf-8") + + @pto.jit(name="source_auto_native", target="a5", backend="vpto", source=str(source_auto_path)) + def source_auto_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_native_build_compiled = source_auto_native.compile() + + source_explicit_path = Path(tmpdir) / "source_explicit.pto" + source_explicit_text = ( + "module {\n" + " func.func @source_explicit_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_explicit_path.write_text(source_explicit_text, encoding="utf-8") + + @pto.jit( + name="source_explicit_native", + target="a5", + backend="vpto", + mode="explicit", + source=str(source_explicit_path), + ) + def source_explicit_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_explicit_native_build_compiled = source_explicit_native.compile() + + source_no_insert_sync_path = Path(tmpdir) / "source_no_insert_sync.pto" + source_no_insert_sync_text = ( + "module {\n" + " func.func @source_no_insert_sync_native(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + source_no_insert_sync_path.write_text(source_no_insert_sync_text, encoding="utf-8") + + @pto.jit( + name="source_no_insert_sync_native", + target="a5", + backend="vpto", + insert_sync=False, + source=str(source_no_insert_sync_path), + ) + def source_no_insert_sync_native(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + source_no_insert_sync_native_build_compiled = source_no_insert_sync_native.compile() + + relative_case_dir = Path(tmpdir) / "relative_case" + relative_case_dir.mkdir() + relative_source_path = relative_case_dir / "relative_kernel.pto" + relative_source_text = ( + "module {\n" + " func.func @relative_source_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + relative_source_path.write_text(relative_source_text, encoding="utf-8") + relative_module_path = relative_case_dir / "relative_case.py" + relative_module_path.write_text( + "from ptodsl import pto\n" + "@pto.jit(name='relative_source_entry', target='a5', source='relative_kernel.pto')\n" + "def relative_kernel(ptr: pto.ptr(pto.f32, 'gm')):\n" + " raise RuntimeError('source-backed JIT should not trace the Python body')\n", + encoding="utf-8", + ) + spec = spec_from_file_location("ptodsl_relative_source_case", relative_module_path) + expect(spec is not None and spec.loader is not None, "relative source test module should be importable") + relative_module = module_from_spec(spec) + spec.loader.exec_module(relative_module) + relative_compiled = relative_module.relative_kernel.compile() + expect( + relative_compiled.mlir_text() == relative_source_text, + "relative source paths should resolve against the declaring Python file", + ) + expect( + relative_compiled.build_metadata()["source_path"] == str(relative_source_path.resolve()), + "relative source-backed metadata should expose the resolved source path", + ) pointer_artifacts_default = artifact_paths( pointer_default._py_name, pointer_default.ir_function_name, @@ -3481,6 +3636,9 @@ def main() -> None: ("pure-container", host_vec_copy.compile()), ("same-backend-multi-child-container", kernel_module_compiled), ("mixed-backend-container", emitc_entry_calls_vpto_kernel_module_probe.compile()), + ("source-auto", source_native_build_compiled), + ("source-explicit", source_explicit_native_build_compiled), + ("source-no-insert-sync", source_no_insert_sync_native_build_compiled), ) native_build_observations = [] @@ -3498,13 +3656,15 @@ def fake_artifacts(py_name, ir_function_name, specialization_key): manifest_path=cache_dir / "manifest.json", ) - def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None): + def fake_run_ptoas(mlir_path, kernel_object, *, target_arch, insert_sync=None, backend=None, pto_level=None): native_build_observations.append( { "mlir_path": mlir_path, "kernel_object": kernel_object, "target_arch": target_arch, "insert_sync": insert_sync, + "backend": backend, + "pto_level": pto_level, "mlir_text": mlir_path.read_text(encoding="utf-8"), } ) @@ -3561,14 +3721,25 @@ def fake_link_shared_library(launch_object, kernel_object, shared_library, *, ke observation["insert_sync"] == expected_insert_sync, f"{label} native build should forward the effective insert_sync policy to ptoas", ) + expected_backend = compiled._module_spec.backend if compiled._module_spec.jit_source is not None else None + expected_pto_level = "level3" if compiled._module_spec.jit_source is not None and compiled._module_spec.mode == "explicit" else None expect( - observation["mlir_text"] == compiled.mlir_text(), - f"{label} native build should hand the backend-partitioned container MLIR to ptoas unchanged", + observation["backend"] == expected_backend, + f"{label} native build should only forward ptoas backend overrides for source-backed kernels", + ) + expect( + observation["pto_level"] == expected_pto_level, + f"{label} native build should only map explicit mode to ptoas level3 for source-backed kernels", ) expect( - observation["mlir_text"].count("module") >= 2, - f"{label} native build should route the unified outer+child container through ptoas", + observation["mlir_text"] == compiled.mlir_text(), + f"{label} native build should hand the backend-partitioned container MLIR to ptoas unchanged", ) + if compiled._module_spec.jit_source is None: + expect( + observation["mlir_text"].count("module") >= 2, + f"{label} native build should route the unified outer+child container through ptoas", + ) with TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) mlir_path = tmpdir_path / "kernel.mlir" @@ -3625,6 +3796,32 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "--enable-insert-sync" in ptoas_cmds[0], "native build should pass --enable-insert-sync when the compiled module explicitly requests it", ) + ptoas_cmds.clear() + with mock.patch.object(native_build_runtime, "resolve_ptoas_binary", return_value=Path("/tmp/fake-ptoas")), mock.patch.object( + native_build_runtime, "_run", side_effect=fake_run_ptoas_cmd + ): + native_build_runtime._run_ptoas( + mlir_path, + kernel_object, + target_arch="a5", + backend="vpto", + pto_level="level3", + insert_sync=True, + ) + expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command with source-backed overrides") + source_ptoas_cmd = ptoas_cmds[0] + expect( + "--pto-backend=vpto" in source_ptoas_cmd, + "source-backed native build should pass the decorator backend to ptoas", + ) + expect( + "--pto-level=level3" in source_ptoas_cmd, + "source-backed explicit mode should pass --pto-level=level3 to ptoas", + ) + expect( + "--enable-insert-sync" in source_ptoas_cmd, + "source-backed native build should still pass explicit/effective insert-sync to ptoas", + ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") auto_mode_violation = expect_raises( RuntimeError, @@ -5004,6 +5201,9 @@ def _enter_inline_simt_with_resource_attr(): launch_handle = block64[1, None] expect(callable(launch_handle), "compiled[grid, stream] should return a launch callable") expect(hasattr(launch_handle, "__call__"), "launch handle should support __call__") + source_launch_handle = source_native_build_compiled[1, None] + expect(callable(source_launch_handle), "source-backed compiled[grid, stream] should return a launch callable") + expect(hasattr(source_launch_handle, "__call__"), "source-backed launch handle should support __call__") print("ptodsl_jit_compile: PASS") os._exit(0) diff --git a/ptodsl/tests/test_jit_diagnostics.py b/ptodsl/tests/test_jit_diagnostics.py index 96ebcd4a37..135f0e296e 100644 --- a/ptodsl/tests/test_jit_diagnostics.py +++ b/ptodsl/tests/test_jit_diagnostics.py @@ -9,6 +9,7 @@ from pathlib import Path import sys +from tempfile import TemporaryDirectory sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "ptodsl")) @@ -404,6 +405,31 @@ def bad_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): return bad_probe +def define_source_non_string_probe(): + @pto.jit(target="a5", source=123) + def bad_probe(rows: pto.i32): + _ = rows + + return bad_probe + + +def define_source_entry_false_probe(): + @pto.jit(target="a5", entry=False, source="kernel.pto") + def bad_probe(tile: pto.Tile): + pto.pipe_barrier(pto.Pipe.ALL) + + return bad_probe + + +def define_source_constexpr_probe(): + @pto.jit(target="a5", source="kernel.pto") + def bad_probe(ptr: pto.ptr(pto.f32, "gm"), *, BLOCK: pto.const_expr = 8): + _ = ptr + _ = BLOCK + + return bad_probe + + @pto.jit(target="a5") def regular_entry_probe(rows: pto.i32): _ = rows @@ -788,6 +814,175 @@ def main() -> None: "mask", "do not belong at this kernel-module boundary", ) + expect_raises( + define_source_non_string_probe, + TypeError, + "@pto.jit source must be a filesystem path string when provided", + ) + expect_raises( + define_source_entry_false_probe, + TypeError, + "@pto.jit(source=...) kernel 'bad_probe' does not support entry=False while source='kernel.pto'", + "Source-backed JIT is currently limited to launchable entry kernels", + ) + expect_raises( + define_source_constexpr_probe, + TypeError, + "@pto.jit(source=...) kernel 'bad_probe' does not support keyword-only pto.const_expr parameter 'BLOCK' while source='kernel.pto'", + "does not template or specialize source text", + ) + with TemporaryDirectory() as tmpdir: + source_dir = Path(tmpdir) + missing_path = source_dir / "missing.pto" + + @pto.jit(target="a5", source=str(missing_path)) + def source_missing_file_probe(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + source_missing_file_probe.compile, + FileNotFoundError, + "@pto.jit(source=", + "missing.pto", + "file does not exist", + ) + + missing_entry_path = source_dir / "missing_entry.pto" + missing_entry_path.write_text( + "module {\n" + " func.func @other_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(name="wanted_entry", target="a5", source=str(missing_entry_path)) + def source_missing_entry_probe(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + source_missing_entry_probe.compile, + TypeError, + "could not bind entry 'wanted_entry'", + "missing non-declaration func.func", + ) + + ambiguous_entry_path = source_dir / "ambiguous_entry.pto" + ambiguous_entry_path.write_text( + "module {\n" + " func.func @ambiguous_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + " builtin.module {\n" + " func.func @ambiguous_entry(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(ambiguous_entry_path)) + def ambiguous_entry(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + ambiguous_entry.compile, + TypeError, + "could not bind entry 'ambiguous_entry'", + "found 2 matching non-declaration func.func ops", + ) + + count_mismatch_path = source_dir / "count_mismatch.pto" + count_mismatch_path.write_text( + "module {\n" + " func.func @count_mismatch(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(count_mismatch_path)) + def count_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + count_mismatch.compile, + TypeError, + "ABI mismatch for entry 'count_mismatch'", + "parameter count differs", + "expected (!pto.ptr, i32)", + "got (!pto.ptr)", + ) + + type_mismatch_path = source_dir / "type_mismatch.pto" + type_mismatch_path.write_text( + "module {\n" + " func.func @type_mismatch(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(type_mismatch_path)) + def type_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + type_mismatch.compile, + TypeError, + "ABI mismatch for entry 'type_mismatch'", + "parameter 0 differs", + "expected !pto.ptr", + "got !pto.ptr", + ) + + non_void_path = source_dir / "non_void.pto" + non_void_path.write_text( + "module {\n" + " func.func @non_void(%arg0: !pto.ptr) -> i32 {\n" + " %c0 = arith.constant 0 : i32\n" + " return %c0 : i32\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(non_void_path)) + def non_void(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + non_void.compile, + TypeError, + "ABI mismatch for entry 'non_void'", + "source entry must return no values", + "i32", + ) + + compile_constexpr_path = source_dir / "compile_constexpr.pto" + compile_constexpr_path.write_text( + "module {\n" + " func.func @compile_constexpr(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n", + encoding="utf-8", + ) + + @pto.jit(target="a5", source=str(compile_constexpr_path)) + def compile_constexpr(ptr: pto.ptr(pto.f32, "gm")): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + lambda: compile_constexpr.compile(BLOCK=8), + TypeError, + "@pto.jit(source=...) kernel 'compile_constexpr' does not accept .compile(...) constexpr binding(s) BLOCK", + "does not template or specialize source text", + ) kernel_module_return_value_probe = define_kernel_module_return_value_probe() expect_raises( kernel_module_return_value_probe.compile, diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py b/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py deleted file mode 100644 index baed02d3bd..0000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/compare.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import os -import sys - -import numpy as np - - -def compare_float(golden_path: str, output_path: str, label: str, atol: float) -> bool: - if not os.path.exists(golden_path) or not os.path.exists(output_path): - print(f"[ERROR] missing file for {label}") - return False - golden = np.fromfile(golden_path, dtype=np.float32) - output = np.fromfile(output_path, dtype=np.float32) - ok = golden.shape == output.shape and np.allclose( - golden, output, atol=atol, rtol=atol, equal_nan=True - ) - if not ok: - diff = np.max(np.abs(golden.astype(np.float64) - output.astype(np.float64))) - print(f"[ERROR] compare failed: {label}, max_abs_diff={diff}") - return ok - - -def compare_exact(golden_path: str, output_path: str, dtype, label: str) -> bool: - if not os.path.exists(golden_path) or not os.path.exists(output_path): - print(f"[ERROR] missing file for {label}") - return False - golden = np.fromfile(golden_path, dtype=dtype) - output = np.fromfile(output_path, dtype=dtype) - ok = golden.shape == output.shape and np.array_equal(golden, output) - if not ok: - mismatch = np.flatnonzero(golden != output) - first = int(mismatch[0]) if mismatch.size else -1 - print( - f"[ERROR] compare failed: {label}, first_mismatch={first}, " - f"golden={golden[first] if first >= 0 else 'n/a'}, " - f"output={output[first] if first >= 0 else 'n/a'}" - ) - return ok - - -def main() -> None: - checks = [ - compare_float("golden_vmadd.bin", "out_vmadd.bin", "vmadd", 2e-4), - ] - if not all(checks): - sys.exit(2) - print("[INFO] compare passed (a5 extra vmadd)") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py b/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py deleted file mode 100644 index 276b6c82cd..0000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/golden.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import argparse -from pathlib import Path - -import numpy as np - -ELEMS = 1024 -SEED = 29 - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--output-dir", type=Path, default=Path(".")) - parser.add_argument("--seed", type=int, default=SEED) - args = parser.parse_args() - out = args.output_dir - out.mkdir(parents=True, exist_ok=True) - rng = np.random.default_rng(args.seed) - - f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) - f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) - f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) - - f_acc.tofile(out / "f_acc.bin") - f_lhs.tofile(out / "f_lhs.bin") - f_rhs.tofile(out / "f_rhs.bin") - - (f_lhs * f_acc + f_rhs).astype(np.float32).tofile(out / "golden_vmadd.bin") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py b/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py new file mode 100644 index 0000000000..a5323de78a --- /dev/null +++ b/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +import numpy as np + + +def _bootstrap_dsl_st_common() -> None: + here = Path(__file__).resolve() + for candidate in here.parents: + common_dir = candidate / "test" / "dsl-st" + if (common_dir / "common.py").exists(): + sys.path.insert(0, str(common_dir)) + return + raise RuntimeError("Unable to locate test/dsl-st/common.py from vmadd kernel.py") + + +_bootstrap_dsl_st_common() + +from common import auto_main, golden_output_case +from ptodsl import pto + + +ELEMS = 1024 +SEED = 29 + + +@pto.jit( + name="a5_extra_vmadd_kernel", + target="a5", + backend="vpto", + mode="explicit", + source="kernel.pto", +) +def a5_extra_vmadd_kernel( + f_acc: pto.ptr(pto.f32, "gm"), + f_lhs: pto.ptr(pto.f32, "gm"), + f_rhs: pto.ptr(pto.f32, "gm"), + out_vmadd: pto.ptr(pto.f32, "gm"), +): + pass + + +def make_inputs(): + rng = np.random.default_rng(SEED) + f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) + f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) + return [f_acc, f_lhs, f_rhs] + + +def make_expected(f_acc, f_lhs, f_rhs): + return (f_lhs * f_acc + f_rhs).astype(np.float32) + + +CASES = [ + golden_output_case( + "a5_extra_vmadd", + a5_extra_vmadd_kernel, + inputs=make_inputs, + expected=make_expected, + rtol=2e-4, + atol=2e-4, + ), +] + + +auto_main(globals()) diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp b/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp deleted file mode 100644 index 91edee7f64..0000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/launch.cpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif -#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) -typedef struct { unsigned char v; } hifloat8_t; -typedef struct { unsigned char v; } float8_e4m3_t; -typedef struct { unsigned char v; } float8_e5m2_t; -typedef struct { unsigned char v; } float8_e8m0_t; -typedef struct { unsigned char v; } float4_e1m2x2_t; -typedef struct { unsigned char v; } float4_e2m1x2_t; -#endif -#include -#if defined(__CCE_AICORE__) && defined(PTOAS_ENABLE_CCE_PRINT) -#include -#endif -#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) -struct MrgSortExecutedNumList { - uint16_t mrgSortList0; - uint16_t mrgSortList1; - uint16_t mrgSortList2; - uint16_t mrgSortList3; -}; -#endif -#ifndef __CPU_SIM -#include "acl/acl.h" -#endif - -extern "C" __global__ [aicore] void a5_extra_vmadd_kernel( - __gm__ float *f_acc, __gm__ float *f_lhs, __gm__ float *f_rhs, - __gm__ float *out_vmadd); - -void LaunchA5ExtraVmadd(float *p0, float *p1, float *p2, float *p3, - void *stream) { - a5_extra_vmadd_kernel<<<1, nullptr, stream>>>( - (__gm__ float *)p0, (__gm__ float *)p1, (__gm__ float *)p2, - (__gm__ float *)p3); -} diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp b/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp deleted file mode 100644 index d5a3b781d2..0000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/main.cpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include -#include - -using namespace PtoTestCommon; - -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ - (int)_ret, __FILE__, __LINE__); \ - const char *_recent = aclGetRecentErrMsg(); \ - if (_recent != nullptr && _recent[0] != '\0') \ - std::fprintf(stderr, "[ERROR] RecentErrMsg: %s\n", _recent); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -#define FILE_CHECK(expr, path) \ - do { \ - if (!(expr)) { \ - std::fprintf(stderr, "[ERROR] file operation failed: %s (%s:%d)\n", \ - path, __FILE__, __LINE__); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -void LaunchA5ExtraVmadd(float *p0, float *p1, float *p2, float *p3, - void *stream); - -struct Buffer { - const char *path; - size_t size; - bool input; - void *host; - void *device; -}; - -int main() { - constexpr size_t kElems = 1024; - constexpr size_t kF32Bytes = kElems * sizeof(float); - - Buffer bufs[] = { - {"./f_acc.bin", kF32Bytes, true, nullptr, nullptr}, - {"./f_lhs.bin", kF32Bytes, true, nullptr, nullptr}, - {"./f_rhs.bin", kF32Bytes, true, nullptr, nullptr}, - {"./out_vmadd.bin", kF32Bytes, false, nullptr, nullptr}, - }; - - int rc = 0; - bool aclInited = false; - bool deviceSet = false; - int deviceId = 0; - aclrtStream stream = nullptr; - - ACL_CHECK(aclInit(nullptr)); - aclInited = true; - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) - deviceId = std::atoi(envDevice); - ACL_CHECK(aclrtSetDevice(deviceId)); - deviceSet = true; - ACL_CHECK(aclrtCreateStream(&stream)); - - for (Buffer &buf : bufs) { - ACL_CHECK(aclrtMallocHost(&buf.host, buf.size)); - ACL_CHECK(aclrtMalloc(&buf.device, buf.size, ACL_MEM_MALLOC_HUGE_FIRST)); - if (!buf.input) - continue; - size_t fileSize = buf.size; - FILE_CHECK(ReadFile(buf.path, fileSize, buf.host, buf.size) && - fileSize == buf.size, - buf.path); - ACL_CHECK(aclrtMemcpy(buf.device, buf.size, buf.host, buf.size, - ACL_MEMCPY_HOST_TO_DEVICE)); - } - - LaunchA5ExtraVmadd( - static_cast(bufs[0].device), static_cast(bufs[1].device), - static_cast(bufs[2].device), - static_cast(bufs[3].device), - stream); - - ACL_CHECK(aclrtSynchronizeStream(stream)); - - for (Buffer &buf : bufs) { - if (buf.input) - continue; - ACL_CHECK(aclrtMemcpy(buf.host, buf.size, buf.device, buf.size, - ACL_MEMCPY_DEVICE_TO_HOST)); - FILE_CHECK(WriteFile(buf.path, buf.host, buf.size), buf.path); - } - -cleanup: - for (Buffer &buf : bufs) { - if (buf.device != nullptr) - aclrtFree(buf.device); - if (buf.host != nullptr) - aclrtFreeHost(buf.host); - } - if (stream != nullptr) - aclrtDestroyStream(stream); - if (deviceSet) - aclrtResetDevice(deviceId); - if (aclInited) - aclFinalize(); - return rc; -} diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh index 709b29747f..34ce0b25f7 100755 --- a/test/vpto/scripts/run_host_vpto_validation.sh +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -25,6 +25,7 @@ CASE_NAME="${CASE_NAME:-}" DEVICE="${DEVICE:-SIM}" SIM_LIB_DIR="${SIM_LIB_DIR:-}" COMPILE_ONLY="${COMPILE_ONLY:-0}" +PTODSL_SIM_SOC_VERSION="${PTODSL_SIM_SOC_VERSION:-Ascend950PR_9599}" log() { echo "[$(date +'%F %T')] $*" @@ -94,8 +95,6 @@ resolve_sim_lib_dir() { die "SIM_LIB_DIR is required for DEVICE=SIM and no dav_3510 simulator lib dir was found under: ${ASCEND_HOME_PATH}" } -resolve_sim_lib_dir - BISHENG_BIN="${BISHENG_BIN:-${ASCEND_HOME_PATH}/bin/bisheng}" command -v "${BISHENG_BIN}" >/dev/null 2>&1 || die "bisheng not found: ${BISHENG_BIN}" @@ -104,13 +103,25 @@ command -v python3 >/dev/null 2>&1 || die "python3 not found" mkdir -p "${WORK_SPACE}" WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" +is_ptodsl_case_dir() { + [[ -f "$1/kernel.py" ]] +} + +validate_case_dir() { + local case_name="$1" + local case_dir="$2" + + [[ -f "${case_dir}/kernel.pto" ]] || + die "case ${case_name} must provide kernel.pto" + if is_ptodsl_case_dir "${case_dir}"; then + return 0 + fi + for f in launch.cpp main.cpp golden.py compare.py; do + [[ -f "${case_dir}/${f}" ]] || die "case ${case_name} is missing ${f}" + done +} + discover_cases() { - local required_files=( - launch.cpp - main.cpp - golden.py - compare.py - ) local onboard_only_prefix="onboard-only/" if [[ -n "${CASE_NAME}" ]]; then @@ -121,26 +132,24 @@ discover_cases() { fi local requested_dir="${CASES_ROOT}/${CASE_NAME}" [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - for f in "${required_files[@]}"; do - [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" - done - [[ -f "${requested_dir}/kernel.pto" ]] || - die "case ${CASE_NAME} must provide kernel.pto" + validate_case_dir "${CASE_NAME}" "${requested_dir}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - local ok=1 - for f in "${required_files[@]}"; do - if [[ ! -f "${dir}/${f}" ]]; then - ok=0 - break - fi - done - [[ "${ok}" -eq 1 ]] || continue [[ -f "${dir}/kernel.pto" ]] || continue local rel="${dir#${CASES_ROOT}/}" + if ! is_ptodsl_case_dir "${dir}"; then + local ok=1 + for f in launch.cpp main.cpp golden.py compare.py; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + fi if [[ "${DEVICE}" == "SIM" && "${COMPILE_ONLY}" != "1" && "${rel}" == "${onboard_only_prefix}"* ]]; then continue @@ -157,6 +166,19 @@ fi readarray -t CASES < <(discover_cases) [[ "${#CASES[@]}" -gt 0 ]] || die "no cases found under ${CASES_ROOT}" +needs_legacy_sim_lib=0 +if [[ "${DEVICE}" == "SIM" ]]; then + for case_name in "${CASES[@]}"; do + if ! is_ptodsl_case_dir "${CASES_ROOT}/${case_name}"; then + needs_legacy_sim_lib=1 + break + fi + done +fi +if [[ "${needs_legacy_sim_lib}" == "1" ]]; then + resolve_sim_lib_dir +fi + case_output_token() { printf '%s' "$1" | sed 's#[/[:space:]]#_#g' } @@ -246,6 +268,29 @@ build_host_executable() { -lstdc++ -lascendcl -lm -ltiling_api -lplatform -lc_sec -ldl -lnnopbase } +run_ptodsl_case() { + local case_name="$1" + local case_dir="$2" + local out_dir="$3" + + log "[$case_name] run PTODSL source-backed case" + ( + cd "${out_dir}" + export PTODSL_CACHE_DIR="${out_dir}/ptodsl-cache" + export PATH="$(dirname "${PTOAS_BIN}"):${PATH}" + if [[ "${DEVICE}" == "SIM" ]]; then + "${ROOT_DIR}/scripts/sim_dsl.sh" \ + --soc-version "${PTODSL_SIM_SOC_VERSION}" \ + --output "${out_dir}/msprof" \ + "${case_dir}/kernel.py" + else + export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" + python3 "${case_dir}/kernel.py" + fi + ) + log "[$case_name] output dir: ${out_dir}" +} + build_one_impl() { local case_name="$1" local case_dir="${CASES_ROOT}/${case_name}" @@ -257,12 +302,17 @@ build_one_impl() { local kernel_so="${out_dir}/lib${case_token}_kernel.so" local -a ptoas_args=() + [[ -f "${case_dir}/kernel.pto" ]] || + die "missing kernel.pto for ${case_name}" + if is_ptodsl_case_dir "${case_dir}"; then + run_ptodsl_case "${case_name}" "${case_dir}" "${out_dir}" + return 0 + fi + [[ -f "${case_dir}/main.cpp" ]] || die "missing main.cpp for ${case_name}" [[ -f "${case_dir}/launch.cpp" ]] || die "missing launch.cpp for ${case_name}" [[ -f "${case_dir}/golden.py" ]] || die "missing golden.py for ${case_name}" [[ -f "${case_dir}/compare.py" ]] || die "missing compare.py for ${case_name}" - [[ -f "${case_dir}/kernel.pto" ]] || - die "missing kernel.pto for ${case_name}" if [[ -f "${case_dir}/ptoas.flags" ]]; then read -r -a ptoas_args < "${case_dir}/ptoas.flags" diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh index 98be7d669d..bffa991bef 100755 --- a/test/vpto/scripts/run_host_vpto_validation_parallel.sh +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -74,13 +74,25 @@ WORK_SPACE="$(cd "${WORK_SPACE}" && pwd)" SUMMARY_FILE="${WORK_SPACE}/parallel-summary.tsv" RUNNER_LOG="${WORK_SPACE}/parallel-runner.log" +is_ptodsl_case_dir() { + [[ -f "$1/kernel.py" ]] +} + +validate_case_dir() { + local case_name="$1" + local case_dir="$2" + + [[ -f "${case_dir}/kernel.pto" ]] || + die "case ${case_name} must provide kernel.pto" + if is_ptodsl_case_dir "${case_dir}"; then + return 0 + fi + for f in launch.cpp main.cpp golden.py compare.py; do + [[ -f "${case_dir}/${f}" ]] || die "case ${case_name} is missing ${f}" + done +} + discover_cases() { - local required_files=( - launch.cpp - main.cpp - golden.py - compare.py - ) local onboard_only_prefix="onboard-only/" if [[ -n "${CASE_NAME}" ]]; then @@ -90,26 +102,24 @@ discover_cases() { fi local requested_dir="${CASES_ROOT}/${CASE_NAME}" [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - for f in "${required_files[@]}"; do - [[ -f "${requested_dir}/${f}" ]] || die "case ${CASE_NAME} is missing ${f}" - done - [[ -f "${requested_dir}/kernel.pto" ]] || - die "case ${CASE_NAME} must provide kernel.pto" + validate_case_dir "${CASE_NAME}" "${requested_dir}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - local ok=1 - for f in "${required_files[@]}"; do - if [[ ! -f "${dir}/${f}" ]]; then - ok=0 - break - fi - done - [[ "${ok}" -eq 1 ]] || continue [[ -f "${dir}/kernel.pto" ]] || continue local rel="${dir#${CASES_ROOT}/}" + if ! is_ptodsl_case_dir "${dir}"; then + local ok=1 + for f in launch.cpp main.cpp golden.py compare.py; do + if [[ ! -f "${dir}/${f}" ]]; then + ok=0 + break + fi + done + [[ "${ok}" -eq 1 ]] || continue + fi if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && "${rel}" == "${onboard_only_prefix}"* ]]; then continue From 7e1736bec0126f8b3093878721d2ee69ee7051dd Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 2 Jul 2026 00:08:10 +0800 Subject: [PATCH 02/54] feat(ptodsl): support using string as pto.jit source input --- .../03-kernel-entry-and-subkernels.md | 33 +++++-- ptodsl/ptodsl/_source_loader.py | 50 ++++++++--- ptodsl/tests/test_jit_compile.py | 27 ++++++ ptodsl/tests/test_jit_diagnostics.py | 20 +++++ .../a5-extra/{vmadd/kernel.pto => vmadd.py} | 88 +++++++++++++++++-- .../cases/micro-op/a5-extra/vmadd/kernel.py | 76 ---------------- test/vpto/scripts/run_host_vpto_validation.sh | 80 +++++++++++++---- .../run_host_vpto_validation_parallel.sh | 49 ++++++++--- 8 files changed, 288 insertions(+), 135 deletions(-) rename test/vpto/cases/micro-op/a5-extra/{vmadd/kernel.pto => vmadd.py} (51%) delete mode 100644 test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md index 0c6150ef5c..61e4b8e50a 100644 --- a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -226,10 +226,10 @@ Calling `.compile()` on an `entry=False` module raises an error. ### Loading an existing PTO file -Use `source=` when the kernel implementation already exists as a hand-written -PTO file, but you still want PTODSL's Python compile and launch workflow. The -decorated Python function declares the host-side ABI; the PTO file provides the -kernel body. +Use `source=` when the kernel implementation already exists as hand-written PTO +IR, but you still want PTODSL's Python compile and launch workflow. The +decorated Python function declares the host-side ABI; `source=` provides the +kernel body either as a PTO file path or as PTO text directly. ```python @@ -257,9 +257,9 @@ The Python function body is not traced in this form. Keep it empty, or leave a short comment for readers. Positional parameters still matter: PTODSL uses them to build the launch wrapper and marshal Python, NumPy, or torch-npu arguments. -The PTO file must contain one non-declaration `func.func` whose symbol matches +The PTO source must contain one non-declaration `func.func` whose symbol matches the JIT entry name. By default, the entry name is the Python function name. Use -`name=` when the PTO symbol has a different name, or when the file contains more +`name=` when the PTO symbol has a different name, or when the source contains more than one kernel: ```mlir @@ -284,9 +284,9 @@ PTODSL checks the selected PTO function before compiling: - If the file or entry cannot be found, the diagnostic names the requested entry and source path. -`source` is a filesystem path. Relative paths are resolved from the Python file -that declares the decorated function, so tests can keep the Python wrapper next -to the PTO file: +When `source` is a filesystem path, relative paths are resolved from the Python +file that declares the decorated function, so tests can keep the Python wrapper +next to the PTO file: ```text case.py @@ -300,6 +300,21 @@ def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")): pass ``` +For short tests, `source` can also embed the PTO text directly: + +```python +tadd_source = """module { + func.func @tadd_entry(%a: !pto.ptr, %o: !pto.ptr) { + return + } +} +""" + +@pto.jit(name="tadd_entry", source=tadd_source) +def tadd(A_ptr: pto.ptr(pto.f32, "gm"), O_ptr: pto.ptr(pto.f32, "gm")): + pass +``` + Source-backed entries use the same `.compile()` and `compiled[grid, stream](...)` launch syntax as ordinary traced entries. If the PTO file contents change, compiling the same declaration again rebuilds the cached artifact. diff --git a/ptodsl/ptodsl/_source_loader.py b/ptodsl/ptodsl/_source_loader.py index 834c1dbaf6..b936937ddf 100644 --- a/ptodsl/ptodsl/_source_loader.py +++ b/ptodsl/ptodsl/_source_loader.py @@ -29,7 +29,8 @@ class SourceModuleArtifact: module: Module mlir_text: str - resolved_path: Path + resolved_path: Path | None + source_kind: str content_digest: str @@ -68,16 +69,18 @@ def cache_identity(self) -> tuple: def build_module(self): """Return ``(module, metadata)`` for ``ModuleArtifact``.""" artifact = self._load() - return artifact.module, { + metadata = { "mlir_text": artifact.mlir_text, - "source_path": str(artifact.resolved_path), + "source_kind": artifact.source_kind, "source_digest": artifact.content_digest, } + if artifact.resolved_path is not None: + metadata["source_path"] = str(artifact.resolved_path) + return artifact.module, metadata def _load(self) -> SourceModuleArtifact: if self._artifact is None: - resolved_path = self._resolve_source_path() - mlir_text = self._read_source_text(resolved_path) + resolved_path, mlir_text, source_kind = self._resolve_source() content_digest = hashlib.sha256(mlir_text.encode("utf-8")).hexdigest() ctx = make_context() with ctx, Location.unknown(): @@ -89,6 +92,7 @@ def _load(self) -> SourceModuleArtifact: module=module, mlir_text=mlir_text, resolved_path=resolved_path, + source_kind=source_kind, content_digest=content_digest, ) return self._artifact @@ -102,6 +106,20 @@ def _resolve_source_path(self) -> Path: return (Path(declaring_file).resolve().parent / raw_path).resolve() return raw_path.resolve() + def _looks_like_inline_source(self, source: str) -> bool: + stripped = source.lstrip() + if "\n" in source or "\r" in source: + return True + return stripped.startswith("module {") or stripped.startswith("builtin.module {") or ( + "module {" in stripped and "func.func @" in stripped + ) + + def _resolve_source(self) -> tuple[Path | None, str, str]: + if self._looks_like_inline_source(self.source): + return None, self.source, "inline" + resolved_path = self._resolve_source_path() + return resolved_path, self._read_source_text(resolved_path), "path" + def _read_source_text(self, resolved_path: Path) -> str: try: return resolved_path.read_text(encoding="utf-8") @@ -110,7 +128,7 @@ def _read_source_text(self, resolved_path: Path) -> str: except OSError as exc: raise jit_source_file_error(self.source, resolved_path, str(exc)) from exc - def _select_entry(self, module: Module, resolved_path: Path): + def _select_entry(self, module: Module, resolved_path: Path | None): matches = [] for op in _walk_ops(module.operation): if op.operation.name != "func.func": @@ -122,32 +140,33 @@ def _select_entry(self, module: Module, resolved_path: Path): if not matches: raise jit_source_entry_error( - resolved_path, + _source_location_label(self.source, resolved_path), self._module_spec.function_name, "missing non-declaration func.func with this symbol name", ) if len(matches) > 1: raise jit_source_entry_error( - resolved_path, + _source_location_label(self.source, resolved_path), self._module_spec.function_name, f"found {len(matches)} matching non-declaration func.func ops", ) return matches[0] - def _verify_entry_abi(self, entry, resolved_path: Path) -> None: + def _verify_entry_abi(self, entry, resolved_path: Path | None) -> None: + source_label = _source_location_label(self.source, resolved_path) expected = tuple(str(type_obj) for type_obj in self._kernel_signature.compute_entry_arg_types()) actual = tuple(str(type_obj) for type_obj in entry.type.inputs) results = tuple(str(type_obj) for type_obj in entry.type.results) if results: raise jit_source_abi_error( - resolved_path, + source_label, self._module_spec.function_name, f"source entry must return no values, got ({', '.join(results)})", ) if len(actual) != len(expected): raise jit_source_abi_error( - resolved_path, + source_label, self._module_spec.function_name, "parameter count differs; " f"expected ({', '.join(expected)}), got ({', '.join(actual)})", @@ -155,7 +174,7 @@ def _verify_entry_abi(self, entry, resolved_path: Path) -> None: for index, (expected_type, actual_type) in enumerate(zip(expected, actual)): if expected_type != actual_type: raise jit_source_abi_error( - resolved_path, + source_label, self._module_spec.function_name, f"parameter {index} differs; expected {expected_type}, got {actual_type}", ) @@ -176,4 +195,11 @@ def _walk_ops(root_op): yield from _walk_ops(op.operation) +def _source_location_label(source: str, resolved_path: Path | None): + if resolved_path is not None: + return resolved_path + digest = hashlib.sha256(source.encode("utf-8")).hexdigest()[:12] + return f"" + + __all__ = ["SourceModuleLoader"] diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index ad086dd56d..b3ab3694c6 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -3272,6 +3272,33 @@ def source_no_insert_sync_native(ptr: pto.ptr(pto.f32, "gm")): relative_compiled.build_metadata()["source_path"] == str(relative_source_path.resolve()), "relative source-backed metadata should expose the resolved source path", ) + + inline_source_text = ( + "module {\n" + " func.func @inline_source_entry(%arg0: !pto.ptr, %arg1: i32) {\n" + " return\n" + " }\n" + "}\n" + ) + + @pto.jit(name="inline_source_entry", target="a5", source=inline_source_text) + def inline_source_backed_probe(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + inline_compiled = inline_source_backed_probe.compile() + expect( + inline_compiled.mlir_text() == inline_source_text, + "inline source-backed JIT mlir_text() should preserve the authored source text", + ) + inline_metadata = inline_compiled.build_metadata() + expect( + inline_metadata["source_kind"] == "inline", + "inline source-backed JIT metadata should mark the source kind as inline", + ) + expect( + "source_path" not in inline_metadata, + "inline source-backed JIT metadata should not synthesize a filesystem path", + ) pointer_artifacts_default = artifact_paths( pointer_default._py_name, pointer_default.ir_function_name, diff --git a/ptodsl/tests/test_jit_diagnostics.py b/ptodsl/tests/test_jit_diagnostics.py index 135f0e296e..7c1d344bcd 100644 --- a/ptodsl/tests/test_jit_diagnostics.py +++ b/ptodsl/tests/test_jit_diagnostics.py @@ -983,6 +983,26 @@ def compile_constexpr(ptr: pto.ptr(pto.f32, "gm")): "@pto.jit(source=...) kernel 'compile_constexpr' does not accept .compile(...) constexpr binding(s) BLOCK", "does not template or specialize source text", ) + + inline_count_mismatch_source = ( + "module {\n" + " func.func @inline_count_mismatch(%arg0: !pto.ptr) {\n" + " return\n" + " }\n" + "}\n" + ) + + @pto.jit(target="a5", source=inline_count_mismatch_source) + def inline_count_mismatch(ptr: pto.ptr(pto.f32, "gm"), rows: pto.i32): + raise RuntimeError("source-backed JIT should not trace the Python body") + + expect_raises( + inline_count_mismatch.compile, + TypeError, + "ABI mismatch for entry 'inline_count_mismatch'", + "parameter count differs", + "} { +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +import numpy as np + + +def _bootstrap_dsl_st_common() -> None: + here = Path(__file__).resolve() + for candidate in here.parents: + common_dir = candidate / "test" / "dsl-st" + if (common_dir / "common.py").exists(): + sys.path.insert(0, str(common_dir)) + return + raise RuntimeError("Unable to locate test/dsl-st/common.py from vmadd.py") + + +_bootstrap_dsl_st_common() + +from common import auto_main, golden_output_case +from ptodsl import pto + + +ELEMS = 1024 +SEED = 29 + +VMADD_SOURCE = """module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { func.func @a5_extra_vmadd_kernel( %f_acc: !pto.ptr, %f_lhs: !pto.ptr, %f_rhs: !pto.ptr, %out_vmadd: !pto.ptr) @@ -57,3 +83,47 @@ return } } +""" + + +@pto.jit( + name="a5_extra_vmadd_kernel", + target="a5", + backend="vpto", + mode="explicit", + source=VMADD_SOURCE, +) +def a5_extra_vmadd_kernel( + f_acc: pto.ptr(pto.f32, "gm"), + f_lhs: pto.ptr(pto.f32, "gm"), + f_rhs: pto.ptr(pto.f32, "gm"), + out_vmadd: pto.ptr(pto.f32, "gm"), +): + pass + + +def make_inputs(): + rng = np.random.default_rng(SEED) + f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) + f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) + return [f_acc, f_lhs, f_rhs] + + +def make_expected(f_acc, f_lhs, f_rhs): + return (f_lhs * f_acc + f_rhs).astype(np.float32) + + +CASES = [ + golden_output_case( + "a5_extra_vmadd", + a5_extra_vmadd_kernel, + inputs=make_inputs, + expected=make_expected, + rtol=2e-4, + atol=2e-4, + ), +] + + +auto_main(globals()) diff --git a/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py b/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py deleted file mode 100644 index a5323de78a..0000000000 --- a/test/vpto/cases/micro-op/a5-extra/vmadd/kernel.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -from pathlib import Path -import sys - -import numpy as np - - -def _bootstrap_dsl_st_common() -> None: - here = Path(__file__).resolve() - for candidate in here.parents: - common_dir = candidate / "test" / "dsl-st" - if (common_dir / "common.py").exists(): - sys.path.insert(0, str(common_dir)) - return - raise RuntimeError("Unable to locate test/dsl-st/common.py from vmadd kernel.py") - - -_bootstrap_dsl_st_common() - -from common import auto_main, golden_output_case -from ptodsl import pto - - -ELEMS = 1024 -SEED = 29 - - -@pto.jit( - name="a5_extra_vmadd_kernel", - target="a5", - backend="vpto", - mode="explicit", - source="kernel.pto", -) -def a5_extra_vmadd_kernel( - f_acc: pto.ptr(pto.f32, "gm"), - f_lhs: pto.ptr(pto.f32, "gm"), - f_rhs: pto.ptr(pto.f32, "gm"), - out_vmadd: pto.ptr(pto.f32, "gm"), -): - pass - - -def make_inputs(): - rng = np.random.default_rng(SEED) - f_acc = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) - f_lhs = rng.uniform(-3.0, 3.0, size=ELEMS).astype(np.float32) - f_rhs = rng.uniform(-1.0, 1.0, size=ELEMS).astype(np.float32) - return [f_acc, f_lhs, f_rhs] - - -def make_expected(f_acc, f_lhs, f_rhs): - return (f_lhs * f_acc + f_rhs).astype(np.float32) - - -CASES = [ - golden_output_case( - "a5_extra_vmadd", - a5_extra_vmadd_kernel, - inputs=make_inputs, - expected=make_expected, - rtol=2e-4, - atol=2e-4, - ), -] - - -auto_main(globals()) diff --git a/test/vpto/scripts/run_host_vpto_validation.sh b/test/vpto/scripts/run_host_vpto_validation.sh index 34ce0b25f7..fa530911af 100755 --- a/test/vpto/scripts/run_host_vpto_validation.sh +++ b/test/vpto/scripts/run_host_vpto_validation.sh @@ -107,18 +107,47 @@ is_ptodsl_case_dir() { [[ -f "$1/kernel.py" ]] } -validate_case_dir() { +is_ptodsl_case_file() { + local case_path="$1" + local base_name + base_name="$(basename "${case_path}")" + [[ -f "${case_path}" ]] && + [[ "${case_path}" == *.py ]] && + [[ "${base_name}" != "golden.py" ]] && + [[ "${base_name}" != "compare.py" ]] && + [[ "${base_name}" != "kernel.py" ]] && + [[ "${base_name}" != _* ]] +} + +is_ptodsl_case_path() { + is_ptodsl_case_dir "$1" || is_ptodsl_case_file "$1" +} + +ptodsl_case_script() { + local case_path="$1" + if is_ptodsl_case_dir "${case_path}"; then + printf '%s\n' "${case_path}/kernel.py" + return 0 + fi + if is_ptodsl_case_file "${case_path}"; then + printf '%s\n' "${case_path}" + return 0 + fi + die "path is not a PTODSL case: ${case_path}" +} + +validate_case_path() { local case_name="$1" - local case_dir="$2" + local case_path="$2" - [[ -f "${case_dir}/kernel.pto" ]] || - die "case ${case_name} must provide kernel.pto" - if is_ptodsl_case_dir "${case_dir}"; then + if is_ptodsl_case_path "${case_path}"; then return 0 fi + [[ -d "${case_path}" ]] || die "case ${case_name} is neither a directory nor a PTODSL case file" for f in launch.cpp main.cpp golden.py compare.py; do - [[ -f "${case_dir}/${f}" ]] || die "case ${case_name} is missing ${f}" + [[ -f "${case_path}/${f}" ]] || die "case ${case_name} is missing ${f}" done + [[ -f "${case_path}/kernel.pto" ]] || die "case ${case_name} must provide kernel.pto" } discover_cases() { @@ -130,15 +159,15 @@ discover_cases() { "${CASE_NAME}" == "${onboard_only_prefix}"* ]]; then die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi - local requested_dir="${CASES_ROOT}/${CASE_NAME}" - [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - validate_case_dir "${CASE_NAME}" "${requested_dir}" + local requested_path="${CASES_ROOT}/${CASE_NAME}" + [[ -e "${requested_path}" ]] || die "unknown case: ${CASE_NAME}" + validate_case_path "${CASE_NAME}" "${requested_path}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - [[ -f "${dir}/kernel.pto" ]] || continue + [[ -f "${dir}/kernel.pto" || -f "${dir}/kernel.py" ]] || continue local rel="${dir#${CASES_ROOT}/}" if ! is_ptodsl_case_dir "${dir}"; then local ok=1 @@ -156,6 +185,16 @@ discover_cases() { fi printf "%s\n" "${rel}" done + + find "${CASES_ROOT}" -type f -name '*.py' | sort | while read -r path; do + is_ptodsl_case_file "${path}" || continue + local rel="${path#${CASES_ROOT}/}" + if [[ "${DEVICE}" == "SIM" && "${COMPILE_ONLY}" != "1" && + "${rel}" == "${onboard_only_prefix}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done } if [[ "${DEVICE}" == "SIM" && "${COMPILE_ONLY}" != "1" && @@ -169,7 +208,7 @@ readarray -t CASES < <(discover_cases) needs_legacy_sim_lib=0 if [[ "${DEVICE}" == "SIM" ]]; then for case_name in "${CASES[@]}"; do - if ! is_ptodsl_case_dir "${CASES_ROOT}/${case_name}"; then + if ! is_ptodsl_case_path "${CASES_ROOT}/${case_name}"; then needs_legacy_sim_lib=1 break fi @@ -270,8 +309,10 @@ build_host_executable() { run_ptodsl_case() { local case_name="$1" - local case_dir="$2" + local case_path="$2" local out_dir="$3" + local case_script + case_script="$(ptodsl_case_script "${case_path}")" log "[$case_name] run PTODSL source-backed case" ( @@ -282,10 +323,10 @@ run_ptodsl_case() { "${ROOT_DIR}/scripts/sim_dsl.sh" \ --soc-version "${PTODSL_SIM_SOC_VERSION}" \ --output "${out_dir}/msprof" \ - "${case_dir}/kernel.py" + "${case_script}" else export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/lib64:${LD_LIBRARY_PATH:-}" - python3 "${case_dir}/kernel.py" + python3 "${case_script}" fi ) log "[$case_name] output dir: ${out_dir}" @@ -293,7 +334,7 @@ run_ptodsl_case() { build_one_impl() { local case_name="$1" - local case_dir="${CASES_ROOT}/${case_name}" + local case_path="${CASES_ROOT}/${case_name}" local case_token case_token="$(case_output_token "${case_name}")" local out_dir="${WORK_SPACE}/${case_token}" @@ -302,12 +343,13 @@ build_one_impl() { local kernel_so="${out_dir}/lib${case_token}_kernel.so" local -a ptoas_args=() - [[ -f "${case_dir}/kernel.pto" ]] || - die "missing kernel.pto for ${case_name}" - if is_ptodsl_case_dir "${case_dir}"; then - run_ptodsl_case "${case_name}" "${case_dir}" "${out_dir}" + if is_ptodsl_case_path "${case_path}"; then + run_ptodsl_case "${case_name}" "${case_path}" "${out_dir}" return 0 fi + local case_dir="${case_path}" + + [[ -f "${case_dir}/kernel.pto" ]] || die "missing kernel.pto for ${case_name}" [[ -f "${case_dir}/main.cpp" ]] || die "missing main.cpp for ${case_name}" [[ -f "${case_dir}/launch.cpp" ]] || die "missing launch.cpp for ${case_name}" diff --git a/test/vpto/scripts/run_host_vpto_validation_parallel.sh b/test/vpto/scripts/run_host_vpto_validation_parallel.sh index bffa991bef..2bb116972a 100755 --- a/test/vpto/scripts/run_host_vpto_validation_parallel.sh +++ b/test/vpto/scripts/run_host_vpto_validation_parallel.sh @@ -78,18 +78,34 @@ is_ptodsl_case_dir() { [[ -f "$1/kernel.py" ]] } -validate_case_dir() { +is_ptodsl_case_file() { + local case_path="$1" + local base_name + base_name="$(basename "${case_path}")" + [[ -f "${case_path}" ]] && + [[ "${case_path}" == *.py ]] && + [[ "${base_name}" != "golden.py" ]] && + [[ "${base_name}" != "compare.py" ]] && + [[ "${base_name}" != "kernel.py" ]] && + [[ "${base_name}" != _* ]] +} + +is_ptodsl_case_path() { + is_ptodsl_case_dir "$1" || is_ptodsl_case_file "$1" +} + +validate_case_path() { local case_name="$1" - local case_dir="$2" + local case_path="$2" - [[ -f "${case_dir}/kernel.pto" ]] || - die "case ${case_name} must provide kernel.pto" - if is_ptodsl_case_dir "${case_dir}"; then + if is_ptodsl_case_path "${case_path}"; then return 0 fi + [[ -d "${case_path}" ]] || die "case ${case_name} is neither a directory nor a PTODSL case file" for f in launch.cpp main.cpp golden.py compare.py; do - [[ -f "${case_dir}/${f}" ]] || die "case ${case_name} is missing ${f}" + [[ -f "${case_path}/${f}" ]] || die "case ${case_name} is missing ${f}" done + [[ -f "${case_path}/kernel.pto" ]] || die "case ${case_name} must provide kernel.pto" } discover_cases() { @@ -100,15 +116,15 @@ discover_cases() { "${CASE_NAME}" == "${onboard_only_prefix}"* ]]; then die "case ${CASE_NAME} is onboard-only and cannot run with DEVICE=SIM" fi - local requested_dir="${CASES_ROOT}/${CASE_NAME}" - [[ -d "${requested_dir}" ]] || die "unknown case: ${CASE_NAME}" - validate_case_dir "${CASE_NAME}" "${requested_dir}" + local requested_path="${CASES_ROOT}/${CASE_NAME}" + [[ -e "${requested_path}" ]] || die "unknown case: ${CASE_NAME}" + validate_case_path "${CASE_NAME}" "${requested_path}" printf "%s\n" "${CASE_NAME}" return 0 fi find "${CASES_ROOT}" -mindepth 1 -type d | sort | while read -r dir; do - [[ -f "${dir}/kernel.pto" ]] || continue + [[ -f "${dir}/kernel.pto" || -f "${dir}/kernel.py" ]] || continue local rel="${dir#${CASES_ROOT}/}" if ! is_ptodsl_case_dir "${dir}"; then local ok=1 @@ -129,6 +145,19 @@ discover_cases() { fi printf "%s\n" "${rel}" done + + find "${CASES_ROOT}" -type f -name '*.py' | sort | while read -r path; do + is_ptodsl_case_file "${path}" || continue + local rel="${path#${CASES_ROOT}/}" + if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && + "${rel}" == "${onboard_only_prefix}"* ]]; then + continue + fi + if [[ -n "${CASE_PREFIX}" && "${rel}" != "${CASE_PREFIX}"* ]]; then + continue + fi + printf "%s\n" "${rel}" + done } if [[ "${DEVICE:-SIM}" == "SIM" && "${COMPILE_ONLY:-0}" != "1" && From 0ba433d442ff0191051e60d4567ddce034d8eafe Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 17 Jun 2026 22:55:44 +0800 Subject: [PATCH 03/54] feat: first stage of vmi --- docs/designs/vmi-dialect-design.md | 2078 ++++++ docs/designs/vmi-implementation-manual.md | 4233 +++++++++++ include/PTO/IR/PTOAttrs.td | 2 + include/PTO/IR/PTOOps.td | 1 + include/PTO/IR/PTOTypeDefs.td | 1 + include/PTO/IR/VMIAttrs.td | 34 + include/PTO/IR/VMIOps.td | 562 ++ include/PTO/IR/VMITypeDefs.td | 67 + include/PTO/IR/VMIUtils.h | 53 + include/PTO/Transforms/Passes.h | 8 + include/PTO/Transforms/Passes.td | 71 + .../PTO/Transforms/VMITargetCapabilities.h | 318 + lib/PTO/IR/CMakeLists.txt | 1 + lib/PTO/IR/VMI.cpp | 1407 ++++ lib/PTO/Transforms/CMakeLists.txt | 3 + lib/PTO/Transforms/PTOValidateVMIIR.cpp | 445 ++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 1330 ++++ lib/PTO/Transforms/VMIToVPTO.cpp | 6269 +++++++++++++++++ test/lit/CMakeLists.txt | 1 + test/lit/lit.cfg.py | 6 +- test/lit/vmi/vmi_absf_integer_invalid.pto | 19 + test/lit/vmi/vmi_absi_float_invalid.pto | 19 + ...ctive_prefix_index_result_type_invalid.pto | 21 + .../vmi/vmi_addf_lane_mismatch_invalid.pto | 21 + .../vmi/vmi_bitcast_total_bits_invalid.pto | 19 + test/lit/vmi/vmi_bitwise_float_invalid.pto | 64 + .../vmi_broadcast_type_mismatch_invalid.pto | 18 + ...i_channel_merge_input_mismatch_invalid.pto | 21 + ..._channel_merge_result_mismatch_invalid.pto | 21 + .../vmi_channel_split_lane_count_invalid.pto | 20 + ...vmi_channel_split_result_count_invalid.pto | 20 + .../vmi_compress_result_mismatch_invalid.pto | 23 + .../vmi/vmi_constant_attr_kind_invalid.pto | 20 + .../vmi_constant_element_count_invalid.pto | 20 + .../vmi/vmi_constant_element_type_invalid.pto | 20 + .../vmi_constant_mask_attr_kind_invalid.pto | 20 + ...mi_constant_mask_element_count_invalid.pto | 20 + ...vmi_constant_mask_element_type_invalid.pto | 20 + test/lit/vmi/vmi_divf_integer_invalid.pto | 22 + test/lit/vmi/vmi_elementwise_kind_invalid.pto | 63 + .../vmi/vmi_ensure_layout_surface_invalid.pto | 47 + test/lit/vmi/vmi_extf_direction_invalid.pto | 19 + .../vmi/vmi_extf_lane_mismatch_invalid.pto | 19 + test/lit/vmi/vmi_fma_integer_invalid.pto | 23 + test/lit/vmi/vmi_gather_indices_invalid.pto | 25 + .../lit/vmi/vmi_iota_element_type_invalid.pto | 19 + test/lit/vmi/vmi_iota_order_invalid.pto | 19 + ..._layout_assignment_active_prefix_index.pto | 26 + .../vmi_layout_assignment_broadcast_remat.pto | 52 + .../vmi_layout_assignment_call_boundary.pto | 45 + .../vmi/vmi_layout_assignment_cf_branch.pto | 56 + .../vmi/vmi_layout_assignment_cf_switch.pto | 51 + ...hannel_merge_count_unsupported_invalid.pto | 23 + ...hannel_split_count_unsupported_invalid.pto | 21 + .../vmi/vmi_layout_assignment_compress.pto | 30 + .../vmi_layout_assignment_compress_store.pto | 31 + .../vmi_layout_assignment_constant_remat.pto | 54 + .../vmi/vmi_layout_assignment_expand_load.pto | 35 + ...ayout_assignment_external_call_invalid.pto | 25 + ...ayout_assignment_external_decl_invalid.pto | 15 + ...yout_assignment_external_decl_preserve.pto | 23 + test/lit/vmi/vmi_layout_assignment_fma.pto | 29 + test/lit/vmi/vmi_layout_assignment_gather.pto | 35 + ...ayout_assignment_indirect_call_invalid.pto | 24 + .../vmi/vmi_layout_assignment_iota_remat.pto | 54 + .../vmi/vmi_layout_assignment_load_truncf.pto | 133 + ...ment_mask_granularity_conflict_invalid.pto | 33 + .../vmi/vmi_layout_assignment_mask_remat.pto | 73 + .../vmi_layout_assignment_mask_use_ensure.pto | 36 + .../vmi/vmi_layout_assignment_masked_load.pto | 32 + .../vmi_layout_assignment_multi_return.pto | 39 + ...signment_multi_return_conflict_invalid.pto | 30 + ...assignment_post_gate_type_attr_invalid.pto | 17 + .../vmi/vmi_layout_assignment_reduce_addf.pto | 30 + .../vmi/vmi_layout_assignment_reduce_addi.pto | 33 + .../vmi_layout_assignment_reduce_minmaxf.pto | 49 + .../lit/vmi/vmi_layout_assignment_scatter.pto | 32 + ...i_layout_assignment_scf_execute_region.pto | 38 + .../lit/vmi/vmi_layout_assignment_scf_for.pto | 43 + test/lit/vmi/vmi_layout_assignment_scf_if.pto | 50 + ...vmi_layout_assignment_scf_index_switch.pto | 48 + .../vmi/vmi_layout_assignment_scf_while.pto | 47 + .../vmi_layout_assignment_store_ensure.pto | 48 + .../vmi_layout_assignment_truncf_ensure.pto | 39 + test/lit/vmi/vmi_layout_assignment_widen.pto | 39 + test/lit/vmi/vmi_layout_factor_invalid.pto | 18 + .../vmi/vmi_layout_gate_surface_invalid.pto | 18 + .../vmi_layout_gate_surface_mask_invalid.pto | 20 + ...gate_type_attr_nested_physical_invalid.pto | 17 + ..._layout_gate_type_attr_surface_invalid.pto | 17 + test/lit/vmi/vmi_layout_gate_valid.pto | 23 + ...i_mask_concrete_without_layout_invalid.pto | 18 + test/lit/vmi/vmi_mask_granularity_invalid.pto | 18 + test/lit/vmi/vmi_mask_logic_invalid.pto | 67 + .../vmi/vmi_mask_pred_with_layout_invalid.pto | 18 + ..._masked_store_mask_granularity_invalid.pto | 25 + .../vmi/vmi_memory_element_type_invalid.pto | 57 + test/lit/vmi/vmi_min_max_integer_invalid.pto | 37 + test/lit/vmi/vmi_negf_integer_invalid.pto | 19 + test/lit/vmi/vmi_op_verifier_basic.pto | 106 + test/lit/vmi/vmi_pack_arity_invalid.pto | 20 + .../vmi_producer_boundary_helper_invalid.pto | 22 + .../vmi_producer_boundary_layout_invalid.pto | 19 + ..._producer_boundary_mask_layout_invalid.pto | 19 + ...i_producer_boundary_non_vmi_op_invalid.pto | 21 + ...vmi_producer_boundary_physical_invalid.pto | 30 + ...ucer_boundary_type_attr_layout_invalid.pto | 17 + ...ucer_boundary_type_attr_nested_invalid.pto | 17 + ...cer_boundary_type_attr_surface_invalid.pto | 17 + test/lit/vmi/vmi_producer_boundary_valid.pto | 27 + .../vmi_ptoas_backend_required_invalid.pto | 17 + test/lit/vmi/vmi_ptoas_cli_control_flow.pto | 43 + test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 45 + test/lit/vmi/vmi_ptoas_public_abi_invalid.pto | 20 + .../vmi_ptoas_public_result_abi_invalid.pto | 22 + ...mi_reduce_addf_missing_reassoc_invalid.pto | 23 + test/lit/vmi/vmi_scatter_indices_invalid.pto | 24 + .../vmi_select_mask_granularity_invalid.pto | 25 + test/lit/vmi/vmi_shli_float_invalid.pto | 22 + test/lit/vmi/vmi_shrui_float_invalid.pto | 22 + test/lit/vmi/vmi_shrui_signed_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_abs.pto | 44 + .../vmi/vmi_to_vpto_active_prefix_index.pto | 33 + ...active_prefix_index_multichunk_invalid.pto | 26 + ..._vpto_active_prefix_index_tail_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_add.pto | 57 + test/lit/vmi/vmi_to_vpto_bf16_arith.pto | 50 + test/lit/vmi/vmi_to_vpto_bitcast.pto | 29 + test/lit/vmi/vmi_to_vpto_bitcast_partial.pto | 29 + test/lit/vmi/vmi_to_vpto_bitwise.pto | 53 + test/lit/vmi/vmi_to_vpto_broadcast.pto | 69 + test/lit/vmi/vmi_to_vpto_call_boundary.pto | 52 + test/lit/vmi/vmi_to_vpto_cf_branch.pto | 78 + .../vmi_to_vpto_channel_merge4_contiguous.pto | 40 + ...hannel_merge_count_unsupported_invalid.pto | 25 + ...i_to_vpto_channel_merge_layout_invalid.pto | 23 + ...to_channel_merge_partial_group_invalid.pto | 25 + ...hannel_split_count_unsupported_invalid.pto | 23 + ...i_to_vpto_channel_split_layout_invalid.pto | 22 + .../vmi/vmi_to_vpto_channel_split_merge.pto | 95 + .../vmi_to_vpto_channel_split_merge_tail.pto | 35 + ...to_channel_split_partial_group_invalid.pto | 24 + .../vmi_to_vpto_cmp_element_type_invalid.pto | 24 + ...vpto_cmp_predicate_unsupported_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_cmp_select.pto | 140 + ...unsigned_predicate_unsupported_invalid.pto | 28 + .../vmi_to_vpto_compaction_deint_invalid.pto | 58 + test/lit/vmi/vmi_to_vpto_compress.pto | 32 + ...mi_to_vpto_compress_multichunk_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_compress_store.pto | 33 + ...vpto_compress_store_multichunk_invalid.pto | 26 + .../vmi/vmi_to_vpto_compress_tail_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_constant.pto | 33 + test/lit/vmi/vmi_to_vpto_constant_mask.pto | 128 + .../vmi_to_vpto_constant_mask_nonprefix.pto | 34 + ...mi_to_vpto_constant_mask_rematerialize.pto | 42 + .../vmi_to_vpto_constant_nonsplat_invalid.pto | 23 + ...vmi_to_vpto_construction_width_invalid.pto | 34 + test/lit/vmi/vmi_to_vpto_create_mask.pto | 87 + .../vmi/vmi_to_vpto_create_mask_dynamic.pto | 132 + .../vmi_to_vpto_create_mask_plt_fallback.pto | 30 + .../vmi_to_vpto_create_mask_rematerialize.pto | 47 + test/lit/vmi/vmi_to_vpto_divf.pto | 33 + .../vmi/vmi_to_vpto_e2e_widen_add_store.pto | 74 + .../vmi_to_vpto_elementwise_width_invalid.pto | 41 + test/lit/vmi/vmi_to_vpto_ensure_identity.pto | 80 + .../vmi/vmi_to_vpto_ensure_layout_deint4.pto | 57 + ..._to_vpto_ensure_layout_partial_invalid.pto | 23 + .../vmi/vmi_to_vpto_ensure_layout_vdintlv.pto | 30 + .../vmi/vmi_to_vpto_ensure_layout_vintlv.pto | 49 + .../vmi_to_vpto_ensure_mask_granularity.pto | 40 + ...to_vpto_ensure_mask_granularity_direct.pto | 31 + ...vpto_ensure_mask_granularity_multistep.pto | 34 + .../vmi/vmi_to_vpto_ensure_mask_layout.pto | 114 + ...pto_ensure_mask_layout_partial_invalid.pto | 23 + .../vmi_to_vpto_ensure_mask_layout_widths.pto | 78 + .../vmi_to_vpto_expand_load_all_active.pto | 66 + ...oad_all_active_negative_offset_invalid.pto | 35 + ..._vpto_expand_load_partial_mask_invalid.pto | 33 + .../vmi_to_vpto_expand_load_runtime_mask.pto | 41 + test/lit/vmi/vmi_to_vpto_extf.pto | 74 + test/lit/vmi/vmi_to_vpto_extf_f8.pto | 59 + test/lit/vmi/vmi_to_vpto_extf_multichunk.pto | 35 + test/lit/vmi/vmi_to_vpto_fma.pto | 83 + .../vmi_to_vpto_fma_element_type_invalid.pto | 26 + ...vpto_function_type_layout_free_invalid.pto | 16 + test/lit/vmi/vmi_to_vpto_gather.pto | 37 + .../vmi/vmi_to_vpto_gather_f16_invalid.pto | 28 + ...i_to_vpto_gather_scatter_shape_invalid.pto | 91 + test/lit/vmi/vmi_to_vpto_iota.pto | 120 + test/lit/vmi/vmi_to_vpto_iota_tail.pto | 57 + test/lit/vmi/vmi_to_vpto_load_deint.pto | 53 + .../vmi/vmi_to_vpto_load_deint_multichunk.pto | 31 + .../vmi/vmi_to_vpto_load_nonfull_invalid.pto | 27 + .../vmi/vmi_to_vpto_load_safe_tail_memref.pto | 73 + ..._to_vpto_load_safe_tail_memref_invalid.pto | 25 + ...fe_tail_memref_negative_offset_invalid.pto | 25 + .../vmi/vmi_to_vpto_load_store_contiguous.pto | 33 + test/lit/vmi/vmi_to_vpto_mask_logic.pto | 126 + test/lit/vmi/vmi_to_vpto_masked_load.pto | 36 + ...mi_to_vpto_masked_load_nonfull_invalid.pto | 33 + ...i_to_vpto_masked_load_safe_tail_memref.pto | 69 + ...fe_tail_memref_negative_offset_invalid.pto | 31 + test/lit/vmi/vmi_to_vpto_masked_store.pto | 38 + .../vmi_to_vpto_masked_store_deint_tail.pto | 42 + ...i_to_vpto_masked_store_nonfull_invalid.pto | 26 + .../lit/vmi/vmi_to_vpto_masked_store_tail.pto | 40 + .../vmi_to_vpto_math_element_type_invalid.pto | 131 + .../vmi/vmi_to_vpto_memory_space_invalid.pto | 130 + test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto | 44 + .../vmi/vmi_to_vpto_memref_layout_invalid.pto | 177 + test/lit/vmi/vmi_to_vpto_min_max.pto | 39 + test/lit/vmi/vmi_to_vpto_negf.pto | 29 + test/lit/vmi/vmi_to_vpto_pack_unpack.pto | 46 + test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 310 + test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 51 + test/lit/vmi/vmi_to_vpto_reduce_addf.pto | 36 + .../vmi_to_vpto_reduce_addf_f16_invalid.pto | 26 + .../vmi_to_vpto_reduce_addf_multichunk.pto | 38 + test/lit/vmi/vmi_to_vpto_reduce_addi.pto | 36 + .../vmi_to_vpto_reduce_addi_i16_invalid.pto | 26 + .../vmi_to_vpto_reduce_addi_multichunk.pto | 38 + .../vmi_to_vpto_reduce_maxf_multichunk.pto | 65 + .../vmi_to_vpto_reduce_maxf_tail_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_reduce_minf.pto | 36 + .../vmi/vmi_to_vpto_reduce_shape_invalid.pto | 85 + .../vmi_to_vpto_relu_element_type_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_scatter.pto | 31 + ...to_vpto_scatter_missing_unique_invalid.pto | 27 + test/lit/vmi/vmi_to_vpto_scf_for.pto | 44 + test/lit/vmi/vmi_to_vpto_scf_if.pto | 57 + test/lit/vmi/vmi_to_vpto_shli.pto | 33 + test/lit/vmi/vmi_to_vpto_shrui.pto | 33 + .../vmi/vmi_to_vpto_shuffle_forwarding.pto | 159 + .../vmi/vmi_to_vpto_shuffle_lane0_splat.pto | 44 + ...stable_gather_masked_load_todo_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_store_deint.pto | 64 + .../vmi/vmi_to_vpto_store_deint_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_store_deint_tail.pto | 35 + test/lit/vmi/vmi_to_vpto_store_tail.pto | 29 + .../vmi/vmi_to_vpto_store_width_invalid.pto | 38 + test/lit/vmi/vmi_to_vpto_sub_mul.pto | 60 + test/lit/vmi/vmi_to_vpto_tile_read_write.pto | 64 + .../vmi/vmi_to_vpto_tile_write_deint_tail.pto | 34 + test/lit/vmi/vmi_to_vpto_tile_write_tail.pto | 33 + ..._to_vpto_tile_write_tail_deint_invalid.pto | 22 + test/lit/vmi/vmi_to_vpto_truncf.pto | 56 + ...vpto_truncf_fp8_128_contiguous_invalid.pto | 25 + ..._vpto_truncf_unsupported_shape_invalid.pto | 23 + test/lit/vmi/vmi_to_vpto_type_arity.pto | 63 + ...vpto_type_attr_nested_residual_invalid.pto | 16 + ...vmi_to_vpto_type_attr_residual_invalid.pto | 16 + test/lit/vmi/vmi_to_vpto_type_only.pto | 27 + test/lit/vmi/vmi_to_vpto_unary_math.pto | 89 + ..._vpto_unrealized_cast_residual_invalid.pto | 20 + .../vmi_to_vpto_unsupported_op_invalid.pto | 25 + test/lit/vmi/vmi_truncf_direction_invalid.pto | 19 + .../vmi/vmi_truncf_lane_mismatch_invalid.pto | 19 + test/lit/vmi/vmi_type_attr_parse.pto | 40 + .../vmi/vmi_type_element_count_invalid.pto | 18 + .../vmi/vmi_unary_math_integer_invalid.pto | 55 + test/lit/vmi/vmi_unpack_arity_invalid.pto | 20 + .../vmi/dequant-f16-to-f32-tail/compare.py | 27 + .../vmi/dequant-f16-to-f32-tail/golden.py | 44 + .../vmi/dequant-f16-to-f32-tail/kernel.pto | 60 + .../vmi/dequant-f16-to-f32-tail/launch.cpp | 40 + .../vmi/dequant-f16-to-f32-tail/main.cpp | 78 + .../vmi/dequant-f16-to-f32-tail/ptoas.flags | 1 + .../vmi/dequant-f8-to-f32-tail/compare.py | 27 + .../vmi/dequant-f8-to-f32-tail/golden.py | 45 + .../vmi/dequant-f8-to-f32-tail/kernel.pto | 59 + .../vmi/dequant-f8-to-f32-tail/launch.cpp | 40 + .../cases/vmi/dequant-f8-to-f32-tail/main.cpp | 78 + .../vmi/dequant-f8-to-f32-tail/ptoas.flags | 1 + .../vmi/quant-f32-to-f16-tail/compare.py | 27 + .../cases/vmi/quant-f32-to-f16-tail/golden.py | 44 + .../vmi/quant-f32-to-f16-tail/kernel.pto | 60 + .../vmi/quant-f32-to-f16-tail/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f16-tail/main.cpp | 78 + .../vmi/quant-f32-to-f16-tail/ptoas.flags | 1 + .../cases/vmi/quant-f32-to-f8-full/compare.py | 27 + .../cases/vmi/quant-f32-to-f8-full/golden.py | 40 + .../cases/vmi/quant-f32-to-f8-full/kernel.pto | 47 + .../cases/vmi/quant-f32-to-f8-full/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f8-full/main.cpp | 79 + .../vmi/quant-f32-to-f8-full/ptoas.flags | 1 + .../cases/vmi/quant-f32-to-f8-tail/compare.py | 27 + .../cases/vmi/quant-f32-to-f8-tail/golden.py | 44 + .../cases/vmi/quant-f32-to-f8-tail/kernel.pto | 56 + .../cases/vmi/quant-f32-to-f8-tail/launch.cpp | 40 + .../cases/vmi/quant-f32-to-f8-tail/main.cpp | 78 + .../vmi/quant-f32-to-f8-tail/ptoas.flags | 1 + .../vmi/reduce-f16-f8-mul-store/compare.py | 27 + .../vmi/reduce-f16-f8-mul-store/golden.py | 46 + .../vmi/reduce-f16-f8-mul-store/kernel.pto | 66 + .../vmi/reduce-f16-f8-mul-store/launch.cpp | 43 + .../vmi/reduce-f16-f8-mul-store/main.cpp | 88 + .../vmi/reduce-f16-f8-mul-store/ptoas.flags | 1 + tools/CMakeLists.txt | 1 + tools/pto-test-opt/CMakeLists.txt | 35 + tools/pto-test-opt/pto-test-opt.cpp | 35 + tools/ptoas/ptoas.cpp | 62 + 302 files changed, 28543 insertions(+), 1 deletion(-) create mode 100644 docs/designs/vmi-dialect-design.md create mode 100644 docs/designs/vmi-implementation-manual.md create mode 100644 include/PTO/IR/VMIAttrs.td create mode 100644 include/PTO/IR/VMIOps.td create mode 100644 include/PTO/IR/VMITypeDefs.td create mode 100644 include/PTO/IR/VMIUtils.h create mode 100644 include/PTO/Transforms/VMITargetCapabilities.h create mode 100644 lib/PTO/IR/VMI.cpp create mode 100644 lib/PTO/Transforms/PTOValidateVMIIR.cpp create mode 100644 lib/PTO/Transforms/VMILayoutAssignment.cpp create mode 100644 lib/PTO/Transforms/VMIToVPTO.cpp create mode 100644 test/lit/vmi/vmi_absf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_absi_float_invalid.pto create mode 100644 test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto create mode 100644 test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_bitcast_total_bits_invalid.pto create mode 100644 test/lit/vmi/vmi_bitwise_float_invalid.pto create mode 100644 test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_split_lane_count_invalid.pto create mode 100644 test/lit/vmi/vmi_channel_split_result_count_invalid.pto create mode 100644 test/lit/vmi/vmi_compress_result_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_attr_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_constant_mask_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_divf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_elementwise_kind_invalid.pto create mode 100644 test/lit/vmi/vmi_ensure_layout_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_extf_direction_invalid.pto create mode 100644 test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_fma_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_gather_indices_invalid.pto create mode 100644 test/lit/vmi/vmi_iota_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_iota_order_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_call_boundary.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_cf_branch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_cf_switch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_compress.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_compress_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_constant_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_expand_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_fma.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_gather.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_iota_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_load_truncf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_remat.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_multi_return.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_addf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_addi.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scatter.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_for.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_if.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_scf_while.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_store_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_widen.pto create mode 100644 test/lit/vmi/vmi_layout_factor_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_valid.pto create mode 100644 test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_logic_invalid.pto create mode 100644 test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_memory_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_min_max_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_negf_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_op_verifier_basic.pto create mode 100644 test/lit/vmi/vmi_pack_arity_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_helper_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_physical_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto create mode 100644 test/lit/vmi/vmi_producer_boundary_valid.pto create mode 100644 test/lit/vmi/vmi_ptoas_backend_required_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_cli_control_flow.pto create mode 100644 test/lit/vmi/vmi_ptoas_cli_pipeline.pto create mode 100644 test/lit/vmi/vmi_ptoas_public_abi_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto create mode 100644 test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto create mode 100644 test/lit/vmi/vmi_scatter_indices_invalid.pto create mode 100644 test/lit/vmi/vmi_select_mask_granularity_invalid.pto create mode 100644 test/lit/vmi/vmi_shli_float_invalid.pto create mode 100644 test/lit/vmi/vmi_shrui_float_invalid.pto create mode 100644 test/lit/vmi/vmi_shrui_signed_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_abs.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_add.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bf16_arith.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_partial.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitwise.pto create mode 100644 test/lit/vmi/vmi_to_vpto_broadcast.pto create mode 100644 test/lit/vmi/vmi_to_vpto_call_boundary.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cf_branch.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_merge.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmp_select.pto create mode 100644 test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto create mode 100644 test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto create mode 100644 test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto create mode 100644 test/lit/vmi/vmi_to_vpto_divf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_identity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf_f8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_extf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_fma.pto create mode 100644 test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_iota.pto create mode 100644 test/lit/vmi/vmi_to_vpto_iota_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto create mode 100644 test/lit/vmi/vmi_to_vpto_mask_logic.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_masked_store_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto create mode 100644 test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_min_max.pto create mode 100644 test/lit/vmi/vmi_to_vpto_negf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_pack_unpack.pto create mode 100644 test/lit/vmi/vmi_to_vpto_quant_dequant.pto create mode 100644 test/lit/vmi/vmi_to_vpto_quant_fp8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_minf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scatter.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scf_for.pto create mode 100644 test/lit/vmi/vmi_to_vpto_scf_if.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shli.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shrui.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto create mode 100644 test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto create mode 100644 test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_store_width_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_sub_mul.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_read_write.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_arity.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_type_only.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unary_math.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto create mode 100644 test/lit/vmi/vmi_truncf_direction_invalid.pto create mode 100644 test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto create mode 100644 test/lit/vmi/vmi_type_attr_parse.pto create mode 100644 test/lit/vmi/vmi_type_element_count_invalid.pto create mode 100644 test/lit/vmi/vmi_unary_math_integer_invalid.pto create mode 100644 test/lit/vmi/vmi_unpack_arity_invalid.pto create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp create mode 100644 test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp create mode 100644 test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp create mode 100644 test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp create mode 100644 test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags create mode 100644 tools/pto-test-opt/CMakeLists.txt create mode 100644 tools/pto-test-opt/pto-test-opt.cpp diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md new file mode 100644 index 0000000000..5578ca93d1 --- /dev/null +++ b/docs/designs/vmi-dialect-design.md @@ -0,0 +1,2078 @@ +# VMI dialect 设计 + +## 背景 + +VPTO 的 `!pto.vreg` 是 256 bytes 物理向量寄存器抽象。很多 VPTO op 暴露的是 +physical placement:`vcvt` part、pack/unpack、interleave/deinterleave、load/store dist、 +predicate granularity 等。TileLang `T.parallel` 或其它前端想表达的是逻辑向量语义,不应该 +手写这些 physical placement。 + +VMI dialect 的目标是提供一层 PTO-friendly 的 semantic vector IR。它不是任何外部向量 dialect +的语法克隆,也不是 VPTO physical dialect。VMI 的设计来源是 PTO virtual vector ISA 需要承接的 +逻辑向量语义、layout、mask granularity、memory safety 和控制流 layout join;后续 lowering 只从 +VMI 决定 physical layout 和 VPTO op。 + +本设计采用 `vmi.vreg` 作为 layout carrier,不再引入单独的 `vbundle` type: + +```text +semantic VMI + -> layout-assigned VMI + -> physical VPTO +``` + +VMI 的 producer 在核心设计之外。TileLang/PTO lowering、手写 VMI 测试或其它 import 工具都可以 +产生 VMI,但它们不能定义 VMI 的 semantic surface。核心设计只要求 producer 在进入 VMI boundary +时生成合法 VMI IR。 + +## 和旧 VMI layout 设计的关系 + +旧文档中的核心形式是: + +```mlir +!pto.vmi.vreg +!pto.vmi.mask +``` + +这个方向是对的:`vmi.vreg` 本身是 virtual aggregate type,可以承载完整 logical vector, +layout 放在它上面比放在 physical `!pto.vreg` 上更合理。 + +旧设计需要补强的地方主要是 layout descriptor 和 lowering contract,而不是推翻 +`vmi.vreg`: + +1. 旧 layout descriptor 把 `logical_shape`、`phys_dtype`、`phys_lanes` 放进 attr,和 + `vreg` / target registry 存在重复信息。重复字段会产生 verifier 漂移。 +2. `axes=[#axis<...>]` 太开放,缺少每个 layout 的精确定义、part ordering 和 lane map。 +3. 旧设计要求 `N * bitwidth(T)` 是 256B 整数倍,无法覆盖 tail / 非整 tile。 +4. mask 只写成 `mask`,但没有定义 data layout、mask layout、mask granularity + conversion 在宽度转换中的同步规则。 +5. 控制流 join 没有定义:`scf.if` 两边 layout 不同、`scf.for` loop-carried layout 如何稳定。 +6. memory access map 和 register layout 没有切开,容易把 strided memory view 误当成 vreg + layout。 +7. hard vector semantics 缺失,例如 padding read、active prefix index、dynamic permute、 + compress/expand、scan/reduction/contract 的 VMI 表达和 lowering contract。 + +因此本设计保留 `vmi.vreg` 这个 carrier,但不沿用旧 layout descriptor 的 +开放式语义。旧文档没有定义 “logical behavior -> hardware mismatch -> physical +decomposition -> lane map -> propagation/sink” 这条 source contract;这是本文新增的核心约束。 + +换句话说,本文不是复述旧 `vmi.layout`,而是把旧的开放式 axis descriptor 收紧成一个很小的 +public layout 集合。本设计只接受 `contiguous`、`deinterleaved = 2`、`deinterleaved = 4`。 +source contract 是新增 layout kind 的准入规则,不是要求实现 generic axes 或任意 lane-map +descriptor。 + +## 目标 + +1. VMI surface 表达逻辑向量语义,不暴露 VPTO part/dist/interleave 细节。 +2. `vmi.vreg` 是 virtual aggregate type,可以表示大于 256B 的 logical vector。 +3. layout 放在 layout-assigned VMI type 上,不再另设 `vbundle`。 +4. VMI mask 是一等类型;surface mask 表达 logical predicate,layout-assigned mask 才携带 + concrete predicate granularity `b8/b16/b32`。 +5. VMI 支持 tail / 非整 tile;padding physical lane 不可观察。 +6. VMI lowering 支持控制流中的 layout join。 +7. VMI producer boundary 后的 IR 必须只依赖 VMI semantic op/type 表达逻辑向量语义。 + +## 非目标 + +1. 不改变 physical `!pto.vreg` 的含义。它仍然是 256 bytes physical register。 +2. 不把 VMI 做成任何外部向量 dialect 的逐 op 复制品;VMI 只表达 PTO lowering 需要的 logical + vector semantics。 +3. 不把 scalar lane extract 当作 VMI vector op。scalar lane extract 是 vector-to-scalar + boundary,必须在进入 VMI 前被 producer 消除,或以明确 diagnostic 退出 PTO 路线。 +4. 不把 VPTO load/store dist 暴露成 VMI surface op。dist 是 lowering 选择。 + +## VMI Producer Boundary Contract + +VMI 是 PTO 路线上的 virtual vector ISA。任何 producer 在进入 VMI boundary 后,必须满足下面之一: + +1. 逻辑向量语义已经表达为 native VMI semantic op。 +2. 逻辑向量语义已经表达为一组 VMI semantic op 的组合,并保持 producer 的 observable semantics。 +3. 该行为不是 VMI 负责的向量计算,而是 vector-to-scalar / tensor / debug / transform boundary, + 已经在进入 VMI 前由 producer 消除,或以明确 diagnostic 退出 PTO 路线。 + +不能把“当前阶段不支持”作为 VMI 设计结果。一个 PTO virtual vector semantic 如果属于 VMI 负责的 +逻辑向量语义,文档必须给出 VMI op、组合 lowering、layout contract、memory fallback 或 target +capability diagnostic。diagnostic 只允许表示语义边界或目标能力缺失,不能表示“VMI 没有设计这个能力”。 + +`pto.vmi -> pto` 的完成条件是: + +```text +at VMI producer boundary: + logical vector semantics are represented by VMI op/type + no physical VPTO op is introduced by the producer + no hidden layout/mask/type side table is required to interpret a VMI value + +after vmi-layout-assignment: + every vmi.vreg/vmi.mask has an explicit #pto.vmi.layout + every mask granularity matches its consumer + every control-flow yield/iter_arg/result has one stable layout + +after vmi-to-vpto: + no pto.vmi op/type remains + every logical VMI value has been lowered to ordered physical VPTO values +``` + +### Capability And Fallback Policy + +所有 direct lowering 和 fallback 选择必须来自显式配置,不能依赖 pass 内隐藏全局状态: + +```text +TargetCapabilityRegistry: + element-type storage/compute/convert support + layout source/sink/conversion support + memory access capability: OOB, masked, gather/scatter, block-strided + predicate capability: granularity conversion, prefix-popcount, rearrangement + reduction/scan/contract capability + scratch memory spaces, alignment, and lifetime rules + +VMIToPTOOptions: + enableScratchFallback + enableGuardedScalarFallback + enableIndexBufferFallback + allowDebugStrip + targetVScaleSpecialization + diagnosticVerbosity +``` + +fallback 被 option 禁用时,diagnostic 必须报告 `disabled_by_option`。target registry 缺能力时, +diagnostic 必须报告 `missing_capability`。debug-only op 只能由 debug pipeline 消费,或在 +`allowDebugStrip` 明确开启时剥离;否则报 `VMI-DEBUG-BOUNDARY`。 + +fallback resource 也必须显式建模: + +```text +scratch fallback: + memory space, alignment, element type, shape, lifetime, and deallocation point + must be explicit in the lowering plan + scratch initialization, such as padding fill, must dominate later scratch load + +guarded scalar/vector fallback: + guard must dominate every memory effect it protects + invalid lane must not compute a memory effect through an OOB memref address + +index-buffer fallback: + index element width, signedness, and address unit must match the consumer + buffer lifetime must dominate gather/scatter or compaction use +``` + +如果无法分配 scratch、无法放置 guard、或 index buffer 宽度不满足目标要求,diagnostic 使用 +`VMI-FALLBACK-RESOURCE`,并说明是 resource 缺失而不是语义不可表达。 + +## 类型模型 + +### Surface Type + +VMI surface type 不显式写 layout: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.vreg<256xf8> +!pto.vmi.vreg<1xf32> + +!pto.vmi.mask<128xpred> +!pto.vmi.mask<256xpred> +``` + +`N` 是 logical lane count,`T` 是 logical element type。surface `mask` 表示 N 个 +logical predicate lane,不预先绑定 VPTO predicate granularity。layout assignment 根据 consumer +选择 concrete granularity: + +```text +f32/i32 consumer -> b32 +f16/bf16/i16 consumer -> b16 +f8/i8 consumer -> b8 +``` + +如果一个 logical mask 被不同 width consumer 使用,VMI lowering 必须按 use 插入 +`vmi.ensure_mask_granularity` 或重物化 mask producer,不能假设某个 concrete granularity 可直接 +给所有 consumer 使用。 + +VMI type 以 1-D logical vector 为核心。来自 multi-rank producer value 的语义在进入 VMI boundary 前按 row-major flatten 成: + +```mlir +!pto.vmi.vreg<64xf32> +!pto.vmi.mask<64xpred> +``` + +VMI value 本身只承载 flattened lane sequence,不携带隐式 rank side table。需要 rank 信息的 op +必须在自身 attr 中保存 logical shape / indexing map,例如 `logical_shape = [8, 8]`。这样保持 +与既有 `vmi.vreg` 设计一致,同时不丢失 transfer、transpose、reshape 等 op 的语义。 + +shape-sensitive op 的规则是: + +```text +elementwise / select: + operate on flattened lanes and preserve any surrounding op-provided shape context + +tile_read / tile_write: + carry logical_shape and permutation_map attrs + +shape_cast / reshape / transpose / contract: + carry source/result shapes, maps, and iterator metadata as op attrs + +block argument / function argument: + carries only flat vreg type; any later shaped use must provide its own shape attrs +``` + +因此 logical shape 信息不能保存在 C++ side table,也不能要求 consumer 从 defining op 反查。 + +Rank-0 logical vector 仍然是 VMI vector value,不是 scalar SSA value: + +```mlir +rank-0 logical vector -> !pto.vmi.vreg<1xT> +rank-0 logical predicate -> !pto.vmi.mask<1xpred> +``` + +只有产生 scalar result 的 extract 才是 vector-to-scalar boundary。rank-0 logical vector load、 +bitcast、mask 和 arithmetic 仍然走 VMI,不能因为只有一个 lane 就绕开 VMI verifier。 + +Scalable logical vector 不能直接进入 VMI type,因为 `vmi.vreg` 的 `N` 是 concrete logical lane +count。producer 必须先根据 target profile 和 tiling decision 把 scalable semantics specialize 成固定 +`N`;否则在 VMI boundary 报 `VMI-SCALABLE-VECTOR`。这不是 VMI 的临时缺口,而是 +固定 256B physical vreg lowering 的前置约束。 + +### Layout-Assigned Type + +`vmi-layout-assignment` 后,所有 VMI data/mask value 都必须带 layout: + +```mlir +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<256xf32, #pto.vmi.layout> + +!pto.vmi.mask<128xb32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +这里的 `#pto.vmi.layout` 是唯一的 VMI register layout carrier。它不是 `#pto.vlayout` +的直接复用,也不是 `vbundle` 的 type 参数;但它必须采用同一套精确 lane-map 语义,保证后续 +lower 到 physical VPTO 时可验证。 + +### 非整 Tile + +VMI type 不要求 `N * bitwidth(T)` 是 256B 整数倍: + +```mlir +!pto.vmi.vreg<100xf32> +!pto.vmi.mask<100xpred> +``` + +physical lowering 时按 256B part 向上取整。超出 `N` 的 physical lane 是 padding lane: + +```text +padding lane: + may be poison/undef internally + must not be stored + must not affect compare/reduction/scan + must not become visible through layout conversion +``` + +任何 store、reduction、compress、mask-producing op 都必须用 logical lane count 或 explicit +mask 保护 padding lane。 + +## Layout 设计来源 + +VMI layout 的价值必须从逻辑 vector 行为推导,而不是从 layout 名字推导。判断流程是: + +```text +1. 前端想表达一个完整的 logical vector 行为。 +2. VPTO 底层指令不能把这个 logical vector 天然放进一个 contiguous physical sequence。 +3. 但 VPTO 可以把这个 logical vector 拆成一组有固定 lane-map 的 physical parts。 +4. 后续常见 op 可以在这些 parts 上逐 part 保持 logical semantics。 +5. 边界 consumer 能直接消费这种 parts,或存在可验证的 materialize path。 +6. 因此值得把这个 parts relation 提升为 VMI layout。 +``` + +layout 不是“某条指令的名字”,而是一个 representation relation: + +```text +Layout L defines: + logical vector value V[NxT] + <-> ordered physical parts P0, P1, ... + with exact map logical lane i -> (part, lane) +``` + +只有当这个 relation 能让 VMI 保持“用户看到的是一个连续 logical vector”,同时避免前端手写 +parts,layout 才有设计价值。 + +### Register Layout 集合 + +VMI register layout 不采用复杂通用 descriptor,而是定义为封闭集合: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +`deinterleaved = K` 表示一个 logical vector 被拆成 K 个 physical part,第 `p` 个 part 保存 +logical lane `p, p + K, p + 2K, ...`。这个名字直接描述元素摆放,不绑定到某条 VPTO op,也不 +引入旧 `axes` 的通用维度系统。 + +不加入 `channel`、`packed_bits`、`blocked`、`stride`、`permutation` 等 layout kind。 +这些能力先由 VMI semantic op、memory access plan 或 explicit layout conversion 表达。只有当 +一个新 representation 同时满足下面的 source contract,才允许扩展 layout 目录。 + +### Layout Source Contract + +每个 VMI layout kind 必须来自一条明确的 source contract: + +```text +logical behavior: + VMI 想表达的用户级 vector 行为 + +hardware mismatch: + 为什么 VPTO 不能用一个 contiguous physical sequence 天然承载 + +physical decomposition: + VPTO 实际能产生或消费的 physical parts + +lane map: + logical lane -> physical part/lane 的精确定义 + +propagation rule: + 哪些 VMI op 可以逐 part 保持语义 + +boundary rule: + 哪些 load/store/pack/convert consumer 可以直接消费,哪些必须 materialize + +mask rule: + 对应 mask 如何生成、转换和消费 +``` + +没有这份 source contract 的 lane movement 不能进入 `#pto.vmi.layout`。 + +### Source 1: Widen Cast To Larger Logical Vector + +逻辑行为: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +``` + +用户语义是“128 个 f16 lane 加宽成 128 个连续 f32 lane”。但 128 个 f32 是 512B,超过单个 +256B physical vreg。VPTO 的可行 lowering 不是一个 contiguous 512B register,而是两条 part +conversion: + +```text +even part: + physical even[i] = extf(logical[2*i]) + +odd part: + physical odd[i] = extf(logical[2*i+1]) +``` + +因此需要一个 layout 表达“这个 VMI value 仍然是 logical `128xf32`,但 physical representation +是 even/odd 两个 parts”: + +```mlir +#pto.vmi.layout +``` + +lane map: + +```text +part = i % 2 +lane = floor(i / 2) +physical[part][lane] = logical[i] +``` + +这个 layout 的价值在于后续 elementwise op 不需要 materialize contiguous representation: + +```mlir +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +lowering 可以变成两路 add: + +```text +add even parts +add odd parts +``` + +最后如果 store consumer 能把 even/odd parts 交织写回 contiguous memory,就不需要中途 +`ensure_layout contiguous`。 + +同理: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +``` + +需要: + +```mlir +#pto.vmi.layout +``` + +这里不再使用抽象 stride 命名。`deinterleaved = 4` 的来源是 `f8 -> f32` 的 VPTO part +conversion contract,不是任意 stride 语义。 + +### Source 2: Narrow / Pack Consumer + +逻辑行为: + +```mlir +%n = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +``` + +如果 `%x` 已经是 `#pto.vmi.layout`,VPTO 可以用 pack/narrow 类 +consumer 把 even/odd f32 parts 合成 contiguous f16 result。这里 layout 的来源不是 producer,而是 +consumer 能直接接受这种 decomposition: + +```text +source layout: + logical f32 value represented as even/odd f32 parts + +consumer: + narrowing pack consumes those parts + +result: + contiguous f16 logical vector +``` + +因此 `deinterleaved` 必须同时登记 producer contract 和 inverse/sink contract。否则 layout 只能 +产生,不能被合法消耗。 + +### Source 3: Same-Width Layout Materialization + +逻辑行为: + +```mlir +%x = pto.vmi.ensure_layout %v + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +这里不新增 surface view op。目标不是产生两个独立 semantic vectors,而是让同一个 logical +vector 继续作为一个 VMI value 存活,只是 physical representation 变成 even/odd parts。IR 中由 +`vmi-layout-assignment` 插入 +`pto.vmi.ensure_layout`,并由 target registry 证明存在 preserving materialization path。VPTO 的 +`vdintlv/vintlv` 类 register rearrangement 可以产生或消费这种 representation。 + +这和 `vcvt` 产生的 even/odd representation 使用同一个 layout: + +```mlir +#pto.vmi.layout +``` + +区别只在 source contract: + +```text +logical behavior: + 同宽 logical vector 保持一个 VMI value,但 physical parts 分别保存 even/odd lanes + +hardware mismatch: + VPTO interleave/deinterleave 指令以两个 physical vreg parts 表达 + +layout: + deinterleaved=2 +``` + +如果 VMI op 的语义本来就是“返回两个独立 vectors”,例如 AoS -> SoA 后用户分别使用 `%x` +和 `%y`,那不需要 layout,直接产生两个 `vmi.vreg`。只有当“一个 logical vector value” +需要以 even/odd parts 长期存活时,才使用 `deinterleaved=2`。 + +### Channel Split / Merge 不是 Register Layout + +channel split/merge 的用户代码通常有两种形态。 + +第一种是把 interleaved data 当作普通 flat vector: + +```text +logical = [r0, g0, b0, a0, r1, g1, b1, a1, ...] +对每个 lane 做同一种逐元素操作 +``` + +这种情况下 `contiguous` representation 就能表达用户语义,不需要 channel layout。 + +第二种是用户按 channel 编程: + +```mlir +%r, %g, %b, %a = pto.vmi.channel_split %rgba + : !pto.vmi.vreg<128xi8> + -> !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, + !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> + +%r2 = pto.vmi.addi %r, %bias_r : !pto.vmi.vreg<32xi8> +%g2 = pto.vmi.addi %g, %bias_g : !pto.vmi.vreg<32xi8> +%b2 = pto.vmi.addi %b, %bias_b : !pto.vmi.vreg<32xi8> +%a2 = pto.vmi.addi %a, %bias_a : !pto.vmi.vreg<32xi8> +%out = pto.vmi.channel_merge %r2, %g2, %b2, %a2 + : !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, + !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> + -> !pto.vmi.vreg<128xi8> +``` + +这里自然的 IR 是多个 semantic VMI values,而不是“一个 VMI value 带 channel layout”。 +目标专用 split/merge 能力是 `channel_split/channel_merge` 的 lowering contract;load/store +memory boundary 的 dist/sink contract 也可以作为等价 lowering path。 + +`channel_split` / `channel_merge` 的语义必须能完全退化成 static shuffle,不能引入额外 +layout 规则。`C` 不需要单独 attr:`channel_split` 的 `C` 来自 result 个数, +`channel_merge` 的 `C` 来自 operand 个数。设 input 有 `N = C * M` 个 logical lanes: + +```text +channel_split(input, C): + out[c][i] = input[i * C + c] + for 0 <= c < C + for 0 <= i < M + +channel_merge(out[0], ..., out[C-1]): + result[i * C + c] = out[c][i] + for 0 <= i < M + for 0 <= c < C +``` + +如果 `N` 不能被 `C` 整除,或者 merge operands 的 logical lane count 不一致,op verifier +必须拒绝。需要 tail 的场景通过外层 mask / valid lane 语义表达,不能让 channel op 自己发明 +padding lane。 + +因此这两个 op 的价值只是 canonical interface:producer 可以直接表达 channel 语义, +外部 import 工具也可以把识别出的 static shuffle pattern canonicalize 成它们;如果没有 +识别或目标没有专用 lowering,保持或退回 `pto.vmi.shuffle` 仍然是等价路径。 +当前 direct VPTO lowering 只接受能形成完整 physical channel groups 的形状:flat contiguous +source/result 与 virtual deinterleaved=C channel layout 必须有相同 physical arity,或已经是 matching +deinterleaved=C layout 的 identity forwarding。arity-changing partial group 需要额外 packing/drop +padding plan,不能直接 lowering。 + +所以 VMI register layout 目录不为 channel-specific representation 引入 layout kind,也不预留 +半成品 layout 语义。本文覆盖的用户形态要么是 flat contiguous vector,要么是多个 channel +semantic value;都不需要“一个 VMI value 带 channel layout”。 + +### Pack / Unpack 不作为长期 Layout + +pack/unpack 的逻辑行为通常是 width conversion 或 memory encoding: + +```text +wide logical vector -> narrow logical vector +narrow memory payload -> wide logical vector +``` + +它们的结果可以是 `contiguous` logical vector;pack/unpack 是 producer/sink/conversion +contract,不是必须长期传播的 register layout。只有当目标 ISA 提供 packed-format arithmetic, +并且 VMI 真的要让 packed representation 跨 compute 存活时,才需要另立 +`packed_bits` layout。本设计没有 packed-format arithmetic source contract,因此 pack/unpack 不进入 +长期 register layout。 + +### 不应成为 Register Layout 的东西 + +以下能力虽然来自 VPTO/VISA,但不是 VMI register layout: + +| 能力 | 原因 | +|---|---| +| `vsldb/vsstb` block stride | 描述 memory address map;result register 可仍是 contiguous representation | +| gather/scatter index | runtime address map,不是 static logical lane 到 physical part 的关系 | +| dynamic `vselr` | runtime permutation,应是 `pto.vmi.permute` op | +| `vsqz/vusqz` compaction | runtime mask 决定 lane destination,应是 `compress/active_prefix_index` op | +| one-shot `vintlv/vdintlv` | 如果只是 boundary conversion,不应提升成长期 layout;若表示一个 VMI value 的 even/odd parts,则归入 `deinterleaved=2` | + +VMI layout 只解决“一个 logical vector value 在寄存器中长期以什么 parts representation 存活” +的问题。memory address、runtime permutation、dynamic compaction 都是其它语义。 + +### Lane Map + +设: + +```text +N = logical lane count +lanesPerDataPart(T) = 256B / sizeof(T) +lanesPerMaskPart(b8) = 256 +lanesPerMaskPart(b16) = 128 +lanesPerMaskPart(b32) = 64 +``` + +`contiguous`: + +```text +chunk = floor(i / lanesPerPart) +lane = i % lanesPerPart +physical[chunk][lane] = logical[i] +``` + +`deinterleaved = K`,其中 `K` 只能是 2 或 4: + +```text +p = i % K +q = floor(i / K) +chunk = floor(q / lanesPerPart) +lane = q % lanesPerPart +physical[p][chunk][lane] = logical[i] +``` + +`deinterleaved=2` 和 `deinterleaved=4` 的 physical value ordering 固定为 part-major: + +```text +p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ..., p(K-1)_chunk0, ... +``` + +所有 verifier、type converter、physical lowering 和 control-flow conversion 必须使用同一套 +ordering。 + +### Physical Arity + +`vmi-to-vpto` 不能按示例猜 physical value 个数,必须由 type + layout 统一推导。 + +对 data vreg: + +```text +lanesPerPart = 256B / sizeof(T) + +contiguous: + chunks = ceil(N / lanesPerPart) + physical values = chunks + +deinterleaved = K: + lanesPerLogicalPart = ceil(N / K) + chunksPerPart = ceil(lanesPerLogicalPart / lanesPerPart) + physical values = K * chunksPerPart +``` + +对 mask: + +```text +lanesPerPart = lanesPerMaskPart(G) +same formula as data, replacing T with mask granularity G +``` + +每个 physical value 的有效 lane 由 lane map 反推: + +```text +contiguous valid: + logical = chunk * lanesPerPart + lane + valid = logical < N + +deinterleaved valid: + logical = K * (chunk * lanesPerPart + lane) + p + valid = logical < N +``` + +padding lane 可以是 poison/undef,但 store、mask-producing op、reduction、scan、compress 和 +layout conversion 都必须显式带着 `valid` 信息,不能只依赖 physical register 宽度。 + +### Broadcast 不作为 Register Layout + +VMI surface 使用 `broadcast` 表达前端语义: + +```mlir +%v = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> +``` + +也就是: + +```text +for i in 0 .. N: + v[i] = x +``` + +这不是 logical lane 到 physical part/lane 的 placement relation,而是一个 value producer +可以延迟 materialize 的事实。`vmi.broadcast` 应保持为 semantic op 或 layout-polymorphic +producer: + +```text +consumer wants contiguous: + materialize scalar into contiguous physical parts + +consumer wants deinterleaved=2: + materialize same scalar into even/odd parts + +consumer wants deinterleaved=4: + materialize same scalar into p0/p1/p2/p3 parts +``` + +因此 broadcast 不进入 `#pto.vmi.layout` 目录。它由 `vmi-layout-assignment` 按 consumer +layout 重物化或下沉到 consumer lowering,而不是作为 `vreg` 的 layout kind。 + +#### Broadcast Materialization + +MLIR SSA value 不能对不同 use 拥有不同 result type。因此 scalar broadcast 的多 layout +适配不是“一个 VMI value 同时带多个 layout”,而是在 layout assignment 中按 use 重物化。 + +semantic VMI: + +```mlir +%b = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> +%u = pto.vmi.addf %a_contiguous, %b + : !pto.vmi.vreg<128xf32> +%v = pto.vmi.addf %a_split, %b + : !pto.vmi.vreg<128xf32> +``` + +如果 `%u` 需要 `contiguous`,`%v` 需要 `deinterleaved=2`,layout assignment 重写为: + +```mlir +%b0 = pto.vmi.broadcast %x + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%u = pto.vmi.addf %a_contiguous, %b0 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b1 = pto.vmi.broadcast %x + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%v = pto.vmi.addf %a_split, %b1 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +physical materialization: + +```text +contiguous: + each physical chunk is filled with scalar x + +deinterleaved=2: + even part is filled with scalar x + odd part is filled with scalar x + +deinterleaved=4: + p0/p1/p2/p3 parts are all filled with scalar x +``` + +这要求 `pto.vmi.broadcast` 标记为 rematerializable,并满足 dominance:clone 位置必须被 scalar +operand `%x` dominate。跨控制流时,如果 scalar operand 可在各 predecessor/body 内使用, +优先在 consumer 所在 block 重物化;否则必须在控制流 join 处选择一个具体 layout 并 materialize。 + +这个规则只对 scalar-to-vector broadcast 是零语义风险的。低 rank vector 到高 rank vector 的 +broadcast 可能需要真实 lane replication/shuffle,不能默认按任意 consumer layout 免费重物化; +这类 broadcast 必须携带 broadcast map,并按普通 VMI op 做 layout assignment。 + +VMI register layout 目录因此是: + +```text +contiguous +deinterleaved=2 +deinterleaved=4 +``` + +channel split/merge、pack/unpack、memory stride、dynamic permutation、dynamic compaction +不在目录内。它们分别由 VMI semantic op、conversion、memory access plan、`vmi.permute`、 +`vmi.compress/active_prefix_index` 承接。 + +## Pipeline + +### 1. VMI Producer Boundary + +VMI core pipeline 从合法 VMI semantic IR 开始。Producer 可以是 TileLang/PTO lowering、手写 VMI +测试或其它外部 import 工具,但 producer 不属于 VMI core pipeline。 + +进入 VMI boundary 时必须满足: + +```text +all logical vector semantics are represented by pto.vmi semantic ops +all VMI data/mask values use surface VMI type without layout +no physical VPTO op is introduced +no hidden layout/mask/type side table is required +scalar/tensor/debug/transform boundary has already been handled by producer +``` + +该 boundary 需要 verifier gate。它验证 VMI IR 自身完整,不验证某个外部 source dialect 的 +coverage。 + +### 2. `vmi-layout-assignment` + +该阶段把无 layout VMI type 转换成 layout-assigned VMI type,推荐实现为独立 pass: + +```mlir +!pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +!pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +layout assignment 做三件事: + +1. 为每个 producer 选择 natural layout。 +2. 为每个 consumer 协调 operand/result layout。 +3. 在必要处插入: + +```mlir +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +``` + +layout assignment 不是局部 pattern 贪心插 conversion,而是约束求解: + +```text +nodes: + every VMI SSA value + block arguments and region/function results + rematerializable producers such as scalar broadcast/iota/constant + +allowed layouts: + contiguous + deinterleaved=2 + deinterleaved=4 + filtered by element type, mask granularity, op capability, and target registry + +hard constraints: + op verifier constraints, such as same-layout elementwise operands + data/mask layout alignment for predicated ops + control-flow block argument/yield/call signature equality + external ABI layout boundary + source/sink contracts for width conversion, load/store, pack/narrow + +soft costs: + natural producer layout preference + ensure_layout materialization cost from target registry + store/load sink cost + rematerialization cost for broadcast/iota/constant + scratch/guarded fallback resource cost +``` + +求解顺序: + +```text +1. Build constraints for the whole region/SCC, including control-flow and call edges. +2. Propagate impossible layouts and required mask granularities. +3. Choose a minimum-cost layout for each node. +4. Use deterministic tie-break: prefer existing natural layout, then contiguous. +5. Insert ensure_layout/ensure_mask_layout or rematerialize producers at chosen use sites. +6. Re-run verifier gates; no hidden side table may be needed to interpret the result. +``` + +如果 hard constraints 冲突,或所有 legal paths 都缺 target capability/resource,报 +`VMI-LAYOUT-CONTRACT` 或更具体 diagnostic。diagnostic payload 必须列出 conflict value、producer +natural layout、consumer required layouts、available conversion paths 和被禁用的 fallback。 + +#### Consumer Layout Demand + +“consumer 需要某个 layout”不是前端语义要求,而是 layout assignment 为了让 operands/results +的 lane-map 对齐并减少 layout conversion 选择的共同 representation。 + +典型例子: + +```mlir +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + +%b = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<128xf32> +``` + +`%w` 的 logical 语义是 `128xf32`,但 VPTO `f16 -> f32` 的自然 lowering 产生 even/odd +两路 parts: + +```text +w_even[i] = extf(a[2*i]) +w_odd[i] = extf(a[2*i+1]) +``` + +因此 `%w` 的 natural layout 是: + +```mlir +#pto.vmi.layout +``` + +`addf` 是 layout-polymorphic elementwise op。它有两个合法选择: + +```text +choice A: + materialize %w to contiguous + materialize broadcast to contiguous + do one contiguous add sequence + +choice B: + materialize broadcast directly as deinterleaved=2 + do add on even parts and odd parts separately + keep result as deinterleaved=2 +``` + +choice B 通常更便宜,因为不需要把 `%w_even/%w_odd` 先 interleave 成 contiguous。broadcast +能直接适配 `deinterleaved=2`,是因为它的 logical lanes 全部等于同一个 scalar: + +```text +b_even = [scalar, scalar, ...] +b_odd = [scalar, scalar, ...] +``` + +所以这里说 `addf` consumer “需要” `deinterleaved=2`,准确含义是: + +```text +layout assignment 选择 deinterleaved=2 作为 addf 的共同 operand/result representation, +因为其中一个 operand 的 natural layout 已经是 deinterleaved=2,并且 broadcast 可零语义风险地重物化到该 layout。 +``` + +### 3. `vmi-to-vpto` + +该阶段把 layout-assigned VMI type 做 1:N physical type conversion,推荐实现为独立 pass: + +```text +!pto.vmi.vreg<128xf32, contiguous> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<128xf32, deinterleaved=2> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<256xf32, deinterleaved=4> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.mask<128xb32, deinterleaved=2> + -> !pto.mask, !pto.mask +``` + +需要 internal projection/materialization op: + +```mlir +%p0, %p1 = pto.vmi.unpack %v + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%v = pto.vmi.pack %p0, %p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +`pack/unpack` 不是新的 layout carrier,只是 layouted `vmi.vreg` 到 physical VPTO parts 的 +projection/materialization。 + +`unpack` 必须能作用在任意 SSA value 上,不能依赖 defining op。VMI value 可以来自 block +argument、`scf.if` result、loop iter_arg、function argument 或 call result;这些 value 没有 +可 look-through 的 layout materialization defining op。 + +`pack/unpack` 的 operand/result 个数必须使用 Physical Arity 公式推导。非整 tile 时,最后一个 +chunk 的 padding lane 仍属于 physical value,但不属于 logical value。 + +### Layout Conversion Materialization + +`pto.vmi.ensure_layout` / `pto.vmi.ensure_mask_layout` 是 logical-value-preserving conversion: + +```text +for every logical lane i: + dst.logical[i] = src.logical[i] +for every padding lane: + dst padding remains unobservable +``` + +source/result layout 完全相同时,`ensure_layout` / `ensure_mask_layout` 是 identity forwarding; +即使存在 partial/tail physical chunk,也不需要 target materialization path。source/result layout +不同时才需要 registry 证明 preserving conversion 及其 full-chunk/tail 处理策略。当前 direct path +允许 equal-arity partial/tail conversion:source/result 的 physical arity 必须相同,且两边都能组成完整 +contiguous/deinterleaved=2/4 `intlv` materialization group;arity-changing partial conversion 和 uneven +deinterleaved groups 继续报 unsupported。 + +合法 materialization path 必须来自 target registry: + +```text +same layout: + no-op + +contiguous <-> deinterleaved=2: + direct interleave/deinterleave register op, load/store dist sink/source, + or scratch/ordered fallback + +contiguous <-> deinterleaved=4: + direct 4-way layout sink/source, proven staged 2-way sequence, + or scratch/ordered fallback + +deinterleaved=2 <-> deinterleaved=4: + convert through contiguous only if both legs have preserving paths, + otherwise use scratch/ordered fallback or report VMI-LAYOUT-CONTRACT +``` + +`deinterleaved=4` 不能默认假设“两次二路 interleave”就是正确 materialization。只有当 staged +sequence 的 lane map 被 registry 证明等价于: + +```text +logical = 4 * lane + p +``` + +才允许使用。否则必须选择 store sink、scratch buffer 或 diagnostic。 + +### Verifier Gates + +每个 pipeline 边界都必须有 hard verifier,不能把残缺 IR 留给后续 pass 猜测: + +```text +at VMI producer boundary: + every logical vector value is represented by !pto.vmi.vreg / !pto.vmi.mask + every logical vector operation is represented by pto.vmi semantic op + no physical VPTO op has been introduced + no hidden layout/mask/type side table is required to interpret a value + +after vmi-layout-assignment: + every !pto.vmi.vreg / !pto.vmi.mask has #pto.vmi.layout + layout kind is one of contiguous/deinterleaved=2/deinterleaved=4 + mask granularity matches each consumer + branch operands, block arguments, function arguments/results, and yields agree on layout + no hidden layout/mask/type side table is required to interpret a value + +before vmi-to-vpto: + every pto.vmi.ensure_layout / ensure_mask_layout has a registered preserving materialization path + every fallback path has resource decision and dominance/lifetime proof + +after vmi-to-vpto: + no pto.vmi op or type remains + no UnrealizedConversionCastOp remains + no pto.vmi.pack/unpack/ensure_* helper remains + every physical value arity matches the Physical Arity helper +``` + +layout、mask、valid-lane 和 physical arity 信息必须存在于 IR type/attr/op operand 中,或可由它们 +纯函数推导;不能依赖 C++ side table。违反这些 gate 时使用 `VMI-PASS-INVARIANT` 或更具体的 +diagnostic,例如 `VMI-LAYOUT-CONTRACT`、`VMI-MEMORY-ACCESS`、`VMI-RESIDUAL-OP`。 + +## Layout Assignment 规则 + +### Elementwise + +same-layout operands: + +```text +vmi.addf/vmi.mulf/vmi.cmpi/vmi.select + fan out per physical part + result keeps operand layout +``` + +different-layout operands: + +```text +choose consumer-demanded layout +insert ensure_layout for other operands +vmi.broadcast can rematerialize in consumer-demanded layout +``` + +### Width Conversion + +典型 natural layout: + +```text +vmi.extf 128xf16 -> 128xf32: + source contiguous f16 + result deinterleaved=2 f32 + +vmi.extf 256xf8 -> 256xf32: + source contiguous f8 + result deinterleaved=4 f32 + +vmi.truncf 128xf32 -> 128xf16: + source may be deinterleaved=2 f32 + result contiguous f16 if pack/store sink requires contiguous + +vmi.truncf 256xf32 -> 256xf8: + source may be deinterleaved=4 f32 + result contiguous f8 if pack/store sink requires contiguous +``` + +Direct `vcvt` lowering 可以覆盖同一 contract 下的 partial/tail case:`extf` 的 logical lanes +必须仍然装进一个 contiguous narrow source physical chunk,并自然产生 deinterleaved=2/4 result; +`truncf` 的 deinterleaved=2/4 source parts 必须能 pack 成一个 contiguous narrow result chunk。 +这些路径允许 VPTO 对 padding lanes 执行 conversion,但 padding 只能流向 result padding lanes, +不能变成 logical result。 + +Mask granularity assignment 把 surface `mask` 转成 concrete +`mask`。consumer 决定所需 granularity: + +```text +f16 op consumes mask +f32 op consumes mask +f8 op consumes mask +``` + +如果 data 从 f16 扩到 f32,后续 f32 consumer 需要: + +```mlir +!pto.vmi.mask +``` + +不能继续复用 `mask`。 + +mask-producing op 的 granularity 不是 producer 固有属性: + +```text +vmi.create_mask / constant_mask: + logical predicate producer; granularity chosen by users + create_mask 的 logical prefix 语义不受目标 PAT_VL token 集合限制; + unsupported PAT_VL count 可以用 pto.plt_b* materialize + constant_mask 的 non-prefix chunk 用 prefix 差分和 predicate boolean ops materialize + +vmi.cmpf/cmpi: + result logical lane count follows compared data + concrete granularity chosen by mask consumers, not by compare element type alone + +multi-use mask: + choose one concrete granularity for the original SSA value + insert ensure_mask_granularity or rematerialize cheap mask producers per use +``` + +`ensure_mask_granularity` 必须 preserve logical predicate lane `mask[i]`。当前 direct lowering 对 +concrete `b8/b16/b32` granularity 使用 `pto.punpack` 做 widening,使用 `pto.ppack` 加 `pto.por` +做 narrowing,并按需要串联相邻级别完成 `b8 <-> b32`。如果目标缺少 predicate rearrangement 或 +granularity conversion,报 `VMI-LAYOUT-CONTRACT`,不能把 b16/b32 mask 当成同一 physical bit +pattern 直接复用。 + +### Predication + +Region-style mask 不作为长期 region op 保留到 VPTO lowering。producer 必须把 mask thread 到 +具体 VMI op: + +```text +masked load/store: + use pto.vmi.masked_load / pto.vmi.masked_store + +masked arithmetic with passthru: + compute candidate result + merge with passthru by pto.vmi.select(mask, candidate, passthru) + +masked reduction/scan: + inactive and padding lanes are excluded from the logical iteration +``` + +如果一个 masked op 的 inactive lane 语义要求“不读内存”或“不执行有副作用操作”,不能用 +full op + select 伪装;必须使用对应 masked VMI op、ordered fallback,或报 target capability +diagnostic。 + +### Memory Ops + +VMI memory op 表达 memory semantics,不表达 register layout。lowering 先构造 access plan: + +```text +base +logical lane count +logical_shape attr, if any +lane-to-address map +contiguity +block-strided row classification +read/write validity mask +padding plan +footprint safety proof +target OOB capability +``` + +memory access map 不是 register layout。比如 `tile_read` 的 memref stride 可以识别 +block-strided rows,并选择 `vsldb`,但 result `vmi.vreg` 的 register layout 仍由 +layout assignment 决定。 + +Producer-specific packed element view 不进入 VMI type。它们必须在 VMI memory op 之前规范化为 +element memref + access map: + +```text +memref> + -> base element type T + -> logical address = original index * K + vector_lane +``` + +normalization 必须保留 offset、stride、alignment、memory space 和 alias 信息。无法证明等价 +element view 时,报 `VMI-MEMORY-ACCESS`,不能把 packed element memref 伪装成 contiguous VMI +load/store。 + +direct path examples: + +```text +contiguous full-safe: + vlds/vsts + !pto.ptr source/destination must be UB-backed; memref source/destination + must either have unknown memory space at this stage or explicitly use + #pto.address_space + +32B block-strided rows with block-uniform mask: + vsldb/vsstb + +interleave/deinterleave boundary: + vldsx2/vstsx2 dist or explicit rearrangement + +indexed memory: + gather/scatter if inactive and duplicate-index semantics match +``` + +GM-backed VMI memory is semantic input, not a direct vector load/store target. +Current `vmi-to-vpto` direct memory lowering emits `pto.vlds`, `pto.vldsx2`, +`pto.vsts`, or `pto.vstsx2`; those VPTO ops operate on UB-backed vector memory. +If a `pto.vmi.load/store/tile_read/tile_write` still names GM at this stage, +the missing step is an explicit memory movement/materialization plan, scratch +plan, or UB view normalization. Otherwise the pass must report `VMI-UNSUPPORTED` +instead of silently producing illegal VPTO. + +### Control Flow + +VMI layouted type 可以跨 internal control flow,但 public ABI 不允许 layout leak。 + +MLIR conversion framework 可以做 region/block/signature 的 structural type conversion,但它不会 +自动决定 layout。`vmi-layout-assignment` 必须先把每个 block argument、region yield、branch +operand 和 call boundary 的 layout 固定下来,再交给 `vmi-to-vpto` 做 1:N type conversion。 + +`scf.if` join: + +```text +if all incoming layouts equal: + keep that layout +else: + choose consumer-demanded layout, otherwise contiguous + insert ensure_layout / ensure_mask_layout before yield +``` + +`scf.for` loop-carried value: + +```text +init layout == iter_arg layout == yield layout == loop result layout +``` + +如果 loop body repeatedly consumes deinterleaved=2/deinterleaved=4,优先保持该 natural layout;如果只有 loop +exit 需要 contiguous,则在 exit 后转换,不在 backedge 每轮转换。 + +`cf.br` / `cf.cond_br` block arguments: + +```text +target block argument has one chosen layout +each predecessor operand is converted to that layout before branch +``` + +function boundary: + +```text +internal VMI functions: + function argument/result layout is part of layout assignment + all callsites and returns must agree with the specialized signature layout + +external/public ABI: + must not expose #pto.vmi.layout + materialize to memory, scalar ABI, or final physical PTO ABI before crossing boundary +``` + +recursive or mutually recursive VMI functions require SCC fixed-point layout assignment. If a stable signature +layout cannot be found without inserting conversion on every cycle edge, choose `contiguous` at the function +boundary and keep deinterleaved layouts inside the function body. + +## VMI Op Families + +本节列出 VMI 必须拥有的 semantic op。assembly form 可在 ODS 中微调,但语义边界应保持。 +表中用 `/` 写在一起的名字表示多个独立 op,不表示一个 variadic opcode。去重后,正式 +semantic op 数量是 75 个。 +`ensure_layout`、`ensure_mask_layout`、`ensure_mask_granularity`、`pack`、`unpack` 是内部 +layout/materialization helper,不计入 semantic op;如果把 helper 也算作 VMI op,总数是 80 个。 + +该总表描述目标 semantic surface,不等价于当前第一批实现清单。当前 implementation slice +以 `docs/designs/vmi-implementation-manual.md` 的 Slice 1 为准;例如 `pto.vmi.from_elements` +虽然属于目标 construction family,但没有 scalar lane insert、vreg immediate 或 scratch +materialization plan 前不能宣称 direct lowering 已支持。 + +```text +construction: 6 +memory: 10 +arithmetic/conversion: 36 +permutation/mask/reduction/channel: 23 +semantic total: 75 +internal helpers: 5 +total including helpers: 80 +``` + +### Construction + +| Op | 语义 | +|---|---| +| `pto.vmi.constant` | logical constant vector,layout assignment 决定 materialization | +| `pto.vmi.broadcast` | scalar 或低 rank value broadcast 到 `vreg` | +| `pto.vmi.iota` | 从 scalar base 生成 logical lane index/value vector | +| `pto.vmi.from_elements` | 按 logical lane order 构造 | +| `pto.vmi.create_mask` | prefix 或 logical-shape mask | +| `pto.vmi.constant_mask` | static logical predicate mask, including non-prefix masks | +| `pto.vmi.mask_and/or/xor/not` | logical predicate elementwise operation | + +### Memory + +```mlir +%v = pto.vmi.load %base[%idx] + : memref -> !pto.vmi.vreg<128xf32> + +pto.vmi.store %v, %base[%idx] + : !pto.vmi.vreg<128xf32>, memref + +%v = pto.vmi.masked_load %base[%idx], %mask, %passthru + : memref, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.masked_store %v, %base[%idx], %mask + : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> + +%g = pto.vmi.gather %base[%indices], %mask, %passthru + : memref, !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.scatter %v, %base[%indices], %mask + : !pto.vmi.vreg<128xf32>, memref, + !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred> + +%e = pto.vmi.expand_load %base[%idx], %mask, %passthru + : memref, !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + +pto.vmi.compress_store %v, %base[%idx], %mask + : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> +``` + +`masked_load` 的 inactive lane 不能产生 memory read。full load + select 只有在 inactive +lane 地址 safe-readable 时才合法。 +当前直接 lowering 只覆盖 contiguous result/passthru/mask:full physical chunks 直接 `vlds + vsel`; +partial/tail chunks 必须先证明完整 physical read footprint safe-readable,否则报 `VMI-UNSUPPORTED`。 +在第一阶段的矩阵 quant/dequant lowering 中,默认假设 UB 中的行数据按元素连续,tail load 可以安全读满 +当前物理 vreg;tail 的对外写入效果仍由 `pto.vmi.create_mask` + `pto.vmi.masked_store` +约束。严格 no-read tail 不是这个默认路径的语义,后续通过 stable gather 模式承接:该模式应把 +contiguous tail masked load 转为 `VGATHER2 + Pg` 风格的 per-lane non-faulting load。当前 +`vmi-to-vpto` 只预留 `enable-stable-gather-masked-load` 开关;开关打开且遇到 +`pto.vmi.masked_load` 时必须给 TODO diagnostic,不能退化成普通 `vlds + vsel`。 + +普通 `vmi.store` 和 `vmi.masked_store` 的 contiguous tail 可以用 true predicate store 承接: +full physical chunk 使用 all-true mask 或用户 mask,最后一个 partial chunk 使用 prefix valid-lane +mask;因此普通 `vmi.store` direct lowering 要求 value element width 能对应 +`pto.mask`。`masked_store` 先把用户 mask 与 valid-lane mask 做 logical AND。 +deinterleaved=2/4 tail store/masked_store 只有在每个 deinterleaved part 的 physical chunk 数相同、可先组成完整 +`vintlv/pintlv` group 并 materialize 成 contiguous chunks 时才直接支持;materialized 后 active +lane 为 0 的 padding-only chunk 不发 store。load padding 仍需要独立的 access plan,不能通过未受保护的 +full-footprint memory op 偷跑。 + +`gather/scatter` 使用 logical lane order 解释 `%indices`,index 单位和 memref element type +一致。`gather` inactive lane 返回 `%passthru[i]` 且不能读内存。`scatter` inactive lane 不能写 +内存;如果 active lanes 可能写同一地址,direct VPTO lowering 必须证明目标语义与 logical +lane order 等价,否则使用 ordered fallback 或报 `VMI-MEMORY-ACCESS`。 + +当前 `gather` direct lowering 覆盖一个保守子集: + +```text +source: + !pto.ptr + +layout: + result / indices / mask / passthru all contiguous + all physical chunks are full, so padding lanes cannot trigger memory reads + +type: + T is 32-bit element type + indices are signless or unsigned i32 + mask granularity is b32 + +lowering: + gathered = pto.vgather2_bc source, indices, mask + result = pto.vsel gathered, passthru, mask +``` + +`VGATHER2_BC` false predicate lanes do not read memory but produce zero result lanes. VMI `gather` requires false +lanes to preserve passthru, so the `vsel` is semantically required, not an optimization artifact. `f16/b16/f8/i8` +gather, tail gather, non-contiguous layout, memref/gm source, and fallback through guarded scalar load or scratch are +future target-capability paths. + +当前 `scatter` direct lowering 只在 VMI IR 携带显式 no-conflict proof 时启用: + +```mlir +pto.vmi.scatter %v, %base[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> +``` + +`indices_unique` 的含义是:所有 active logical lanes 的 `%indices` 两两不同。这个 proof 可以来自 +producer 的静态分析、前端语义或上游 canonicalization;VMI lowering 不从 runtime 值猜测它。direct +path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、32-bit value +element、i32 indices 和 b32 mask。没有 `indices_unique` 时,`vmi-to-vpto` 必须诊断,而不能直接发 +`VSCATTER`,因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于 VMI +logical lane order。 + +`expand_load/compress_store` 表达 masked contiguous stream,不是 arbitrary indexed access: + +```text +expand_load: + k = 0 + for i in 0 .. N: + if mask[i]: + result[i] = base[idx + k] + k += 1 + else: + result[i] = passthru[i] + +compress_store: + k = 0 + for i in 0 .. N: + if mask[i]: + base[idx + k] = value[i] + k += 1 +``` + +Current direct `expand_load` lowering supports two paths. The first is the +degenerate all-active case: + +```text +mask == all_true => expand_load(base[idx], mask, passthru) == load(base[idx]) +``` + +The accepted mask must be statically proven all active through +`pto.vmi.create_mask` with constant `active_lanes >= N`, or a dense all-true +`pto.vmi.constant_mask`. The result, passthru, and mask layouts must be +contiguous. Partial/tail chunks still need the same safe full-read proof as +ordinary `vmi.load`; otherwise the direct path reports `VMI-UNSUPPORTED`. + +The second direct path covers one full 32-bit UB physical chunk with a runtime +mask: + +```text +base' = pto.addptr base, idx +indices = pto.vusqz(zero_i32_carrier, mask) +gathered = pto.vgather2_bc base', indices, mask +result = pto.vsel gathered, passthru, mask +``` + +It requires contiguous result/passthru/mask layout, 32-bit element type, b32 +mask granularity and one full physical chunk. Multi-chunk runtime masks need a +cross-chunk prefix-count carry; f16/b16/f8/i8 need a gather packing contract. +Unsupported cases still require guarded load, scratch fallback, or diagnostic, +and must not be lowered as a plain full load. + +Current direct `compress_store` lowering is intentionally narrower than the +surface semantics. It requires contiguous value/mask layout, exactly one full +physical chunk, and a UB `!pto.ptr` destination. The direct sequence is: + +```text +store_base = pto.addptr base, idx +sqz = pto.vsqz value, mask +align0 = pto.init_align +align1 = pto.vstur align0, sqz, store_base, "POST_UPDATE" +pto.vstar align1, store_base +``` + +The paired `vstur` consumer is what makes the later VPTO LLVM emitter select +`VSQZ #st=1`; emitting `vsqz` without that store consumer is only register +compress. Full physical chunk is required in this first path because padding +mask lanes must not be squeezed into memory. Multi-chunk `compress_store` +needs cross-chunk compaction and SQZN/store-state planning; deinterleaved +layouts need logical lane order reconstruction before the store chain. + +### Index And Address Contract + +`!pto.vmi.vreg` 是 logical index vector,不是 physical address vector。进入 VPTO 前, +index 必须按 target registry legalize 成目标支持的整数宽度: + +```text +index legalization: + choose target index bitwidth + prove every lane value fits, or insert preserving extend/trunc/check sequence + preserve signedness required by the consuming op +``` + +memory op 的 index 单位是 memref element,不是 byte。byte address 由 memref layout、element +size、base offset 和 lane index 共同计算: + +```text +logical element offset -> memref affine/strided map -> byte address +``` + +`gather/scatter` 的 `%indices`、`expand_load/compress_store` 的 active-prefix offset、`iota` 生成 +的 lane index 都必须在同一套 address unit 下解释。不能把 element index 直接当 byte offset,也 +不能在没有 range proof 时把 `index` 静默截断成较窄整数。 + +`active_prefix_index(mask)` 返回当前 lane 之前的 active lane 数: + +```text +idx[i] = popcount(mask[0 .. i)) +``` + +因此 `expand_load/compress_store` active lane 使用 `base + idx[i]`。如果目标缺少 prefix-popcount +或 index-vector lowering,必须选择 index-buffer/guarded fallback,或报 `VMI-FALLBACK-RESOURCE` +/ `VMI-LAYOUT-CONTRACT`。 + +`tile_read/tile_write` 承接 transfer-style padding 和 multi-dimensional access semantics: + +```mlir +%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> + +pto.vmi.tile_write %tile, %view[%c0, %c0], %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : !pto.vmi.vreg<64xf32>, memref<8x8xf32, strided<[?, 1], offset: ?>>, + !pto.vmi.mask<64xpred> +``` + +`tile_read/tile_write` 只承接 memref memory semantics。producer 的 transfer-style read/write 如果作用在 +tensor source/destination 上,必须在进入 VMI 前 bufferize 成 memref access plan,或退出 PTO +路线。tensor write-back style 语义是产生新 tensor,不是对 memref 的 memory effect;不能把它 +伪装成 `pto.vmi.tile_write`。未处理的 tensor transfer 报 `VMI-TENSOR-BOUNDARY`。 + +`tile_read` invalid lane 的 result 必须等于 padding,不是后继 op 的 inactive lane。 + +`tile_read` lowering 必须先构造三个对象: + +```text +validMask(result lane): + logical lane is inside result shape + and explicit transfer mask maps to true + and source address is in bounds + +paddingValue(result lane): + scalar padding: same value for every invalid lane + vector-element padding: select element by suffix coordinate + broadcast/permuted padding: apply the same result-lane map as data + +safeReadProof: + proves the actual physical load footprint is safe-readable + independent from validMask +``` + +`validMask=false` 只说明 result lane 应等于 padding,不说明该 lane 的 source address 可以被读。 +因此 `tile_read` 的 preserving lowering 决策是: + +```text +safeReadProof == full and validMask all-true: + direct load + +safeReadProof == full and validMask not all-true: + loaded = full load + pad = materialize paddingValue in result layout + result = select(validMask, loaded, pad) + +target has true masked/non-faulting load: + loaded = masked load with inactive lanes not read + pad = materialize paddingValue in result layout + result = select(validMask, loaded, pad) unless inactive result is already padding + +safeReadProof != full: + split full-safe and partial paths, or + fill scratch with paddingValue, guarded-copy only valid lanes, then load scratch, or + use guarded scalar/vector fallback +``` + +First implementation stage note: + +```text +The padding-preserving branches above are semantic requirements for the full +design, but they are not part of the first-stage VMI implementation. The first +stage may lower only all-valid direct reads, or physical-tail reads whose extra +lanes are outside the logical VMI value and remain unobservable. If invalid +logical lanes require transfer_read paddingValue materialization, true +masked/non-faulting load, scratch, or guarded fallback, lowering must stop with +the implementation diagnostic code VMI-UNSUPPORTED instead of emitting an +approximate full load. +``` + +如果所有 preserving paths 都因 target capability 或 option 被禁用,报 `VMI-MEMORY-ACCESS`, +payload 必须指出缺的是 unsafe partial `tile_read` padding-preserving path。 + +`tile_write` 没有 padding value,但有 write-valid mask: + +```text +writeMask(source lane): + logical lane is inside source shape + and explicit transfer mask maps to true + and destination address is in bounds +``` + +`writeMask=false` 的 lane 不能产生 memory effect。只有 full physical footprint safe-writable 且 +writeMask all-true 时,才能使用 predicate-ignored store。partial write 必须使用 true masked +store、split/guarded fallback、scatter-like fallback,或报 `VMI-MEMORY-ACCESS`。 +当前 direct `vmi.tile_write` 只覆盖 flat contiguous tail:最后一个 partial chunk 使用 prefix +valid-lane predicate 发 `vsts`,同样要求 value element width 能对应 `pto.mask`。 +deinterleaved=2/4 tail 只有在能先完整 materialize 到 contiguous +chunks 时直接支持,padding-only materialized chunk 不发 store;带 transfer mask coordinate remap 的 +tile write 仍必须走独立 access plan。 + +explicit transfer mask 的坐标属于 transfer access space,不一定等于 flattened result/source lane +坐标。non-minor-identity transfer 必须先做 predicate coordinate remap;缺少 remap capability 时, +diagnostic 必须点名 transfer mask coordinate remap,而不是泛化成普通 memory failure。 + +### Arithmetic And Conversion + +VMI 不复用外部 elementwise arithmetic op。需要定义对应 VMI op: + +| Semantic | VMI op | +|---|---| +| float binary | `pto.vmi.addf/subf/mulf/divf/minf/maxf` | +| float unary | `pto.vmi.negf/sqrt/exp/ln/relu` | +| integer binary | `pto.vmi.addi/subi/muli` | +| bitwise/shift | `pto.vmi.andi/ori/xori/not/shli/shrui` | +| fused multiply-add | `pto.vmi.fma` | +| float casts | `pto.vmi.extf/truncf` | +| bitcast | `pto.vmi.bitcast` | +| compare/select | `pto.vmi.cmpf/cmpi/select` | + +Integer div/rem, arithmetic right shift, integer casts, int-float casts, and +index casts are intentionally not in the current VMI surface. They need +explicit signedness, rounding, saturation, overflow/remainder, and VPTO target +contracts before ODS ops are introduced. + +producer constant 转成 `pto.vmi.constant`,包括 dense、splat 和 rank-0 logical vector。 +constant 的 element type、shape、splatness 和 poison/undef 属性如果存在,必须保留到 VMI +constant attr;padding physical lane 仍按 VMI padding rule 处理,不能把 padding lane 当成用户 +constant lane。 + +当前 VPTO direct lowering 只把 scalar broadcast 和 splat constant materialize 成 +`pto.vdup`。这条路径与逐元素 op 一样要求 physical element width 能对应 +`pto.mask`;其它 element type 或非 splat constant 必须先有明确的 materialization +contract,否则报 `VMI-UNSUPPORTED`。 + +VMI arithmetic op 必须保留原 `arith` op 的 numeric contract: + +```text +floating point: + fastmath flags + rounding mode, if present + NaN / signed-zero / inf behavior implied by flags + +integer: + signedness of div/rem/compare/extend + overflow flags such as nsw/nuw when present + truncation and extension width rules + +compare/select: + cmpf/cmpi predicate + select condition mask granularity and layout +``` + +lowering 不能因为 VPTO 有更快指令就加强或放松这些属性。比如没有 fastmath 允许时,`fma` +不能拆成 `mulf + addf`,也不能把 `mulf + addf` 合成 `fma`;带 `nsw/nuw` 的 integer op +可以利用 flag 做优化,不带 flag 的 op 必须保持 wraparound/defined overflow 语义。 + +`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 layout 下 bit grouping +physically adjacent、且每个对应 physical chunk 的 logical bit footprint 相同时才能 direct; +padding bits 只能流向 result padding bits。否则需要 layout conversion、scratch materialization +或 target capability diagnostic。 + +当前 VPTO direct lowering 对逐元素算术、逻辑、比较和 select 还有一条共同硬约束:物理 element +width 必须能对应到 `pto.mask`。因此 VMI 语义层可以承载 `index` 或 `f64` +这类类型,但在没有独立 lowering contract 前,`vmi-to-vpto` 必须报 `VMI-UNSUPPORTED`, +不能让 OneToN conversion 或 residual gate 隐式失败。 + +这条共同约束不是唯一约束。某些目标 VPTO/VISA op 还有自己的 element type contract, +必须在 `vmi-to-vpto` preflight 中单独检查。当前 direct lowering 明确承诺: + +```text +addf/subf/mulf: f16/bf16/f32 +divf: f16/f32 +minf/maxf: f16/bf16/f32 +negf/absf: f16/f32 +sqrt/exp/ln: f16/f32 +relu: f16/f32 +absi: signless/signed i8/i16/i32 +cmpf: f16/bf16/f32 +cmpi: signless/signed/unsigned i8/i16/i32 +``` + +因此 bf16/f8 虽然可能是合法 VMI float-like type 且能 materialize b16/b8 predicate mask, +但只要目标 direct op 不承诺该 element type,`vmi-to-vpto` 就必须先报 +`VMI-UNSUPPORTED`,直到定义对应 materialization 或 VPTO 目标能力。 + +当前 direct lowering 将 `pto.vmi.fma %lhs, %rhs, %acc` 映射为每个 physical part 上的 +`pto.vmula %acc_part, %lhs_part, %rhs_part, %all_true_mask`。该路径只承诺 f16/bf16/f32 +floating-point fused multiply-add;整数 multiply-accumulate、带 rounding/fastmath 变体或需要 +不同 accumulator 精度的形式必须单独建模,不能复用这个 op 偷换语义。 + +### Permutation, Mask, Reduction, Channel + +| Semantic | VMI op | +|---|---| +| static lane map | `pto.vmi.shuffle` | +| dynamic indexed lane map | `pto.vmi.permute` | +| logical interleave/deinterleave | `pto.vmi.interleave/deinterleave` | +| shape metadata change | `pto.vmi.shape_cast/reshape/transpose` | +| subvector update | `pto.vmi.slice/insert_slice/insert_element` | +| predicate logic | `pto.vmi.mask_and/or/xor/not` | +| prefix active index | `pto.vmi.active_prefix_index` | +| register compaction/expansion | `pto.vmi.compress/expand` | +| reduction/scan | `pto.vmi.reduction/scan` | +| contraction | `pto.vmi.contract/outerproduct` | +| channel split/merge | `pto.vmi.channel_split/channel_merge` | + +`pto.vmi.shuffle` 表达完整 static lane map。当前 VPTO direct lowering 先识别 physical chunk +forwarding:每个 result physical chunk 的所有非 padding lanes 必须来自同一个 source chunk, +且 source lane number 等于 result lane number;result padding lanes 不参与证明,forward 过来的 +物理 padding lanes 仍然不可观察。否则在每个 result physical chunk 都来自同一个 source chunk、 +result chunk 没有 padding lane、且 source lane index 是 ASC/DESC 连续序列时,用 `pto.vci` +生成 index vector 并发 `pto.vselr`。任意非 affine permutation、以及需要 tail lane 重排但无法安全 +materialize tail index vector 的场景,仍然需要通用 index-vector materialization、scratch fallback +或 target capability diagnostic。 + +`channel_split/channel_merge` 是 PTO-specific semantic op。它们表达用户按 channel 编程时的 +多个 logical VMI values,不能降格成 +`#pto.vmi.layout` kind。它们必须拥有 static shuffle 等价定义,canonicalization 可以双向进行: +识别出的 shuffle pattern 可以变成 channel op,channel op 也可以合法展开回 shuffle。 +Direct lowering 还必须证明 physical group 完整;否则即使 logical shuffle 语义成立,也要报 +target capability/materialization diagnostic,而不是让 OneToN pattern 在中途失败。 + +### Internal Layout Helpers + +这些 op 只允许存在于 VMI lowering 的中间阶段,不能作为 VMI semantic surface,也不能残留到 +physical VPTO 之后: + +| Op | 语义 | +|---|---| +| `pto.vmi.ensure_layout` | data vreg layout-preserving conversion | +| `pto.vmi.ensure_mask_layout` | mask layout-preserving conversion | +| `pto.vmi.ensure_mask_granularity` | logical predicate-preserving granularity conversion | +| `pto.vmi.unpack` | layouted VMI value projection to physical VPTO parts | +| `pto.vmi.pack` | physical VPTO parts materialized as one layouted VMI value | + +`active_prefix_index` 语义是: + +```text +idx[i] = popcount(mask[0 .. i)) +``` + +VMI surface 不暴露 VPTO `vusqz` 的无意义 source operand;需要 type/ABI carrier 时在 +`vmi-to-vpto` late materialize。 + +当前直接 lowering 只覆盖 contiguous 单物理 chunk。这个 case 可以用 `pto.vusqz` 精确承接: +`vmi-to-vpto` 先 materialize 一个 zero vreg 作为 VPTO `vusqz` 的 source carrier,再把 VMI mask +作为 governing predicate 传入。多物理 chunk 需要把前一 chunk 的 active count carry 到后一 chunk; +deinterleaved layout 还需要按逻辑 lane 顺序重建 prefix,因此不能逐物理 part 独立发 `vusqz`。 + +`vmi.compress(source, mask)` 语义是按 logical lane order 保留 active source lane 并压缩到结果前缀。 +当前直接 lowering 只覆盖 contiguous 单个 full physical chunk,可以用 `pto.vsqz(source, mask)` 承接。 +partial/tail chunk 不能直接走 `vsqz`,因为 padding mask lane 如果为 true,padding source lane 可能被 +压缩到可观察的 result 前缀。多物理 chunk 需要跨 chunk compaction;`compress_store` 还涉及 +`VSQZ #st=1` 与 `VSTUR`/`SQZN` 的配对约束,不能由 register `compress` 自动推出。 + +`vmi.compress_store(value, base[idx], mask)` 语义是按 logical lane order 把 active lane 写成连续 +memory stream。当前直接 lowering 只覆盖 contiguous、单个 full physical chunk 和 UB pointer +destination,并发出 `pto.vsqz -> pto.vstur POST_UPDATE -> pto.vstar` 的完整 store-state chain。非 full +chunk 暂不直接 lowering,因为 padding mask lane 可能被硬件 squeeze 成额外写出;multi-chunk 需要 +跨 chunk active count 和 SQZN FIFO/VSTUR 配对计划。 + +`shape_cast/reshape/transpose` 必须区分 metadata change 和 lane movement: + +```text +shape_cast / reshape: + preserve row-major flattened lane order + produce explicit result logical_shape attr + +transpose / flat_transpose: + changes logical lane order according to permutation + must lower through shuffle/permute/layout conversion/direct transpose capability +``` + +这些 op 的 source/result shape、permutation 和 broadcast map 都是 op attrs。VMI lowering 不能从 +producer defining op 或 side table 推断缺失 shape。 + +低 rank vector 到高 rank vector 的 broadcast 也不能当成 scalar broadcast 免费重物化。它必须 +保存 broadcast map: + +```text +result[indices] = source[broadcast_map(indices)] +``` + +只有 scalar-to-vector broadcast 可以按 consumer layout 任意重物化。 + +`iota` 是 lane index generation 的 VMI 表达: + +```text +iota(base, ASC): + result[i] = base + i + +iota(base, DESC): + result[i] = base - i +``` + +第一版 `iota` 的 `T` 跟随 VPTO `vci` 能承接的元素类型:integer 8/16/32 和 f16/f32。 +可变 step 不是 surface op 语义的一部分;如果 producer 需要 `base + i * step`,应表达为 +`iota(base=0) -> muli/vmi arithmetic -> addi/addf` 组合,或后续单独引入带 step 的 op。 +tail physical chunk 的 padding lane 可以承接 iota 的自然延续值,但这些 lane 不是 logical lane; +后续 memory/mask/reduction 等有外部效果的 consumer 必须继续按 valid logical lane 保护。 +deinterleaved layout 下的 physical part 需要 strided index materialization: + +```text +part p contains logical lanes p, p + factor, p + 2 * factor, ... +ASC value = base + p + factor * local_lane +DESC value = base - p - factor * local_lane +``` + +因此 direct `vci` 只覆盖 contiguous full-chunk path;deinterleaved path 必须额外物化 +`vci(0) * factor + base +/- p`,不能误降成每个 part 内连续的 `vci(base + p)`。当前 lowering +按 physical part 生成 `vci(0) + vmuls(factor) + vadds/vdup/vsub` 序列;padding/tail chunk +仍然需要独立的 padding-safe materialization plan。 + +`slice/insert_slice` 都按 logical lane order 定义,不读取或写入 padding lane: + +```text +slice(offset, size, stride): + result[j] = source[offset + j * stride] + +insert_slice(offset, stride): + result = dest + result[offset + j * stride] = update[j] + +insert_element(pos): + result = dest + result[pos] = scalar +``` + +`reduction/scan` 的 logical iteration 只覆盖 active logical lanes,padding lanes 不参与: + +```text +reduction(op, init, value, mask): + acc = init + for i in 0 .. N: + if mask is absent or mask[i]: + acc = op(acc, value[i]) + result = acc + +scan(op, init, value, mask): + acc = init + for i in 0 .. N: + if mask is absent or mask[i]: + acc = op(acc, value[i]) + result[i] = acc + else: + result[i] = passthru_or_identity +``` + +Current direct reduction support starts with integer add: + +```mlir +%r = pto.vmi.reduce_addi %value, %init, %mask + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> + +%rf = pto.vmi.reduce_addf %value, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + +%rmax = pto.vmi.reduce_maxf %value, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + +%rmin = pto.vmi.reduce_minf %value, %init, %mask + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> +``` + +`reduce_addi` preserves integer wraparound addition semantics. The direct +lowering requires contiguous layout, full 32-bit source physical chunks, +matching mask chunks, and one rank-0 init/result chunk. It emits `pto.vcadd` +for each masked source chunk, then serially accumulates each chunk result into +the rank-0 accumulator with `pto.vadd` under a `PAT_VL1` predicate. Padding +source lanes are rejected instead of being allowed to participate. + +`reduce_addf` is legal only with an explicit `{reassoc}` contract because the +ISA documents pair-wise FP reduction order. The direct lowering supports only +f32, contiguous layout, full source physical chunks, matching b32 mask chunks, +and one rank-0 init/result chunk. It uses the same per-chunk `vcadd` plus +serial `PAT_VL1 vadd` accumulation shape. Without `{reassoc}`, the verifier +rejects the op instead of silently changing ordered floating-point semantics. + +`reduce_maxf` and `reduce_minf` preserve VPTO-compatible floating-point min/max +reduction semantics. Direct lowering supports f16/f32, contiguous layout, full +source physical chunks, matching mask chunks, and one rank-0 init/result chunk. +For each physical source chunk, lowering emits `pto.vcmax` or `pto.vcmin`. +The chunk result's lowest lane is then accumulated into the rank-0 accumulator +with `pto.vmax` or `pto.vmin` under a `PAT_VL1` predicate. The index value that +`vcmax/vcmin` writes to the second lane is intentionally not part of the VMI op +result and is discarded by only observing lane 0. Inactive lane identities, +signed zero handling, and NaN behavior follow the underlying `vcmax/vcmin` and +`vmax/vmin` VPTO instructions. Padding source lanes are rejected, because the +logical reduction must not allow padding to become an inactive-lane identity or +a NaN-producing participant. + +lowering 可以选择 VPTO reduction/scan 指令、tree decomposition、scratch memory 或 scalarized +ordered fallback,但必须保持 numeric contract。没有目标能力时使用 `VMI-ELEMENT-TYPE` 或 +`VMI-LAYOUT-CONTRACT`,不能让未 lower 的逻辑向量 op 残留到 VPTO。 + +`contract/outerproduct` 在 VMI 中保留 indexing maps、iterator types、accumulator、mask 和 +element type,并且不允许绕过 VMI 直接回到其它向量 IR。如果目标有直接 matrix/vector contract +能力,lower 到直接 VPTO sequence;否则按 iterator space 分解成 VMI arithmetic + +reduction/scan,再走普通 VMI lowering。只有当 element type、accumulator 精度或 iterator +semantics 无法由目标表达时,才报 target capability diagnostic。 + +如果 producer 的 extract-like operation 结果仍是 logical vector,应表达成 `pto.vmi.slice`、 +`pto.vmi.shuffle` 或 `pto.vmi.shape_cast`。如果结果是 scalar,则属于 vector-to-scalar boundary, +不进入 VMI vector path,也不产生 `pto.vmi.extract`: + +```text +VMI-SCALAR-EXTRACT-BOUNDARY +``` + +## End-To-End Examples + +### f16 Widen Add Store + +Semantic VMI: + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<128xf16> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%s = pto.vmi.addf %w, %bias + : !pto.vmi.vreg<128xf32> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<128xf32>, memref +``` + +Layout-assigned VMI: + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%s = pto.vmi.addf %w, %bias + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref +``` + +Physical lowering 可以生成 EVEN/ODD `vcvt`、两路 `vadd`,并在 store sink 使用 interleave +store 或显式 layout conversion。 + +### f8 To f32 + +```mlir +%a = pto.vmi.load %A[%i] + : memref -> !pto.vmi.vreg<256xf8> +%w = pto.vmi.extf %a + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +%s = pto.vmi.addf %w, %b + : !pto.vmi.vreg<256xf32> +pto.vmi.store %s, %C[%i] + : !pto.vmi.vreg<256xf32>, memref +``` + +layout assignment 可把 `%w/%s` 设为 `#pto.vmi.layout`。contiguous store 必须使用 +已验证的 layout sink 或先 materialize contiguous representation,不能把 p0/p1/p2/p3 part 当成连续内存写出。 + +### Block-Strided Tile Read + +```mlir +%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask + {logical_shape = [8, 8], + permutation_map = affine_map<(d0, d1) -> (d0, d1)>} + : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> +``` + +如果 access plan 证明每 row 是 32B contiguous block,row 间 stride 可落到 ISA stride 字段, +且 mask block-uniform,lowering 可以选择 `vsldb`。如果 padding 非零,仍需在 load 后用 +valid mask 修正 invalid lane。 + +## Risk Closure Matrix + +| 风险 | 设计闭环 | 测试出口 | +|---|---|---| +| producer 直接绕过 VMI 生成 physical VPTO | VMI Producer Boundary Contract + Verifier Gates | `vmi_producer_boundary.mlir`, `vmi_pipeline_hard_gates.mlir` | +| arith numeric contract 被 VPTO 快速路径改写 | fastmath/rounding/overflow/cmp predicate preservation | `vmi_arith_numeric_contract.mlir` | +| layout 设计泛化失控 | closed `contiguous/deinterleaved=2/4` layout set + source contract | `vmi_f16_ext_add_store_deinterleaved2.mlir`, `vmi_f8_ext_add_store_deinterleaved4.mlir` | +| layout assignment 局部贪心导致控制流/多 use 错误 | region/SCC constraint solver + deterministic tie-break | `vmi_layout_assignment_constraint_solver.mlir`, `vmi_cf_and_call_layout_boundary.mlir` | +| 1:N physicalization arity 漂移 | Physical Arity helper + hard gate | `vmi_physical_arity_non_full_deinterleaved.mlir` | +| `deinterleaved=4` materialization 错 lane | registered preserving materialization path | `vmi_ensure_layout_materialization_contract.mlir` | +| mask granularity 过早固化 | surface `mask` + consumer-driven granularity assignment | `vmi_mask_granularity_width_change.mlir` | +| non-scalar broadcast / transpose 被当成 metadata | explicit broadcast map and lane-movement semantics | `vmi_shape_broadcast_semantics.mlir` | +| transfer padding / OOB read 写成 full load/store | `validMask` / `paddingValue` / `safeReadProof` / `writeMask` decision tree | `vmi_tile_read_padding_decision_tree.mlir`, `vmi_tile_write_oob_no_effect.mlir` | +| index/address 单位或宽度被误用 | index/address legalization contract | `vmi_index_address_legalization.mlir` | +| reduction/scan/contract 回退成 residual logical-vector op | VMI semantic op + direct/decompose/scratch lowering contract | `vmi_reduction_scan_contract_coverage.mlir` | +| shape 信息依赖 hidden side table | flat VMI value + shape-sensitive op attrs | `vmi_shape_broadcast_semantics.mlir`, `vmi_pipeline_hard_gates.mlir` | +| fallback 缺资源时退化成残缺 lowering | explicit fallback resource contract + `VMI-FALLBACK-RESOURCE` | `vmi_fallback_resource_diagnostics.mlir` | +| tensor/debug/scalar boundary 混入 VMI | explicit boundary diagnostics | `vmi_tensor_transfer_boundary.mlir`, `vmi_debug_boundary.mlir`, `vmi_extract_boundary.mlir` | + +## Diagnostics + +| Code | 场景 | +|---|---| +| `VMI-SCALAR-EXTRACT-BOUNDARY` | scalar lane extract 不是 VMI vector op,必须在进入 VMI 前消除或退出 PTO 路线 | +| `VMI-SCALABLE-VECTOR` | scalable vector 未在进入 VMI 前 specialize 成固定 logical lane count | +| `VMI-ELEMENT-TYPE` | target registry 缺 storage/compute/convert capability | +| `VMI-LAYOUT-CONTRACT` | VMI layout、mask granularity 或控制流/调用边界约束冲突 | +| `VMI-MEMORY-ACCESS` | access plan 无 direct/fallback path | +| `VMI-LAYOUT-CONTRACT` | layout conversion 或 sink 未被 target registry 支持 | +| `VMI-FALLBACK-RESOURCE` | scratch、guard、index buffer 或 fallback index width 资源不可用 | +| `VMI-TENSOR-BOUNDARY` | tensor transfer 必须在进入 VMI 前 bufferize 或退出 PTO 路线 | +| `VMI-DEBUG-BOUNDARY` | debug op 必须在进入 VMI 前消费、剥离或退出 PTO 路线 | +| `VMI-PASS-INVARIANT` | pipeline hard gate 被破坏,例如 hidden side table、残留 conversion cast 或 layout 缺失 | +| `VMI-RESIDUAL-OP` | physicalization 后仍有非法 VMI op/type 或 helper | + +diagnostic payload 至少包含 source op、semantic reason、failed contract、available paths、 +missing capability 或 disabled fallback option。 + +## Implementation Plan + +具体文件布局、Slice 切分、ODS/type/op/pass/test 落地步骤见 +`docs/designs/vmi-implementation-manual.md`。本节只保留高层任务顺序。 + +1. 定义 `!pto.vmi.vreg`、`!pto.vmi.vreg`、 + `!pto.vmi.mask`、`!pto.vmi.mask`。 +2. 定义 layout 目录:`#pto.vmi.layout`、 + `#pto.vmi.layout`、 + `#pto.vmi.layout`, + 并实现统一 lane-map / physical-arity helper。 +3. 定义 VMI semantic op families:construction、memory、arith、conversion、mask、 + permutation、active-prefix、compress/expand、channel split/merge、reduction/scan/contract。 +4. 实现 VMI producer boundary verifier,禁止 producer 直接生成 physical VPTO 或依赖 hidden state。 +5. 实现 `vmi-layout-assignment`,包含 op transfer function、cost model、mask granularity + conversion、control-flow join。 +6. 实现 VMI memory lowering:access plan、safe-read/write proof、tile padding materialization、 + transfer mask coordinate remap、masked/guarded/scratch fallback。 +7. 实现 `vmi-to-vpto` 1:N type conversion,包含 `pack/unpack` materialization 和 structural + conversion。 +8. 加 target element-type / layout-sink / ISA contract / fallback resource registry。 +9. 加 VMI hard gate verifier:覆盖 VMI producer boundary、`vmi-layout-assignment`、 + `vmi-to-vpto` 后的残留 op/type、layout、mask granularity、conversion cast 和 hidden-state + invariant。 +10. 加 VMI diagnostic code registry 和 lit tests。 + +## Test Checklist + +1. `vmi_f16_ext_add_store_deinterleaved2.mlir` + - `extf` 后 result 是 `vreg<128xf32, deinterleaved=2>`,store 保持 contiguous logical order。 +2. `vmi_f8_ext_add_store_deinterleaved4.mlir` + - `deinterleaved=4` p0/p1/p2/p3 不被误写成 contiguous memory。 +3. `vmi_non_full_tile_padding_lanes.mlir` + - `vreg<100xf32>` padding lane 不可观察。 +4. `vmi_mask_granularity_width_change.mlir` + - surface `mask` 被不同 width consumer 使用时,正确生成 `mask` / + `mask` 并保持 data layout。 +5. `vmi_control_flow_layout_join.mlir` + - `scf.if/scf.for` layouted VMI type join 稳定。 +6. `vmi_tile_read_padding_safe_footprint.mlir` + - full physical load unsafe 时不偷读 invalid lane。 +7. `vmi_block_strided_rows_vsldb.mlir` + - `tile_read/tile_write` 识别 32B block rows,并拒绝 per-lane mask direct path。 +8. `vmi_active_prefix_index_compress.mlir` + - arbitrary mask compaction 使用 logical prefix order。 +9. `vmi_extract_boundary.mlir` + - scalar extract 输出 `VMI-SCALAR-EXTRACT-BOUNDARY`。 +10. `vmi_channel_split_merge_semantic_op.mlir` + - interleaved channel data 按用户语义拆成多个 VMI values,再通过 merge 写回。 +11. `vmi_producer_boundary.mlir` + - producer boundary 后只有 VMI semantic op/type,不出现 physical VPTO 或 hidden-state 依赖。 +12. `vmi_mask_threading.mlir` + - region-style mask 被 thread 到 masked VMI op 或 `vmi.select` merge,不残留 region mask。 +13. `vmi_gather_scatter_memory_semantics.mlir` + - inactive gather/scatter lane 不读写内存,scatter duplicate-index case 不走非法 direct path。 +14. `vmi_reduction_scan_contract_coverage.mlir` + - reduction/scan/contract 不回退成 residual logical-vector op,按 VMI lowering contract 处理。 +15. `vmi_cf_and_call_layout_boundary.mlir` + - `cf.br/cond_br` block arguments 和 internal call signatures 选择稳定 layout,external ABI 不泄露 layout。 +16. `vmi_iota_bitcast_insert_extract_coverage.mlir` + - lane index、bitcast、vector-result extract-like 和 insert-like 语义都有 VMI 承接。 +17. `vmi_memory_view_normalization.mlir` + - producer-specific vector element view 先规范化为 element view 和 access plan。 +18. `vmi_debug_boundary.mlir` + - debug-only op 不进入 VMI;未被 producer 消费时输出 `VMI-DEBUG-BOUNDARY`。 +19. `vmi_arith_numeric_contract.mlir` + - VMI arithmetic constant、fastmath、cmp predicate、integer signedness/overflow flags 保真。 +20. `vmi_shape_broadcast_semantics.mlir` + - `shape_cast/reshape` 只改 explicit op shape attrs,`transpose/flat_transpose` 和非 scalar broadcast 保持 lane map 语义且不依赖 shape side table。 +21. `vmi_physical_arity_non_full_deinterleaved.mlir` + - 非整 tile 下 `contiguous/deinterleaved=2/4` 的 physical value 个数和 valid lane map 一致。 +22. `vmi_ensure_layout_materialization_contract.mlir` + - `ensure_layout` 保持 logical lane 值,`deinterleaved=4` 只使用 registry 证明过的 materialization path。 +23. `vmi_tile_read_padding_decision_tree.mlir` + - safe full-read + non-all-true valid mask 生成 padding materialization + select;unsafe path 不读 invalid address。 +24. `vmi_tile_write_oob_no_effect.mlir` + - `tile_write` 的 writeMask=false lane 没有 memory effect,不被 lower 成 predicate-ignored store。 +25. `vmi_transfer_mask_coordinate_remap.mlir` + - non-minor-identity `tile_read/tile_write` 的 explicit mask 先映射到 result/source logical lane。 +26. `vmi_tile_read_vector_element_padding.mlir` + - vector-element padding 按 suffix coordinate 展开,invalid lane 使用对应 padding element。 +27. `vmi_index_address_legalization.mlir` + - `vreg`、gather/scatter indices、active-prefix offset 使用 element units 且宽度合法。 +28. `vmi_fallback_resource_diagnostics.mlir` + - scratch、guarded fallback、index-buffer fallback 缺资源时输出 `VMI-FALLBACK-RESOURCE`。 +29. `vmi_tensor_transfer_boundary.mlir` + - tensor transfer-style producer op 不伪装成 VMI memory op,未 bufferize 时输出 `VMI-TENSOR-BOUNDARY`。 +30. `vmi_pipeline_hard_gates.mlir` + - 各 pass 边界拒绝残留 VMI helper/unrealized cast/hidden state,且 final lowering 不残留 VMI op/type。 +31. `vmi_layout_assignment_constraint_solver.mlir` + - 多 use、rematerializable producer、control-flow join、layout conversion cost 冲突时选择稳定 layout 或输出精确 diagnostic。 diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md new file mode 100644 index 0000000000..772194f64d --- /dev/null +++ b/docs/designs/vmi-implementation-manual.md @@ -0,0 +1,4233 @@ +# VMI 实现手册 + +本文是 `docs/designs/vmi-dialect-design.md` 的落地手册。设计文档回答“为什么这样设计”,本文回答 +“按什么顺序改哪些文件、每一步做到什么程度才算完成”。 + +本文不替代最终 ODS / C++ verifier / lit 测试。实现时如果发现本文和 ODS 或 verifier 冲突,以 +更精确的 verifier 约束为准,并同步刷新本文。 + +## 0. 当前仓库约束 + +当前仓库只有一个 MLIR dialect: + +```text +dialect name: pto +cpp namespace: ::mlir::pto +``` + +VPTO 低层 op/type 也在同一个 `pto` dialect 里,通过 `VPTOOps.td`、`VPTOTypeDefs.td` 等文件组织。 +因此第一版 VMI 不新建独立 dialect,采用同一 dialect 下的嵌套 mnemonic: + +```text +types: + !pto.vmi.vreg<...> + !pto.vmi.mask<...> + +attrs: + #pto.vmi.layout<...> + +ops: + pto.vmi.addf + pto.vmi.subf + pto.vmi.mulf + pto.vmi.ensure_layout +``` + +落地方式是:`PTO_Dialect` 仍是唯一 dialect,VMI 只是 `pto` dialect 内的一组 type/attr/op。 +如果后续要拆成真正独立的 `pto.vmi` dialect,必须先保证所有 pass、type converter、parser 测试 +和公开文档同步迁移;第一版不要做这个拆分。 + +风险点:带点 mnemonic 例如 `vmi.vreg`、`vmi.addf` 必须在 Slice 0 先用 parser round-trip 测试 +证明。如果 TableGen 的默认 type/attr parser 不接受该 spelling,就在 VMI type/attr 上实现 +custom assembly format,而不是改公开 spelling。 + +## 1. 文件布局 + +新增文件: + +```text +include/PTO/IR/VMIAttrs.td +include/PTO/IR/VMITypeDefs.td +include/PTO/IR/VMIOps.td +lib/PTO/IR/VMI.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +lib/PTO/Transforms/PTOValidateVMIIR.cpp +test/lit/vmi/ +``` + +修改文件: + +```text +include/PTO/IR/PTOAttrs.td +include/PTO/IR/PTOTypeDefs.td +include/PTO/IR/PTOOps.td +include/PTO/IR/CMakeLists.txt +lib/PTO/IR/CMakeLists.txt +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/CMakeLists.txt +``` + +推荐 include 关系: + +```text +PTOAttrs.td + include "PTO/IR/VMIAttrs.td" + +PTOTypeDefs.td + include "PTO/IR/VMITypeDefs.td" + +PTOOps.td + include "PTO/IR/VMIOps.td" +``` + +放置顺序: + +```text +VMIAttrs.td: + include PTODialect.td, AttrTypeBase.td, EnumAttr.td + must not include PTOAttrs.td + +VMITypeDefs.td: + include PTODialect.td and can rely on PTOAttrs.td having included VMIAttrs.td + +VMIOps.td: + include after PTO_Op is defined in PTOOps.td + do not include VPTOOps.td from VMIOps.td +``` + +这样现有 `LLVM_TARGET_DEFINITIONS PTOOps.td` 的 TableGen 生成路径可以继续覆盖 VMI type、attr +和 op。只有当 TableGen 生成目标不能正确收集新增 td 时,才单独新增 `mlir_tablegen` 目标。 + +`lib/PTO/IR/VMI.cpp` 放 VMI type/attr/op verifier、parse/print helper 和公共 lane-map helper。 +不要把 VMI verifier 塞进 `VPTO.cpp`。 + +Pass 注册要求: + +```text +include/PTO/Transforms/Passes.td: + add VMILayoutAssignment + add VMIToVPTO + add PTOValidateVMIIR + +include/PTO/Transforms/Passes.h: + add explicit create*Pass declarations if generated declarations are not enough + +lib/PTO/Transforms/CMakeLists.txt: + add the three new .cpp files to PTOTransforms + keep DEPENDS PTOPassesIncGen and PTOOpsIncGen + add missing MLIR dialect libraries only when a new source actually includes them +``` + +Driver wiring is explicit and opt-in. `ptoas --enable-vmi` runs the VMI semantic pipeline before the VPTO backend +pipeline: + +```text +pto-validate-vmi-ir +vmi-layout-assignment +pto-validate-vmi-layout-ir +vmi-to-vpto +``` + +`--enable-vmi` requires `--pto-backend=vpto` or `pto.backend = "vpto"` because the pipeline produces physical VPTO +values and ops. It is not part of the default PTOAS pipeline; existing PTO/VPTO inputs keep their previous behavior +unless the flag is set. + +The `ptoas --enable-vmi` user-facing entry also rejects public functions whose signature contains `!pto.vmi.*`. +Internal/private VMI-typed functions may still be specialized by `vmi-layout-assignment` and physicalized by +`vmi-to-vpto`, but a public VMI ABI requires an explicit materialization plan and must not be inferred from the +layout solver. + +CLI coverage: + +```text +vmi_ptoas_cli_pipeline.pto: + --pto-backend=vpto + --enable-vmi lowers the VMI pipeline + pto.backend = "vpto" also selects the VPTO-compatible path + explicit --pto-backend=emitc with --enable-vmi is rejected + +vmi_ptoas_backend_required_invalid.pto: + default emitc backend with --enable-vmi and no pto.backend = "vpto" is rejected + +vmi_ptoas_public_abi_invalid.pto / vmi_ptoas_public_result_abi_invalid.pto: + public VMI argument/result signatures are rejected before layout assignment +``` + +## MLIR Framework Usage + +三个核心 pass 不应该用同一种 MLIR 机制硬套。这里先定义实现框架选择,避免后续把 layout +求解、结构化控制流改写和 1:N physicalization 混在一个 pattern pass 里。 + +当前实现框架按下面的职责切开: + +```text +pto-validate-vmi-ir: + Operation::walk verifier。只看 IR 是否满足阶段不变量,不改 IR,不使用 conversion framework。 + +vmi-layout-assignment: + module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, + 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 + +vmi-to-vpto: + MLIR OneToNTypeConversion。每个 layout-assigned VMI value 按统一 physical ordering 展开成多个 + VPTO value,并依靠 OneToN structural patterns 重写函数、return、region result、block argument 和 + branch operand。 +``` + +这三个 pass 的边界必须通过 IR 可见状态传递:layout 写在 `!pto.vmi.*` type 上,必要 materialization +写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 +`unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 + +源码级实现应该进一步拆成五个独立层次: + +```text +IR layer: + include/PTO/IR/VMIAttrs.td + include/PTO/IR/VMITypeDefs.td + include/PTO/IR/VMIOps.td + lib/PTO/IR/VMI.cpp + + 只定义语义、parse/print、type/op verifier 和公共 lane-map helper。 + 这一层不能知道 layout assignment 的全局选择,也不能直接依赖 VPTO lowering pass。 + +Semantic validation layer: + lib/PTO/Transforms/PTOValidateVMIIR.cpp + + 只检查阶段输入/输出是否满足 contract。它是 hard gate,不做 repair。 + +Layout solving layer: + lib/PTO/Transforms/VMILayoutAssignment.cpp + + 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, + 然后把结果写回 type 或 ensure_* helper。 + +Physicalization layer: + lib/PTO/Transforms/VMIToVPTO.cpp + + 负责把 layout-assigned VMI value 通过 OneToNTypeConversion 展成 VPTO physical values, + 并把每个 pto.vmi.* semantic op 改写成 VPTO op 序列。 + +Driver/test layer: + tools/ptoas/ptoas.cpp + tools/pto-test-opt/ + test/lit/vmi/ + + ptoas 只暴露 opt-in pipeline;pto-test-opt 保留单 pass 和中间 IR 的调试入口。 +``` + +每层的 MLIR 框架选择如下: + +```text +ODS/TableGen: + 定义 type/attr/op surface 和 verifier hook。 + +Operation::walk: + 用于 validation 和 layout constraint collection。 + +Union-find + DenseMap: + 用于 layout assignment 的 per-SSA-value 等价类求解。 + +IRRewriter/RewriterBase: + 用于 layout assignment 之后的 type rewrite、helper insertion、cheap producer rematerialization。 + +OneToNTypeConverter + OneToNOpConversionPattern: + 只用于 vmi-to-vpto,把一个 logical VMI value 展成多个 VPTO value。 + +Upstream OneToN structural helpers: + func.func / func.call / func.return / common SCF region-result conversion。 + +Project-local OneToN structural patterns: + cf.br / cf.cond_br / cf.switch / scf.execute_region / scf.index_switch。 +``` + +不要把这些层次合并成一个万能 pattern pass。特别是: + +```text +layout assignment 不能依赖 OneToNTypeConverter: + 因为 layout 不是 type-only 决策,同一个 !pto.vmi.vreg<128xf32> 的不同 SSA value + 可能因 producer/consumer/control-flow 约束得到不同 layout。 + +vmi-to-vpto 不能重新做 layout solving: + 它只消费已经写在 type/helper 上的 layout 决策。遇到未 assignment 的 VMI type 必须失败。 + +structural OneToN pattern 不能知道 VMI 语义: + 它们只负责 flatten/rebuild operands、results、successor operands 和 block arguments。 + 具体 lane 语义只属于 pto.vmi.* op lowering pattern。 + +verifier 不能偷偷修 IR: + 否则后续 pass 会依赖 verifier 的隐式 repair 行为,导致 pipeline 顺序不可推理。 +``` + +一个可以直接对照代码的 pass 边界表: + +```text +pass input output +--------------------------- ---------------------------- ---------------------------- +pto-validate-vmi-ir surface VMI IR same IR, or hard failure +vmi-layout-assignment surface/layout-partial VMI layout-assigned VMI IR +pto-validate-vmi-layout-ir layout-assigned VMI IR same IR, or hard failure +vmi-to-vpto layout-assigned VMI IR physical VPTO IR +final residual verifier physical VPTO candidate no pto.vmi.*, no !pto.vmi.* +``` + +### 代码级落点 + +当前实现应该能按文件直接审计。每个 pass 的核心类、MLIR 机制和失败边界如下: + +```text +lib/PTO/Transforms/PTOValidateVMIIR.cpp + pass: + PTOValidateVMIIRPass + PTOValidateVMILayoutIRPass + public helpers: + validateVMIProducerBoundaryIR + validateVMILayoutAssignedIR + MLIR API: + Operation::walk + func::FuncOp function type inspection + recursive TypeAttr / TypedAttr / ArrayAttr / DictionaryAttr scan + must not: + rewrite IR + create unrealized_conversion_cast + create ConversionTarget + repair illegal helper/type leakage + +lib/PTO/Transforms/VMILayoutAssignment.cpp + pass: + VMILayoutAssignmentPass + core object: + LayoutSolver + state: + DenseMap + SmallVector + SmallVector + SmallVector + SmallVector + MLIR API: + Operation::walk for fact collection + SymbolTable for direct internal calls + concrete cf/scf handlers for control-flow equivalence + IRRewriter/OpBuilder only after solving + must not: + use TypeConverter as the layout decision model + rewrite while collecting constraints + hide chosen layout in a pass-private side table + infer external VMI ABI + +lib/PTO/Transforms/VMIToVPTO.cpp + pass: + VMIToVPTOPass + converter: + VMIToVPTOTypeConverter : OneToNTypeConverter + pattern families: + OneToNOpConversionPattern for pto.vmi.* semantic ops + upstream func/scf OneToN structural patterns + project-local cf/scf structural OneToN patterns + MLIR API: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + applyPartialOneToNConversion + final residual walk + must not: + redo layout solving + inspect defining ops to recover physical parts + allow pto.vmi.pack/unpack/ensure_* to survive final output + allow unrealized_conversion_cast to survive final output +``` + +这里最重要的分界是:`vmi-layout-assignment` 解决的是 value-level layout,`vmi-to-vpto` +解决的是 type/value 1:N physicalization。前者的结果必须已经写回 `!pto.vmi.*` type 或显式 +`pto.vmi.ensure_*`;后者只能消费这些 IR-visible facts。 + +这也回答了“有没有充分利用 MLIR 自带能力”:结构化 1:N signature/control-flow conversion 必须用 +MLIR OneToN conversion;layout assignment 则不能强行塞进 converter,因为 converter 看不到 +producer natural layout、consumer request、CFG join 和 call-return slot 这些 value-level facts。 + +### Pass 级实现细则 + +这几个 pass 对 MLIR 自带能力的使用方式应该是“各用其长”,而不是都套成 converter pattern。 +实现时按下面的判断标准拆: + +```text +只检查阶段不变量: + 用 Operation::walk。不要创建 ConversionTarget,也不要 rewrite。 + +需要根据 SSA value、CFG join、call boundary 和 consumer request 决策 layout: + 用 module-level solver。MLIR conversion framework 没有 per-value layout 决策模型。 + +需要把一个 logical value 展成多个 physical value,并同步改 function/block/control-flow signature: + 用 OneToNTypeConversion。这里是 converter framework 最应该发挥作用的地方。 +``` + +#### Pass 框架细化 + +第一版实现按下面的源码和 MLIR infra 对齐。这个表是实现时的边界,不只是文档分层: + +```text +source file pass primary MLIR facility +----------------------------------------- --------------------------- --------------------------------------------- +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan +lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter +lib/PTO/Transforms/VMIToVPTO.cpp vmi-to-vpto OneToNTypeConverter + OneToNOpConversionPattern +``` + +这意味着每个 pass 的输入输出 contract 是固定的: + +```text +pto-validate-vmi-ir: + input: + surface VMI IR + legal: + pto.vmi semantic ops + !pto.vmi.vreg + !pto.vmi.mask + func/scf/cf structural ops carrying those types + illegal: + layout-assigned !pto.vmi.* type + physical !pto.vreg / !pto.mask / !pto.align type + pto.vmi.ensure_* / pack / unpack helper + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-layout-assignment: + input: + verifier-clean surface VMI IR + legal work: + solve per-SSA layout/granularity constraints + rewrite VMI value/function/block types with explicit layout + insert pto.vmi.ensure_* only for use-site materialization + rematerialize cheap producers instead of inserting ensure_* when semantics are replay-safe + illegal work: + physicalize to !pto.vreg / !pto.mask + introduce pto.vmi.pack / pto.vmi.unpack + keep layout only in a pass-private side table + output: + layout-assigned VMI IR, or failure + +pto-validate-vmi-layout-ir: + input: + layout-assigned VMI IR + legal: + pto.vmi semantic ops + pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity + !pto.vmi.vreg + !pto.vmi.mask + illegal: + surface !pto.vmi.vreg + surface !pto.vmi.mask + physical VPTO register types before vmi-to-vpto + pto.vmi.pack / pto.vmi.unpack + VMI or physical type hidden in non-signature attribute + output: + exactly the same IR, or failure + +vmi-to-vpto: + input: + layout-assigned VMI IR + legal work: + convert each VMI value to an ordered list of physical VPTO values + rewrite function signatures, block arguments, branch operands, region results and calls + lower pto.vmi semantic/helper ops to VPTO ops + illegal work: + infer missing layouts + change a chosen layout because one pattern finds a cheaper lowering + leave pto.vmi.* / !pto.vmi.* / unrealized_conversion_cast in final IR + output: + physical VPTO IR, or failure +``` + +`vmi-layout-assignment` 和 `vmi-to-vpto` 的关键差异是:前者解决“这个 SSA value 应该是什么 layout”, +后者解决“这个已经有 layout 的 SSA value 展开成哪些 physical value”。同一个 surface type 不能用 +`TypeConverter` 得到唯一答案: + +```mlir +%a = pto.vmi.broadcast %s : f32 -> !pto.vmi.vreg<128xf32> +%b = pto.vmi.extf %x : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%c = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32> +} +``` + +这里 `%a` 可以按 consumer 需要 rematerialize 成 contiguous 或 deinterleaved;`%b` 的 natural layout 是 +`deinterleaved=2`;`%c` 的 layout 必须由两个 yield 和后续 consumer 共同约束。这个选择依赖 Value、 +def-use、control-flow join 和 use-site request,不是 `!pto.vmi.vreg<128xf32> -> ...` 的 type-only 规则。 + +因此 layout pass 的代码形态应该固定为: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collectAllVMIValues())) + return failure(); + if (failed(collectEquivalenceConstraints())) + return failure(); + if (failed(collectProducerNaturalLayouts())) + return failure(); + if (failed(collectConsumerRequests())) + return failure(); + if (failed(rewriteDataTypes())) + return failure(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferAndRewriteMaskTypes())) + return failure(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionTypesFromSolvedValues(); + return validateVMILayoutAssignedIR(module); +} +``` + +其中 `collect*` 阶段只能记录事实,不能边 walk 边改 IR。原因是控制流和 call boundary 会把后面才遇到的 +operand/result 合并到前面的 value class;边收集边改 type 会让后续约束看到混合状态,错误诊断也会依赖 +walk 顺序。 + +`vmi-to-vpto` 则必须是 converter pass。第一版使用的是 `OneToNTypeConversion`,因为它要同时处理 +value type 和结构签名: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +func.func @f(%arg0: !pto.vmi.vreg<128xf32, layout>) -> !pto.vmi.vreg<128xf32, layout> + -> func.func @f(%arg0_0: !pto.vreg<64xf32>, %arg0_1: !pto.vreg<64xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这里不能用普通 1:1 `TypeConverter`,也不能靠每个 VMI op pattern 自己拆 operand。否则 `func.return`、 +`cf.br`、`scf.for` iter arg 这种没有 VMI defining op 的边界会漏转换。`OneToN` adaptor 才是 semantic +pattern 获取 physical parts 的唯一来源: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); +``` + +结构化转换的实现分工如下: + +```text +upstream helper: + populateFuncTypeConversionPatterns + covers func.func / func.return / direct func.call signature conversion + + scf::populateSCFStructuralOneToNTypeConversions + covers common SCF result/yield/block-argument structural conversions + +project-local OneToN patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只能做结构搬运: + +```text +1. read OneToNTypeMapping for each original operand/result +2. flatten successor operands or region result types +3. rebuild the same cf/scf op with converted types +4. inline/move original regions when required +``` + +它们不能做下面这些事: + +```text +infer layout from operand defining op +emit vadd/vcvt/vlds/vsts +decide contiguous vs deinterleaved +special-case pto.vmi semantic op +``` + +VMI 语义只能出现在 `OneToNOpConversionPattern` 里。这样才能保证 block argument、function +argument、loop-carried value 和 branch target argument 都按同一套 physical ordering 转换。 + +`vmi-to-vpto` 的 legality 由 preflight + conversion + final gate 三段组成,而不是单靠 +`ConversionTarget`: + +```text +preflight: + verifyVMIToVPTOInputIR + rejects layout-free VMI types + verifySupportedVMIToVPTOOps + rejects unsupported semantic/materialization cases before rewrite starts + +conversion: + applyPartialOneToNConversion + applies structural and semantic OneToN patterns + +final gate: + verifyNoResidualVMIIR + rejects pto.vmi.* + rejects !pto.vmi.* in operand/result/block/function/attribute type trees + rejects pto.vmi.pack/unpack materialization helpers + rejects unrealized_conversion_cast +``` + +这比只设置 `ConversionTarget` 更直接,因为当前 OneToN 工具链的重点是 type/value expansion 和 pattern +rewriter;最终合法性必须递归检查 attribute/type tree,防止 VMI type 被藏在 nested attr 里。 + +#### `pto-validate-vmi-ir` / `pto-validate-vmi-layout-ir` + +这两个 pass 是 hard gate,不是 legalization pass。 + +使用的 MLIR 能力: + +```text +Operation::walk: + 遍历 module 内所有 op、region、block argument、operand/result type 和 attribute。 + +TypeAttr / TypedAttr recursive scan: + 拒绝把 VMI/physical VPTO type 藏在 nested attribute 中。 + +func::FuncOp function type special case: + function_type attr 是签名本身,可以按当前阶段规则检查;其它 attr 不能携带 VMI/physical type。 +``` + +不使用 `ConversionTarget` 的原因: + +```text +ConversionTarget 适合表达“哪些 op/type legal,哪些 pattern 能改掉”。 +这里我们只想回答“当前 IR 是否已经处在某个阶段边界”,失败后必须停机,而不是尝试 repair。 +如果 verifier 顺手改 IR,pipeline 的阶段不变量会变成隐式行为,后续 pass 很难审计。 +``` + +这两个 pass 的输出只能是原 IR 或 failure: + +```cpp +void runOnOperation() override { + if (failed(verifyStageInvariant(getOperation()))) + signalPassFailure(); +} +``` + +#### `vmi-layout-assignment` + +这个 pass 使用 MLIR 的 IR 遍历和 rewrite 基础设施,但不使用 `TypeConverter` 作为主模型。 + +核心原因: + +```text +TypeConverter 的输入是 Type。 +layout assignment 的输入是 Value。 + +同一个 !pto.vmi.vreg<128xf32> 可以因为不同 producer/consumer 关系得到不同 layout: + f16->f32 widen result -> deinterleaved=2 + f8 ->f32 widen result -> deinterleaved=4 + only contiguous store value -> contiguous +``` + +实现应拆成两个阶段,不要边 walk 边 rewrite: + +```text +collect: + 1. 收集所有 VMI data/mask SSA value 和 block argument。 + 2. 用 union-find 合并必须同 layout 的 value。 + 3. 记录 producer natural layout。 + 4. 记录 consumer layout/granularity request。 + 5. 记录 function return slot、call operand/result、branch operand/block argument 关系。 + +rewrite: + 1. 为每个 equivalence class 选 layout。 + 2. 改写 value/function/block/result type。 + 3. 对 use-site mismatch 插入 ensure_* 或 rematerialize cheap producer。 + 4. 运行 pto-validate-vmi-layout-ir。 +``` + +建议的数据结构边界: + +```cpp +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +这里可以充分使用 MLIR 的接口,但它们只是 constraint source: + +```text +BranchOpInterface / concrete cf.* handlers: + successor operand[i] == destination block argument[i] + +RegionBranchOpInterface / concrete scf.* handlers: + region yield operand[i] == parent result[i] + loop init/result/iter_arg/yield 同 slot 等价 + +CallOpInterface + SymbolTable: + direct internal call operand/result 和 callee argument/return slot 等价 + external/indirect VMI call 先拒绝,因为缺 ABI materialization + +IRRewriter: + 只在 solve 完成后统一改 type、插 ensure_*、clone cheap producer。 +``` + +`vmi-layout-assignment` 的 pass invariant 是:所有 layout 决策必须写回 IR。后续 `vmi-to-vpto` +只能读取 `!pto.vmi.*` type 和显式 `pto.vmi.ensure_*`,不能依赖 layout solver 的 side table。 + +#### `vmi-to-vpto` + +这个 pass 应该充分使用 MLIR converter framework,具体是 `OneToNTypeConversion`,不是普通 +`DialectConversion`。 + +普通 1:1 dialect conversion 不够的地方: + +```text +!pto.vmi.vreg<128xf32, deinterleaved=2> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +!pto.vmi.vreg<256xf8, deinterleaved=4> + -> !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8>, !pto.vreg<256xf8> +``` + +函数参数、返回值、block argument、branch operand、region result 都必须做同样的 1:N 展开。 +这正是 `OneToNTypeConverter`、`OneToNOpConversionPattern` 和结构化 OneToN helper 的职责。 + +实现骨架: + +```cpp +void runOnOperation() override { + ModuleOp module = getOperation(); + + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(&getContext()); + + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + populateProjectLocalCFOneToNPatterns(typeConverter, patterns); + populateVMISemanticOneToNPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +`VMIToVPTOTypeConverter` 只做一种事:把 layout-assigned VMI type 映射到 canonical physical value list。 +它不能重新推导 layout。 + +```text +contiguous: + chunk0, chunk1, ... in logical order + +deinterleaved=2: + part0 chunks for logical lanes 0,2,4,... + part1 chunks for logical lanes 1,3,5,... + +deinterleaved=4: + part0 chunks for lanes 0,4,8,... + part1 chunks for lanes 1,5,9,... + part2 chunks for lanes 2,6,10,... + part3 chunks for lanes 3,7,11,... +``` + +每个 semantic pattern 必须从 adaptor 拿 physical parts,不允许从 defining op 反推: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhs = adaptor.getLhs(); + ValueRange rhs = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhs.size() != rhs.size() || lhs.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, resultType] : llvm::enumerate(resultTypes)) { + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs[i], rhs[i]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这个约束对控制流是关键的:`scf.for` iter arg、branch target argument、function argument 都没有可用的 +defining op;它们的 physical parts 只能来自 OneToN signature/block argument conversion。 + +`vmi-to-vpto` 应有三层失败点,诊断不要混在一起: + +```text +preflight: + layout 未 assignment、unsupported semantic op、unsupported materialization path + +conversion: + pattern 缺失、arity mismatch、结构化控制流展开失败 + +final residual verifier: + 任何 pto.vmi.*、!pto.vmi.*、pto.vmi.pack/unpack/ensure_*、unrealized_conversion_cast 残留 +``` + +### `pto-validate-vmi-ir` + +`pto-validate-vmi-ir` 是边界 verifier,不使用 DialectConversion。 + +推荐使用: + +```text +Operation::walk +TypeSwitch / isa / dyn_cast +emitOpError / InFlightDiagnostic +SymbolTable, for function/call boundary checks +CallGraph or manual call graph collection, if recursive SCC needs diagnostics +DominanceInfo, if helper placement or resource dominance is checked +``` + +这个 pass 只检查 VMI producer boundary 和阶段不变量: + +```text +before layout assignment: + VMI data/mask values use surface type + no layout-assigned VMI type leaks in unless the test explicitly starts after assignment + no physical VPTO op appears in the semantic VMI region + no VMI helper op appears before the pass that is allowed to create it + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after layout assignment: + pass: pto-validate-vmi-layout-ir + every VMI data value has a layout + every VMI mask has layout and concrete granularity + control-flow joins have stable type/layout + no non-signature op/module TypeAttr or TypedAttr payload contains VMI or physical VPTO types + +after VMI-to-VPTO: + no VMI op/type/helper remains + no unrealized_conversion_cast remains +``` + +不要把这个 pass 写成 rewrite pass。它可以收集 context 用于诊断,但不能通过局部修补让非法 IR +继续前进;否则后续 pass 会开始依赖 verifier 的隐式 repair 行为。 + +实现上要扫描的不只是 operand/result/block argument: + +```text +func.func function type: + 作为函数签名本身检查,允许出现当前阶段合法的 VMI type。 + +non-signature attributes: + module/op attribute 中只要递归包含 VMI type 或 physical VPTO type 都拒绝。这里包括 TypeAttr、 + TypedAttr,以及 ArrayAttr/DictionaryAttr 这类容器中的 nested attribute/type payload。 +``` + +这样可以堵住 hidden-state 形式的 side table,例如把 `!pto.vmi.vreg<...>` 偷存在 module attribute +里。`func.func` 的内建 `function_type` attr 是唯一例外,因为它只是函数签名的 MLIR 表达,不是额外 +隐藏状态。 + +### `vmi-layout-assignment` + +`vmi-layout-assignment` 不以 MLIR `TypeConverter` 作为主机制。 + +原因是 layout 选择不是单纯的 `Type -> TypeRange` 映射: + +```text +same surface type: + !pto.vmi.vreg<128xf32> + +possible per-value decisions: + value produced by f16->f32 widen: deinterleaved=2 + value loaded only for contiguous store: contiguous + value feeding fp8-like->f32 consumer path: deinterleaved=4 +``` + +两个 SSA value 可以有完全相同的 surface type,但因为 producer natural layout、consumer demand、 +控制流 join 和 target capability 不同,得到不同 layout。因此主模型应该是 per-SSA-value 的约束图, +而不是类型转换表。 + +推荐内部结构: + +```text +DenseMap +DenseMap +DenseMap +SmallVector +SmallVector +``` + +推荐使用的 MLIR 基础能力: + +```text +RegionBranchOpInterface: + collect scf.if/scf.for-like region entry, yield, result relations + +BranchOpInterface: + collect cf.br/cf.cond_br predecessor operand -> block argument relations + +CallOpInterface, CallableOpInterface, FunctionOpInterface: + collect call operand/result and function argument/result relations + +SymbolTable: + resolve direct calls and reject unresolved VMI signature assumptions + +DominanceInfo: + choose legal insertion points for ensure_layout, mask conversion, and rematerialization + +IRRewriter / RewriterBase: + rewrite types, insert helper ops, clone rematerializable producers +``` + +求解结果必须 materialize 回 IR,不能留在 side table: + +```text +1. Rewrite every VMI value type to a layout-assigned type. +2. Rewrite mask type to layout + b8/b16/b32 granularity. +3. Insert pto.vmi.ensure_layout where a consumer requires a different layout. +4. Insert pto.vmi.ensure_mask_layout / ensure_mask_granularity where predicate layout or granularity differs. +5. Clone rematerializable producers such as constant, broadcast, create_mask, iota-like producers when cheaper. +6. Re-run the VMI stage verifier. +``` + +这个 pass 可以用 `RewritePatternSet` 辅助局部 canonicalization,例如删除同 layout 的 +`ensure_layout`,但不能让 greedy pattern driver 决定全局 layout。全局约束必须先收敛,再做改写。 + +更具体地说,这里不用 `TypeConverter` 的原因不是 MLIR converter 不好用,而是此阶段的问题不是 +“一个旧 type 机械变成一个新 type”: + +```text +%a : !pto.vmi.vreg<128xf32> // 只被 contiguous store 消费 +%b : !pto.vmi.vreg<128xf32> // 来自 f16->f32 widen,后续继续 vadd +%c : !pto.vmi.vreg<128xf32> // 控制流 join,两个 predecessor 必须统一 layout +``` + +这三个 value 的 surface type 完全相同,但 layout 决策分别可能是 contiguous、deinterleaved=2、 +以及由 join 两侧约束共同决定。`TypeConverter` 看不到“这个 SSA value 的 producer/consumer/CFG +关系”,所以它只能作为后续 physicalization 的工具,不能作为 layout assignment 的主算法。 + +该 pass 对 MLIR 基础能力的使用边界是: + +```text +Operation::walk: + 收集所有 VMI SSA value、block argument、函数签名和 op transfer facts。 + +Union-find / DenseMap: + 表达必须同 layout 的 equivalence class。 + +SymbolTable: + 解析 direct internal func.call;带 VMI type 的 external/indirect call 先拒绝。 + +IRRewriter: + 改写 function/block/result type,插入 ensure_*,必要时 rematerialize cheap producer。 + +verifyLayoutAssignedVMIIR: + pass 末尾 hard gate,确认所有决策已经 materialize 到 IR。 +``` + +### `vmi-to-vpto` + +`vmi-to-vpto` 应该使用 MLIR 的 1:N conversion framework,而不是普通 `DialectConversion`。 +这个 pass 的核心问题正是一个 logical VMI value physicalize 成多个 VPTO value: + +```text +!pto.vmi.vreg -> !pto.vreg... +!pto.vmi.mask -> !pto.mask... +``` + +普通 `DialectConversion` 的 `OpConversionPattern` 对 1:N fixed operand/result 支持不够直接: +pattern adaptor 可能拿到 source materialization,也可能拿到 flat converted operands;`func.return` +这类“一个 logical operand 展开成多个 physical operands”的场景也容易出现不完整展开。因此这里采用 +MLIR `OneToNTypeConversion` 工具: + +推荐组件: + +```text +OneToNTypeConverter +OneToNOpConversionPattern +OneToNPatternRewriter +OneToNTypeMapping +populateFuncTypeConversionPatterns +scf::populateSCFStructuralOneToNTypeConversions +applyPartialOneToNConversion +final residual verifier +``` + +`OneToNTypeConverter` 负责 layout-assigned VMI type 到 ordered physical VPTO value list: + +```cpp +typeConverter.addConversion([](VMIVRegType type, SmallVectorImpl &results) { + // Use getVMIPhysicalArity(type) and the shared lane-map helper. + // Append one physical !pto.vreg per part/chunk. +}); + +typeConverter.addConversion([](VMIMaskType type, SmallVectorImpl &results) { + // Use mask granularity and physical arity helper. + // Append one physical !pto.mask per part/chunk. +}); +``` + +source/target materialization 可以用 VMI helper 承接中间状态: + +```text +VMI value -> physical values: + pto.vmi.unpack + +physical values -> VMI value: + pto.vmi.pack +``` + +但它们只是 conversion materialization,不是最终 IR 的合法残留。final gate 必须拒绝: + +```text +pto.vmi.pack +pto.vmi.unpack +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +unrealized_conversion_cast +``` + +`applyPartialOneToNConversion` 本身不是 legality framework;它负责应用 1:N patterns 并替换内部 +`unrealized_conversion_cast`。因此 `vmi-to-vpto` 必须在 conversion 后运行 final residual verifier, +把下面这些全部作为 hard failure: + +```text +any pto.vmi.* op +any !pto.vmi.* type +any pto.vmi.pack/unpack materialization helper +any pto.vmi.ensure_* helper +any unrealized_conversion_cast +``` + +结构转换必须覆盖: + +```text +func arguments/results and return operands: + use populateFuncTypeConversionPatterns + +call operands/results: + convert callee signature and call sites together + +block arguments and branch operands: + convert target block arguments and predecessor operands in the same conversion + current implementation provides project-local OneToN patterns for cf.br, + cf.cond_br, and cf.switch because MLIR only provides the generic + BranchOpInterface helper for ordinary 1:1 dialect conversion, not for VMI + 1:N physicalization. + +scf.if/scf.for region yields and results: + use scf::populateSCFStructuralOneToNTypeConversions + otherwise write explicit OneToN patterns around RegionBranchOpInterface relations +``` + +如果当前 LLVM/MLIR 版本没有提供对应 OneToN helper,就补项目内 custom `OneToNConversionPattern`。 +选择标准不是“少写代码”,而是能否正确处理 1:N result、block argument、region yield 和 +recursive/function SCC。 + +当前实现的结构转换分工如下: + +```text +upstream OneToN helper: + func.func / func.return / func.call + scf.if / scf.for / scf.while and common SCF structural cases + +project-local OneToN structural patterns: + cf.br + cf.cond_br + cf.switch + scf.execute_region + scf.index_switch +``` + +项目内 structural pattern 只做一件事:按照 `OneToNTypeMapping` 展平/重建 operand、result、 +successor operand 和 block argument。它们不能内嵌 VMI layout 语义,也不能通过 defining op +重新推导物理寄存器列表。VMI 语义只出现在各个 `pto.vmi.*` 的 `OneToNOpConversionPattern` 中。 + +OneToN conversion 的执行顺序: + +```text +1. Populate structural conversion patterns. +2. Populate VMI semantic op lowering patterns. +3. Populate helper lowering/materialization patterns. +4. applyPartialOneToNConversion on the module. +5. Run final residual verifier as the hard legality gate. +``` + +如果 conversion 或 final gate 失败,诊断必须区分: + +```text +unsupported VMI semantic op +unsupported layout materialization path +unconverted function/control-flow boundary +unexpected VMI helper residual +unexpected unrealized_conversion_cast +``` + +这样 pass 边界就是清楚的: + +```text +pto-validate-vmi-ir: + verifier/walk, no conversion + +vmi-layout-assignment: + global per-value layout solver, then IR materialization + +vmi-to-vpto: + OneToNTypeConversion-based 1:N physicalization and final legality gate +``` + +### Concrete Pass Skeleton + +整个 pipeline 按下面的 hard contract 串起来: + +```text +raw VMI producer + -> pto-validate-vmi-ir + -> vmi-layout-assignment + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto + -> final residual verifier +``` + +The `ptoas --enable-vmi` driver entry uses exactly this sequence before the existing VPTO backend pipeline. The +test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is wired +through the user-facing compiler driver. + +各阶段之间只通过 IR 传递状态,不通过 pass-private side table 传递语义。也就是说: + +```text +layout assignment output: + VMI value type already contains layout + VMI mask type already contains layout + concrete b8/b16/b32 granularity + required layout conversion already appears as pto.vmi.ensure_* or rematerialized producer + +vmi-to-vpto input: + may contain pto.vmi.* semantic ops and helper ops + must not contain layout-free VMI type + function signatures and op/module TypeAttr or TypedAttr payloads are part of this invariant, + not just SSA operands/results + +vmi-to-vpto output: + must not contain pto.vmi.* op/type/helper + must not contain unrealized_conversion_cast + function type attributes and any other op/module TypeAttr or TypedAttr payloads must not contain !pto.vmi.* +``` + +This prevents a fragile design where `vmi-to-vpto` has to rediscover layout decisions from defining ops. A VMI value +may be a function argument, block argument, `scf.if` result, `scf.for` carried value, or branch target argument; none +of those has a useful defining op. + +#### Layout Assignment State + +`vmi-layout-assignment` should be implemented as one module-level solver object: + +```cpp +struct DataValueState { + Value value; + VMIVRegType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr naturalLayout; // producer-preferred layout + SmallVector uses; // consumer requirements +}; + +struct MaskValueState { + Value value; + VMIMaskType surfaceType; + UnionFindNode eqClass; + VMILayoutAttr requestedLayout; + StringRef requestedGranularity; // b8/b16/b32 after inference + SmallVector uses; // consumer layout/granularity requests +}; + +struct LayoutUseRequest { + Operation *consumer; + VMILayoutAttr layout; + StringRef reason; // add/select/store/widen-source/etc. +}; +``` + +The solver runs in phases: + +```text +1. collect all VMI data/mask SSA values, including block arguments +2. add equivalence constraints +3. add producer natural-layout constraints +4. add consumer layout/granularity requests +5. solve each equivalence class +6. insert ensure_* or rematerialize producers for non-class-compatible uses +7. rewrite value types and function signatures +8. run pto-validate-vmi-layout-ir +``` + +Equivalence is only for cases where two logical values must have the same physical lane order: + +```text +add/sub/mul: + lhs == rhs == result + +cmpf/cmpi: + lhs == rhs + result mask requests lhs layout + element-width granularity + +select: + true_value == false_value == result + mask operand gets a use-site request for result layout + element-width granularity + +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == region_iter_arg[i] == yield[i] == result[i] + +cf.br/cf.cond_br: + successor operand[i] == successor block argument[i] + +direct internal func.call: + call operand[i] == callee argument[i] + call result[i] == all callee return operand[i] +``` + +Natural layout is not equivalence. For example: + +```text +extf f16 -> f32: + result natural layout = deinterleaved=2 + +extf f8 -> f32: + result natural layout = deinterleaved=4 + +truncf f32 -> f16: + result natural layout = contiguous + +truncf f32 -> fp8-like: + result natural layout = contiguous + +store/tile_write: + consumer requests contiguous externally visible order +``` + +If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless a +defined rematerialization path can split the value before the conflict. The first version should only rematerialize +trivially replayable producers: + +```text +constant +broadcast +constant_mask +create_mask +``` + +For non-rematerializable producers, insert `pto.vmi.ensure_layout` immediately before the consumer that requested the +different layout. This is the conservative first implementation rule. It works for ordinary SSA values, block +arguments, loop-carried values, branch arguments, and call results because the helper is dominated by the value at the +use site and does not need to be hoisted across control flow. `DominanceInfo` may be used later to hoist duplicated +helpers as an optimization, but it must not be required for correctness in the first implementation. + +That helper is a real IR marker: if `vmi-to-vpto` cannot lower its requested conversion, the program fails with an +explicit unsupported materialization diagnostic. + +#### Layout Assignment Implementation Frame + +This pass is a normal `OperationPass`. It deliberately does not use `DialectConversion`, because there is +no stable `Type -> Type` rule until the pass has solved producer preference, consumer demand, and control-flow joins. +The implementation should look like this: + +```cpp +struct LayoutSolver { + ModuleOp module; + MLIRContext *ctx; + + DenseMap dataIds; + SmallVector dataNodes; + DenseMap maskIds; + SmallVector maskNodes; + + SmallVector dataUseRequests; + SmallVector maskUseRequests; + DenseMap> firstReturnOperandsByFunc; + + LogicalResult collectConstraints(); + LogicalResult rewriteIR(); +}; +``` + +The concrete state objects should carry only facts that are materialized back into IR: + +```cpp +struct DataNode { + Value value; + VMIVRegType surfaceType; + unsigned parent; + VMILayoutAttr naturalLayout; // null means no producer preference yet +}; + +struct MaskNode { + Value value; + VMIMaskType surfaceType; + unsigned parent; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; // empty until b8/b16/b32 is known +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; +``` + +Do not store hidden layout state that `vmi-to-vpto` must rediscover. After this pass, a debugger should be able to read +the IR and know the chosen layout for every VMI value from its type alone. + +The pass body should stay simple: + +```cpp +void runOnOperation() override { + LayoutSolver solver(getOperation()); + if (failed(solver.collectConstraints()) || + failed(solver.rewriteIR()) || + failed(verifyLayoutAssignedVMIIR(getOperation()))) + signalPassFailure(); +} +``` + +The current implementation should map directly to this phase order: + +```cpp +LogicalResult LayoutSolver::run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module); +} +``` + +This order is intentional: + +```text +collect: + only discovers VMI values and block arguments. + +addConstraints: + only records equivalence, natural layout and consumer request facts. + It must not rewrite IR, because later CFG/call constraints may still merge + two values that were already seen. + +rewriteDataTypes: + commits solved data layouts to !pto.vmi.vreg type. + +insertDataUseMaterializations: + repairs use-site layout mismatch after the producer's committed type is known. + +inferMaskRequests: + uses already committed data layouts and element widths to infer concrete mask + layout/granularity requests. + +rewriteMaskTypes: + commits mask layout and b8/b16/b32 granularity. + +insertMaskUseMaterializations: + repairs mask layout/granularity mismatch. + +rewriteFunctionType: + updates function signatures last, after argument/result value types have been + rewritten. +``` + +Do not move `rewriteFunctionType` before use-site materialization. A function signature is the public shape of the +solved value class; changing it early makes call/return diagnostics depend on walk order and can hide an unresolved +use-site mismatch. + +Constraint collection is a module walk with explicit handlers. The important point is that each handler only records +facts; it must not rewrite while walking: + +```text +Data equivalence: + pto.vmi.addf/addi: lhs == rhs == result + pto.vmi.cmpf/cmpi: lhs == rhs + pto.vmi.select: true_value == false_value == result + pto.vmi.ensure_layout: source and result are not equivalent if layouts differ + +Data natural layout: + pto.vmi.extf f16->f32: result natural = deinterleaved=2 + pto.vmi.extf fp8-like->f32: result natural = deinterleaved=4 + pto.vmi.truncf: result natural = contiguous + pto.vmi.channel_merge with C inputs: result natural = deinterleaved=C + +Data use request: + pto.vmi.store/tile_write: value requested as contiguous + pto.vmi.channel_split with C results: source requested as deinterleaved=C + op requiring a common operand/result layout: request producer class layout + +Mask request: + cmp result: same data layout as operands, granularity from element width + select mask: same data layout as selected value, granularity from element width + store mask path: same data layout as stored value, granularity from element width +``` + +Control flow should be handled as equivalence, not as local op preference: + +```text +scf.if: + result[i] == then yield[i] == else yield[i] + +scf.for: + init_arg[i] == body iter_arg[i] == yield[i] == result[i] + +scf.while: + before argument[i] == condition forwarded operand[i] == after argument[i] + after yield[i] == result[i] + +scf.execute_region: + every nested scf.yield operand[i] == execute_region result[i] + +scf.index_switch: + every case/default yield operand[i] == index_switch result[i] + +cf.br: + operand[i] == destination block argument[i] + +cf.cond_br: + true operand[i] == true destination block argument[i] + false operand[i] == false destination block argument[i] + +cf.switch: + default operand[i] == default destination block argument[i] + case k operand[i] == case k destination block argument[i] + +func.call: + only direct internal callees are supported in the first implementation + call operand[i] == callee argument[i] + call result[i] == every corresponding callee return operand[i] +``` + +Function returns need one extra bookkeeping rule. A function result slot has one public layout in the function type, so +all `func.return` operands at the same index must be equivalent: + +```text +first return operand[i] == every later return operand[i] +function result type[i] is rewritten from the solved type of return operand[i] +call result[i] == every corresponding callee return operand[i] +``` + +If two return paths naturally produce incompatible layouts, the pass should report `VMI-LAYOUT-CONTRACT` instead of +silently choosing one path: + +```mlir +^a: + %x = pto.vmi.extf %f16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %x : !pto.vmi.vreg<128xf32> // natural deinterleaved=2 + +^b: + %y = pto.vmi.extf %f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + return %y : !pto.vmi.vreg<256xf32> // different result shape/layout, invalid by verifier/type first +``` + +For equal result shape but incompatible producer preferences, the same rule applies: + +```text +return slot 0 from f16->f32 path: natural deinterleaved=2 +return slot 0 from f8E4M3FN->f32 path with the same logical result shape: natural deinterleaved=4 +diagnostic: VMI-LAYOUT-CONTRACT: conflicting natural layouts ... +``` + +External declarations with VMI types are not a layout problem; they are ABI materialization. The first implementation +must reject them before rewriting: + +```text +VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan +``` + +The rewrite phase has three ordered steps: + +```text +1. Rewrite all data SSA value types to !pto.vmi.vreg. +2. Rewrite all mask SSA value types to !pto.vmi.mask. +3. Repair use-site mismatches by either rematerializing a cheap producer or inserting an explicit helper. +``` + +Rematerialization is allowed only when replaying the producer cannot change memory, control flow, or execution count +semantics: + +```text +allowed: + pto.vmi.constant splat + pto.vmi.broadcast + pto.vmi.constant_mask + pto.vmi.create_mask + +not allowed in the first implementation: + load/tile_read + arithmetic result + conversion result + shuffle/channel_split/channel_merge result + value crossing a call boundary or block argument +``` + +If rematerialization is not legal, insert: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +``` + +These helpers make the unresolved materialization explicit. `vmi-layout-assignment` is allowed to create them; +`vmi-to-vpto` is responsible for proving and lowering them. If lowering cannot prove the physical transform, the final +diagnostic should be an unsupported layout/materialization diagnostic, not silent incorrect code. + +Layout assignment completion checks: + +```text +1. No surface !pto.vmi.vreg remains. +2. No surface !pto.vmi.mask remains. +3. Every VMI function argument, result, block argument, branch operand, call operand, and return operand has the + layout-assigned type selected by the solved equivalence class. +4. Every consumer-specific mismatch is represented either by a rematerialized cheap producer or by an explicit + pto.vmi.ensure_* op immediately before that consumer. +5. External declarations with VMI types are rejected; they are not rewritten into an implicit ABI. +``` + +#### OneToN Conversion Details + +`vmi-to-vpto` should use MLIR `OneToNTypeConversion` for all structural rewriting that involves VMI values: + +```text +OneToNTypeConverter: + !pto.vmi.vreg -> !pto.vreg... + !pto.vmi.mask -> !pto.mask... + +Patterns: + framework structural OneToN patterns for func/return/scf + explicit OneToNOpConversionPattern for each pto.vmi semantic op + explicit helper patterns for pack/unpack/ensure_* + +Final gate: + reject residual pto.vmi.*, !pto.vmi.*, function signatures containing !pto.vmi.*, and unrealized_conversion_cast +``` + +The implementation is an `OperationPass` with this shape: + +```cpp +struct VMIToVPTOTypeConverter final : OneToNTypeConverter { + VMIToVPTOTypeConverter() { + addConversion([](Type t) { return t; }); + addConversion(convertVMIVRegType); + addConversion(convertVMIMaskType); + + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module)) || + failed(verifySupportedVMIToVPTOOps(module))) + return signalPassFailure(); + + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(module.getContext()); + populateVMIOneToNConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns))) || + failed(verifyNoResidualVMIIR(module))) + signalPassFailure(); +} +``` + +The type converter must define one canonical physical ordering and every pattern must use that ordering: + +```text +!pto.vmi.vreg + -> chunks in logical order: + chunk0 lanes [0..P-1], chunk1 lanes [P..2P-1], ... + +!pto.vmi.vreg + -> part-major chunks: + part0 chunk0 lanes [0,2,4,...] + part0 chunk1 next even lanes + part1 chunk0 lanes [1,3,5,...] + part1 chunk1 next odd lanes + +!pto.vmi.vreg + -> part-major chunks: + part0 lanes [0,4,8,...] + part1 lanes [1,5,9,...] + part2 lanes [2,6,10,...] + part3 lanes [3,7,11,...] + +!pto.vmi.mask + -> same part/chunk ordering as its data layout, one !pto.mask per physical part/chunk +``` + +`materializeVPTOToVMI` and `materializeVMIToVPTO` should use only `pto.vmi.pack` and `pto.vmi.unpack`. These ops are +conversion scaffolding; they are never valid final output. This makes accidental framework materialization visible in +the IR and easy to reject. + +Pattern population should be explicit: + +```cpp +void populateVMIOneToNConversionPatterns(VMIToVPTOTypeConverter &converter, + RewritePatternSet &patterns) { + populateFuncTypeConversionPatterns(converter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); + + patterns.add(converter, ctx); +} +``` + +Use upstream OneToN helpers where they exist: + +```text +func.func / func.return / func.call: + populateFuncTypeConversionPatterns + +scf.if / scf.for / scf.while and common structural SCF: + scf::populateSCFStructuralOneToNTypeConversions +``` + +Use project-local OneToN patterns where the current MLIR version does not provide a complete 1:N structural rewrite: + +```text +cf.br +cf.cond_br +cf.switch +scf.execute_region +scf.index_switch +``` + +These project-local structural patterns should not know VMI semantics. They only flatten operands/results according to +`OneToNTypeMapping`, convert successor block argument lists, and rebuild the same control-flow op. + +#### Pattern Authoring Checklist + +Every new `pto.vmi.*` lowering pattern should answer the same questions before it is added to +`populateVMIOneToNConversionPatterns`: + +```text +1. Does the op require all data operands/results to have identical physical arity? + If yes, check every ValueRange size against the result mapping before emitting VPTO ops. + +2. Does the op consume a mask? + If yes, the mask must already have concrete granularity and the same physical ordering expected by the data + operand. The pattern must not reinterpret a pred mask by lane count alone. + +3. Does the op observe contiguous logical order outside the register file? + If yes, require contiguous layout or explicitly lower the ensure_layout/materialization before using load/store + style VPTO ops. + +4. Does the op have padding lanes? + If yes, prove padding is unobservable. For load-like ops this requires a full-read safety proof or a fallback. + For store-like ops this requires a true predicate that disables padding writes. + +5. Does the op have target-specific side effects or ordering, such as squeeze/compact/store coupling? + If yes, put that check in verifySupportedVMIToVPTOOps before conversion starts, so the pass fails before partial + rewriting. + +6. Can it create pto.vmi.pack/unpack or unrealized_conversion_cast through framework materialization? + If yes, the semantic pattern still may be correct, but final residual verification must reject any leftover helper. +``` + +This gives a concrete division of labor: + +```text +verifySupportedVMIToVPTOOps: + shape/target/path support checks that should fail before any rewrite. + +OneToNOpConversionPattern: + mechanical lowering for a preflight-approved case. + +verifyNoResidualVMIIR: + final hard gate for missed patterns, illegal materializations and hidden VMI type payloads. +``` + +Do not put target capability probing in a structural pattern. For example, a `cf.br` pattern must never ask whether +`deinterleaved=4` can be materialized. It only converts successor operands. The semantic op that created or consumes +the value is responsible for proving the VPTO lowering path. + +#### Converter Use By Pass + +The implementation should be reviewable with the following rule: + +```text +pto-validate-vmi-ir: + no TypeConverter, no ConversionTarget, no rewrite. + +vmi-layout-assignment: + no TypeConverter for choosing layouts. + It may use RewriterBase after solving, but not DialectConversion as the solving model. + +vmi-to-vpto: + must use OneToNTypeConverter for VMI types. + must use OneToNOpConversionPattern for semantic VMI ops. + should use upstream func/scf OneToN helpers when available. + may add project-local structural OneToN patterns only for missing framework coverage. +``` + +The main reason is not style. It is correctness across values without defining ops: + +```mlir +^bb0(%x: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + cf.br ^bb1(%x : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + +^bb1(%y: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): + %z = pto.vmi.addf %y, %y + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + ... +``` + +`%y` has no defining VMI op. Its physical values are the converted block arguments produced by OneToN block signature +conversion. Any implementation that tries to recover physical parts from a defining op is therefore incomplete for +control flow, function arguments and loop-carried values. + +When writing semantic `OneToNOpConversionPattern`, do not infer physical parts from a defining op. Use the OneToN +adaptor's per-original-operand `ValueRange`: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + ... + rewriter.replaceOp(op, physicalResults, adaptor.getResultMapping()); +} +``` + +Every VMI semantic lowering then follows the same shape: + +```cpp +ValueRange lhsParts = adaptor.getLhs(); +ValueRange rhsParts = adaptor.getRhs(); +TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + +for each physical part index i: + emit physical VPTO op for lhsParts[i], rhsParts[i] -> resultTypes[i] + +replace op with all physical results using adaptor.getResultMapping() +``` + +This convention is mandatory for values crossing control flow. For example an `scf.for` iter arg has no defining op; +its physical parts are the converted block arguments created by OneToN signature conversion. + +The concrete pattern shape is: + +```cpp +LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange in0 = adaptor.getIn0(); + ValueRange in1 = adaptor.getIn1(); + TypeRange outTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (in0.size() != in1.size() || in0.size() != outTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [i, outType] : llvm::enumerate(outTypes)) { + results.push_back(rewriter.create(op.getLoc(), outType, + in0[i], in1[i]).getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +For non-VMI operands, use a helper like `getSingleValue(op, adaptor.getOffset(), "...")` and fail if the framework +unexpectedly expanded them. This catches malformed conversion rules early. + +#### Semantic Lowering Buckets + +The first implementation should split VMI op lowering into four buckets: + +```text +identity/helper: + pack, unpack, ensure_layout identity/materialization cases, ensure_mask_* identity case + +per-part elementwise: + addf, addi, subf, subi, mulf, muli, divf, minf, maxf, negf, absf, absi, sqrt, exp, ln, relu, andi, ori, xori, shli, shrui, not, cmpf, cmpi, select + +per-part predicate: + mask_and, mask_or, mask_xor, mask_not + +layout-producing conversion: + extf, truncf, bitcast + +externally ordered memory: + load, store, tile_read, tile_write +``` + +Per-part elementwise ops are straightforward only when all operands/results already share the same assigned layout: + +```text +logical deinterleaved=2 value: + part0 contains logical lanes 0, 2, 4, ... + part1 contains logical lanes 1, 3, 5, ... + +vmi.addf/subf/mulf on two such values: + emit the matching VPTO per-part op for part0_lhs, part0_rhs + emit the matching VPTO per-part op for part1_lhs, part1_rhs +``` + +This preserves logical lane semantics because each physical part contains the same logical lane subset for all +operands and the result. + +Memory ops are different because their observable semantics are contiguous logical order: + +```text +vmi.store of deinterleaved=2: + cannot blindly store part0 then part1 as the final memory order + must use a store plan that writes logical lane 0,1,2,3,... order + or materialize source to contiguous before physical store +``` + +Therefore `store/tile_write` lowering must either: + +```text +1. consume contiguous layout directly, or +2. lower ensure_layout(deinterleaved -> contiguous), then store, or +3. use target store instructions whose dist mode proves contiguous external order +``` + +The first implementation uses option 2 for full physical chunks: + +```text +vmi.load: + emit contiguous physical vlds chunks in memory order + materialize contiguous -> assigned result layout + +vmi.masked_load: + only when the full physical read footprint is proven safe + emit contiguous physical vlds chunks in memory order + select loaded lanes against passthru with the VMI mask + if enable-stable-gather-masked-load is set, reject pto.vmi.masked_load with + a stable TODO diagnostic until the VGATHER2-based strict no-read path is + implemented + +vmi.store: + materialize assigned source layout -> contiguous + emit physical vsts chunks in memory order + +vmi.tile_read / vmi.tile_write: + follow the same externally ordered rule +``` + +Current direct memory lowering may only emit VPTO vector memory ops for +UB-backed memory. Concretely, a `!pto.ptr<..., ub>` is legal, a +`!pto.ptr<..., gm>` is not; a memref with `#pto.address_space` is legal, +and a memref without a memory-space attribute is treated as unknown/local to +this stage to preserve existing local-view tests. A memref explicitly marked +GM or another non-VEC space is rejected by `vmi-to-vpto`. + +GM-backed VMI memory is still a valid semantic source/sink before this pass, +but direct lowering does not perform GM<->UB movement. That must be represented +by an earlier/lower memory access plan, scratch materialization, or UB view +normalization before `vmi-to-vpto`; otherwise the diagnostic is +`VMI-UNSUPPORTED` and names the GM-backed source/destination. + +For `deinterleaved=2`, `vldsx2 DINTLV_B*` and `vstsx2 INTLV_B*` are valid optimization candidates because the ISA has +an explicit two-stream de/interleave memory distribution mode. This should be implemented only as a peephole inside +`vmi-to-vpto` after the generic plan is correct: + +```text +vmi.load result layout deinterleaved=2: + vldsx2 DINTLV_B* can directly produce part0/part1 chunks + +vmi.store source layout deinterleaved=2: + vstsx2 INTLV_B* can directly store part0/part1 chunks in logical memory order +``` + +Do not generalize this to `deinterleaved=4` unless the two-level dist composition is proven against the ISA. The +fallback for `deinterleaved=4` remains generic layout materialization plus ordinary memory ops. + +Partial/tail load-style memory is legal only when the lowering can prove the full physical read footprint is safe. The +current direct path supports this limited proof: + +```text +source is a statically shaped memref +offset is a constant non-negative index, or tile_read implicit offset 0 +offset + physical_arity(result) * lanes_per_physical_part <= static memref element count +``` + +When this proof holds, `vmi.load` / `vmi.tile_read` may still issue full `pto.vlds` chunks. The extra padding lanes are +not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, dynamic +offsets, dynamic memrefs, and insufficient static footprints remain unsupported: + +```text +VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read +footprint (...; safe-read proof failed: ...) +VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering emits +pto.vlds/pto.vsts and requires UB-backed memory) +``` + +Store-style ops are different because inactive lanes can be made write-free with true predicates. `vmi.store`, +`vmi.masked_store`, and `vmi.tile_write` therefore support the explicit contiguous/deinterleaved tail-store +materialization paths described below. + +## 2. Slice 0: Type / Attr Bootstrap + +第一步只实现 VMI type、layout attr 和纯 helper,不实现任何 conversion pass。 + +### 2.1 `#pto.vmi.layout` + +定义 `VMILayoutAttr`: + +```mlir +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +建议内部参数: + +```text +kind: enum { contiguous, deinterleaved } +factor: int64_t +``` + +Verifier: + +```text +contiguous: + factor must be 1 + +deinterleaved: + factor must be 2 or 4 +``` + +禁止接受其它 spelling,例如 `stride2`、`stride4`、`parity`、`mod_split`、`blocked`。 + +### 2.2 `!pto.vmi.vreg` + +定义 `VMIVRegType`: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +elementType: Type +layout: Attribute // null means surface type before layout assignment +``` + +Verifier: + +```text +elementCount > 0 +elementType is scalar-like integer / float / index supported by VMI +layout is null or VMILayoutAttr +deinterleaved=4 only allowed when target registry later supports it; type verifier only checks shape +``` + +不要要求 `elementCount * bitwidth(elementType)` 是 256B 整数倍。 + +### 2.3 `!pto.vmi.mask` + +定义 `VMIMaskType`: + +```mlir +!pto.vmi.mask<128xpred> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +建议参数: + +```text +elementCount: int64_t +granularity: enum/string { pred, b8, b16, b32 } +layout: Attribute +``` + +Verifier: + +```text +elementCount > 0 +surface mask may use pred and no layout +layout-assigned mask must use b8/b16/b32 and must have VMILayoutAttr +pred mask must not carry layout +``` + +### 2.4 Lane Map Helper + +在 C++ 中提供纯函数 helper,供 verifier、layout assignment、VMI-to-VPTO 和测试共用: + +```text +getDataLanesPerPart(elementType) +getMaskLanesPerPart(granularity) +getVMIPhysicalArity(type) +mapLogicalLaneToPhysical(type, logicalLane) +mapPhysicalLaneToLogical(type, part, chunk, lane) +isPaddingLane(type, part, chunk, lane) +``` + +这些 helper 是 hard dependency。任何 pass 不能重新手写一套 arity 公式。 + +Slice 0 完成条件: + +```text +1. VMI type/attr 能 parse/print round-trip。 + Covered by vmi_type_attr_parse.pto. +2. 非法 layout factor、非法 mask granularity、非法 element count 有 verifier diagnostic。 + Covered by vmi_layout_factor_invalid.pto, + vmi_mask_granularity_invalid.pto, vmi_type_element_count_invalid.pto, + and vmi_mask_concrete_without_layout_invalid.pto / + vmi_mask_pred_with_layout_invalid.pto. +3. helper 单测或 lit 测试覆盖 contiguous/deinterleaved=2/deinterleaved=4 和非整 tile。 + Covered by vmi_to_vpto_type_only.pto and + vmi_to_vpto_type_arity.pto. +``` + +## 3. Slice 1: Minimal VMI Op Set + +不要一次实现 75 个 semantic op。第一批只实现能跑通 widening + elementwise + store 的闭环。 + +### 3.1 必选 semantic op + +Construction: + +```text +pto.vmi.constant +pto.vmi.broadcast +pto.vmi.iota +pto.vmi.create_mask +pto.vmi.constant_mask +``` + +`pto.vmi.from_elements` belongs to the eventual construction surface, but it is +not part of Slice 1. Do not synthesize it from ad hoc scalar lane inserts until +there is an explicit vreg immediate, scalar-insert, or scratch materialization +contract. + +Mask: + +```text +pto.vmi.mask_and +pto.vmi.mask_or +pto.vmi.mask_xor +pto.vmi.mask_not +``` + +Arithmetic / conversion: + +```text +pto.vmi.addf +pto.vmi.addi +pto.vmi.subf +pto.vmi.subi +pto.vmi.mulf +pto.vmi.muli +pto.vmi.fma +pto.vmi.divf +pto.vmi.minf +pto.vmi.maxf +pto.vmi.negf +pto.vmi.absf +pto.vmi.absi +pto.vmi.sqrt +pto.vmi.exp +pto.vmi.ln +pto.vmi.relu +pto.vmi.andi +pto.vmi.ori +pto.vmi.xori +pto.vmi.shli +pto.vmi.shrui +pto.vmi.not +pto.vmi.cmpf +pto.vmi.cmpi +pto.vmi.select +pto.vmi.extf +pto.vmi.truncf +pto.vmi.bitcast +``` + +`pto.vmi.shrui` represents logical right shift and lowers to `pto.vshr`. +`pto.vmi.shrsi` is intentionally not defined until VPTO exposes or documents +an arithmetic right-shift contract distinct from logical right shift. +Integer div/rem, integer casts, int-float casts, and index casts are also +intentionally outside the current VMI surface until signedness, rounding, +saturation, overflow/remainder, and target lowering contracts are explicit. + +Memory: + +```text +pto.vmi.load +pto.vmi.masked_load +pto.vmi.gather +pto.vmi.expand_load +pto.vmi.store +pto.vmi.masked_store +pto.vmi.scatter +pto.vmi.compress_store +pto.vmi.tile_read +pto.vmi.tile_write +``` + +Current implementation scope note: + +```text +pto.vmi.gather / scatter +pto.vmi.active_prefix_index / compress / compress_store +future scan / contract style ops +``` + +These families are not first-stage completion blockers. The dialect surface may +define them, and the lowering may keep narrow direct paths when the target VPTO +contract is already explicit. Full semantic coverage for these families remains +out of scope until cross-chunk state, duplicate-index ordering, prefix carry, +compaction state, or contraction accumulation contracts are explicitly designed. +Unsupported shapes must fail before OneToN rewrite with `VMI-UNSUPPORTED`; they +must not fall through to residual-op diagnostics. + +Permutation: + +```text +pto.vmi.shuffle +pto.vmi.channel_split +pto.vmi.channel_merge +``` + +Internal helper: + +```text +pto.vmi.ensure_layout +pto.vmi.ensure_mask_layout +pto.vmi.ensure_mask_granularity +pto.vmi.unpack +pto.vmi.pack +``` + +### 3.2 Op Verifier Rules + +Construction op verifier: + +```text +constant value must be a dense elements attr, and its element type/count must match the result vreg +broadcast scalar type must match the result element type +constant_mask value must be a dense elements attr, must have i1 element type, and its element count must match the +result mask +create_mask may produce surface pred mask or concrete layout-assigned mask +mask_and/mask_or/mask_xor/mask_not require all mask operands/results to have the same logical lane count; if any +mask is layout-assigned, all masks must carry the same layout and granularity +``` + +Elementwise op verifier: + +```text +all data operands have same logical lane count +all data operands have same element type except documented conversion op +if any operand has layout, all layouted operands/results must agree +surface op may have no layout before vmi-layout-assignment +``` + +`select` verifier: + +```text +mask lane count == true/false/result lane count +mask layout must match data layout after layout assignment +mask granularity must match selected element width after layout assignment +``` + +`extf/truncf` verifier: + +```text +source/result lane count equal +source/result element types are float +bitwidth changes in the expected direction +``` + +Memory op verifier: + +```text +load/tile_read memory element type must match result VMI data element type when the source is PtrType or MemRefType +store/tile_write memory element type must match stored VMI data element type when the destination is PtrType or MemRefType +``` + +`shuffle` verifier: + +```text +static mask length == result lane count +each mask index selects an existing source logical lane +result element type == source element type +no padding lane may be selected +``` + +`channel_split` verifier: + +```text +result count C >= 2 +input lane count N == C * M +each result is vreg +channel c result semantics: out[c][i] = input[i * C + c] +if any source/result carries layout, all must carry layout +for C=2/4, layout-assigned source must be contiguous or deinterleaved=C +layout-assigned results must be contiguous +``` + +`channel_merge` verifier: + +```text +operand count C >= 2 +all operands have same M and element type T +result is vreg +result semantics: result[i * C + c] = input[c][i] +if any input/result carries layout, all must carry layout +layout-assigned inputs must be contiguous +for C=2/4, layout-assigned result must be contiguous or deinterleaved=C +``` + +`ensure_layout` verifier: + +```text +source/result are both VMIVRegType +same elementCount and elementType +source/result both layout-assigned +source layout may equal result layout; that is a canonical no-op +``` + +`ensure_mask_layout` verifier is identical except it uses `VMIMaskType` and preserves granularity. + +`ensure_mask_granularity` verifier: + +```text +source/result are both VMIMaskType +same elementCount +same layout +source/result granularity are b8/b16/b32 +logical predicate value must be preserved +``` + +`pack/unpack` verifier: + +```text +VMI side must be layout-assigned +physical operand/result count == getVMIPhysicalArity(VMI type) +physical data types are !pto.vreg +physical mask types are !pto.mask +ordering is the shared Physical Arity helper order +``` + +Slice 1 完成条件: + +```text +1. Every Slice 1 op parses, prints, and has negative verifier tests. + Arithmetic/mask/helper verifier coverage includes vmi_elementwise_kind_invalid.pto, + vmi_mask_logic_invalid.pto, vmi_ensure_layout_surface_invalid.pto, + vmi_unpack_arity_invalid.pto, and vmi_pack_arity_invalid.pto. +2. Helper ops are marked internal in docs and rejected by final VMI-to-VPTO gate if residual. +3. `channel_split/channel_merge` have tests proving shuffle-equivalent lane order. +``` + +## 4. Slice 2: VMI Producer Boundary Verifier + +VMI core implementation starts from VMI IR. Producer-specific import is outside this manual's core path. + +实现 `PTOValidateVMIIR.cpp` 中的 VMI boundary verifier: + +```text +recommended pass name: pto-validate-vmi-ir +anchor: func::FuncOp or ModuleOp +source file: lib/PTO/Transforms/PTOValidateVMIIR.cpp +``` + +Boundary verifier checks: + +```text +all logical vector values use !pto.vmi.vreg / !pto.vmi.mask +all logical vector behavior is represented by pto.vmi semantic ops +surface VMI values before layout assignment do not carry layout +no physical VPTO op appears before vmi-to-vpto +no hidden side table is required to interpret VMI values +scalar/tensor/debug/transform boundary has already been resolved by producer +``` + +Slice 2 完成条件: + +```text +1. VMI-native positive tests pass boundary verification. + Covered by vmi_producer_boundary_valid.pto. +2. Physical VPTO op before VMI-to-VPTO is rejected. + Covered by vmi_producer_boundary_physical_invalid.pto, including both + physical function types and physical VPTO ops. +3. Layout-assigned type before layout assignment is rejected unless the test explicitly starts after layout assignment. + Covered by vmi_producer_boundary_layout_invalid.pto and + vmi_producer_boundary_mask_layout_invalid.pto. +4. Missing VMI type/op invariants produce `VMI-PASS-INVARIANT` or a more specific diagnostic. + Covered by vmi_producer_boundary_non_vmi_op_invalid.pto, + vmi_producer_boundary_helper_invalid.pto, and the producer-boundary + TypeAttr nested/surface/layout invalid tests. +``` + +## 5. Slice 3: `vmi-layout-assignment` + +推荐实现为 pass: + +```text +recommended pass name: vmi-layout-assignment +anchor: ModuleOp +source file: lib/PTO/Transforms/VMILayoutAssignment.cpp +``` + +`vmi-layout-assignment` 必须是 module 级 pass。函数参数、`func.return` operand、 +`func.call` operand/result 和 callee signature 需要在同一个约束图里求解;函数级 pass +只能看到局部 body,无法安全地同步 callsite 和 callee。 + +### 5.1 Internal Data Model + +Build one layout node per VMI SSA value: + +```text +Operation result +BlockArgument +Region yield operand +Function argument/result +Call operand/result +``` + +Each node records: + +```text +logical type: VMIVRegType or VMIMaskType +allowed layouts: bitset {contiguous, deinterleaved2, deinterleaved4} +required mask granularity: pred/b8/b16/b32 or unknown +natural layout preference +hard constraints +soft costs +``` + +No information required by later passes may live only in this data structure. After the pass, type/attr/op +operands must fully describe the result. + +### 5.2 Transfer Functions + +Minimum Slice 3 transfer functions: + +```text +constant/broadcast/create_mask/constant_mask: + rematerializable in any legal consumer layout + +mask_and/mask_or/mask_xor/mask_not: + all mask operands/results same layout and granularity + +addf/addi/subf/subi/mulf/muli/divf/minf/maxf/negf/absf/absi/sqrt/exp/ln/relu/andi/ori/xori/shli/shrui/not/cmpf/cmpi/select: + all data operands/results same layout + mask layout follows data layout + +extf f16 -> f32: + result natural layout = deinterleaved=2 + source requires contiguous layout for the direct vcvt part=EVEN/ODD path + partial/tail source chunks are supported when they still fit in one physical + source chunk and produce the natural two-part result; source padding lanes map + only to result padding lanes + +extf f8 -> f32: + result natural layout = deinterleaved=4 + source requires contiguous layout for the direct vcvt part=P0/P1/P2/P3 path + partial/tail source chunks are supported under the same one-source-chunk + contract; source padding lanes map only to result padding lanes + +truncf f32 -> f16: + can consume deinterleaved=2 and produce contiguous + current implementation records a deinterleaved=2 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + partial/tail source pairs are supported when the two deinterleaved source + parts pack into one contiguous result chunk; source padding lanes map only to + result padding lanes + +truncf f32 -> fp8-like: + can consume deinterleaved=4 and produce contiguous + current implementation records a deinterleaved=4 source use-site request and + inserts pto.vmi.ensure_layout when the source value solved to contiguous. + The lowering emits four pto.vcvt operations with part=P0/P1/P2/P3, then ORs + the mutually exclusive partial destination registers into one contiguous fp8 + result. This mirrors the hardware packed-4 contract: each source part owns + one quarter of the destination byte lanes, so the final externally visible + vector remains logical lane order 0..N-1 after the merge. + +bitcast: + source and result layouts must match + source/result total logical bits must match + current implementation supports identical physical arity when every source/result + physical chunk carries the same number of logical bits. This covers full chunks + and partial/tail chunks such as 65xf32 -> 130xi16, where the second physical + chunk carries 32 logical bits on both sides. Partial/tail bitcast remains + unsupported if source padding bits would become result logical bits. + +load/tile_read: + result layout chosen by consumers unless memory plan has a cheaper registered sink/source + +store/tile_write: + can consume any layout only if target registry has preserving store path + current implementation records a contiguous use-site request for vmi.store and + inserts pto.vmi.ensure_layout when the stored value class solved to a + non-contiguous layout. This makes externally visible memory order explicit in + IR before vmi-to-vpto. If explicit IR reaches vmi-to-vpto with a + deinterleaved=2/4 tail value, the direct lowering may still materialize it to + contiguous physical chunks first, but only when every deinterleaved part has + the same physical chunk count and therefore forms complete intlv groups. + +shuffle/channel_split/channel_merge: + default result layout contiguous unless target registry provides direct layout-preserving path + current implementation supports pto.vmi.shuffle when every result physical + chunk forwards one source physical chunk with identical lane positions for + all non-padding result lanes. Result padding lanes are ignored by the + forwarding proof and remain unobservable after physicalization. This allows + whole-chunk projection/reordering under contiguous or explicit deinterleaved + layouts, including tail-prefix projections such as `[0, 1, 2, 3] -> + !pto.vmi.vreg<4xf32>`. Arbitrary lane permutation remains unsupported unless + the vselr index-vector path below can materialize it. + current implementation supports channel_split/channel_merge for 2 or 4 + channels. channel_split consumes a natural deinterleaved=C source and produces + contiguous per-channel results; channel_merge consumes contiguous per-channel + inputs and produces a natural deinterleaved=C result. The direct path also + accepts partial/tail channel groups when the virtual deinterleaved=C channel + layout has the same physical arity as the source/result representation, so + every physical group can be materialized with complete intlv/dintlv pairs. + Arity-changing partial groups such as splitting 4xf32 into two 2xf32 channels + remain unsupported. If a producer/consumer + requires dense contiguous layout, pto.vmi.ensure_layout materializes the + pto.vdintlv/pto.vintlv tree explicitly. Non-matching layouts and other channel + counts remain unsupported. +``` + +### 5.3 Solver Order + +Implement deterministic solving: + +```text +1. Collect region/SCC constraints, including scf/cf/function/call boundaries. +2. Propagate impossible layouts and required mask granularities. +3. Pick a layout per node using minimum cost. +4. Tie-break: explicit layout already present on the VMI type, then natural layout, then contiguous. +5. Rewrite result/block/function types to layout-assigned VMI types. +6. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses that need conversion. +7. Clone rematerializable producers per use when cheaper than conversion. +8. Run verifier gate. +``` + +Current implementation status: + +```text +implemented: + extf source -> contiguous use-site request for supported f16/fp8-like to f32 paths + truncf f32->f16 source -> deinterleaved=2 use-site request + truncf f32->fp8-like source -> deinterleaved=4 use-site request + single-use pto.vmi.load / tile_read results can adopt a consumer-requested + layout before type rewrite; this covers direct memory producers such as + load -> truncf without inserting a redundant ensure_layout + vmi.store data operand -> contiguous use-site request + explicit VMI vreg layout is preserved as an initial solver constraint + explicit concrete VMI mask layout/granularity is preserved as an initial solver constraint + channel_split source -> deinterleaved=C use-site request + channel_split results -> contiguous natural layout + channel_merge inputs -> contiguous use-site request + channel_merge result -> deinterleaved=C natural layout + shuffle without explicit layouts -> contiguous source use-site request and contiguous result natural layout + shuffle with explicit source/result layouts -> preserve explicit layouts and let vmi-to-vpto prove chunk forwarding + pto.vmi.ensure_layout insertion for non-contiguous store operands + pto.vmi.ensure_layout insertion for truncf source materialization + pto.vmi.ensure_mask_layout / ensure_mask_granularity insertion for select mask operands + pto.vmi.create_mask / constant_mask rematerialization for select mask operands when the consumer needs a + different mask layout/granularity + splat pto.vmi.constant rematerialization for data operands when the consumer needs + a different layout + pto.vmi.broadcast rematerialization for data operands when the consumer needs + a different layout + scf.execute_region result/yield layout equivalence + scf.index_switch result/yield layout equivalence + scf.while state layout equivalence + +not yet implemented: + generic per-consumer layout request table for every VMI op + producer rematerialization for non-splat data constants and other cheap producers + cost model / target capability registry +``` + +Do not implement a local greedy pattern pass that ignores block arguments or function signatures. + +### 5.4 CFG Rules + +CFG 处理分两层。第一层是必须做的 layout equivalence:同一个控制流值在 +result、yield、region/block argument 之间必须形成同一个 layout/mask 约束组。第二层才是 +layout conflict resolution:当同一个 producer 的不同 consumers 希望不同 layout 时,插入 +`ensure_layout`、`ensure_mask_layout` 或 rematerialize producer。 + +当前可落地的最小实现先做第一层。它不尝试在 branch 边界自动插入 conversion,因此下面这些 +关系一旦因为 natural layout 或 mask granularity 冲突无法合并,必须报 `VMI-LAYOUT-CONTRACT`, +不能默默选择某一边。 + +`scf.if` equivalence: + +```text +for each result index i: + scf.if result[i] + == then scf.yield operand[i] + == else scf.yield operand[i] +``` + +如果 value 是 `!pto.vmi.vreg`,合并 data layout 约束;如果 value 是 +`!pto.vmi.mask`,合并 mask layout 和 granularity 请求。这样 `%m = scf.if ... -> +!pto.vmi.mask` 后被 `vmi.select` 消费时,select 对 `%m` 推出的 `b8/b16/b32 + layout` +会传播回两边 yield 的 mask producer。 + +`scf.for` equivalence: + +```text +for each iter_arg index i: + init_arg[i] + == region_iter_arg[i] + == scf.yield operand[i] + == scf.for result[i] +``` + +这条规则避免 loop-carried value 每次迭代改变 layout。对于 `extf f16->f32` 作为 init、 +loop body 内部 `addf` 并 yield 的 case,`extf` 的 natural layout `deinterleaved=2` +必须稳定传递到 `%acc` region arg、`scf.yield` 和 loop result。 + +`cf.br` / `cf.cond_br` equivalence: + +```text +for each successor operand index i: + branch successor operand[i] + == successor block argument[i] +``` + +当前实现覆盖标准 `cf.br`、`cf.cond_br` 和 `cf.switch`。其中 `cf.switch` 的 default operands +与 default destination block arguments 按 index 建 layout 等价关系;每个 case operand segment +与对应 case destination block arguments 按 index 建 layout 等价关系。更泛化的 +`BranchOpInterface` op 如果携带 VMI type,后续要么补对应 mapping,要么在 layout assignment +阶段明确 diagnostic,不能让 hidden default layout 穿过去。 + +当前实现支持携带 VMI value 的 `scf.execute_region`:execute_region result 与直属 region terminator +`scf.yield` operands 按 result index 合并到同一个 layout 等价类。嵌套 region 内属于其他 op 的 +`scf.yield` 不参与 execute_region 的等价关系。 + +当前实现支持携带 VMI value 的 `scf.index_switch`:default/case region `scf.yield` operands 与 +index_switch results 按 result index 合并到同一个 layout 等价类。 + +当前实现支持携带 VMI value 的 `scf.while`:init operand、before region argument、`scf.condition` +forwarded operand、after region argument、after region `scf.yield` operand 和 while result 按状态 +index 合并到同一个 layout 等价类。`scf.condition` 的 i1 condition 本身不参与 VMI layout 约束。 + +Function boundary: + +```text +internal functions may get specialized layouted signatures +external ABI must not expose VMI layout +recursive SCC requires fixed-point signature layout +``` + +当前实现支持 direct `func.call` 到同一 module 内带 body 的 `func.func`: + +```text +call operand[i] == callee argument[i] +call result[i] == every callee return operand[i] +same-result-index return operands inside one callee are equivalent +``` + +如果携带 VMI type 的 call 无法解析到带 body 的 direct callee,layout assignment 必须报 +`VMI-LAYOUT-CONTRACT`。后续如需支持 public/external ABI,必须先定义 VMI 值如何在 ABI +边界 materialize,不能把 layouted VMI type 暴露出去。 +当前实现明确拒绝携带 VMI type 的 `func.call_indirect`,因为它没有可解析的 direct internal +callee signature/body 可参与 layout constraint solving。 + +当前实现对携带 VMI type 的 external function declaration 报 `VMI-LAYOUT-CONTRACT`,因为还没有 +定义 VMI value 的外部 ABI materialization plan。没有 VMI type 的 external declaration 必须在 +`rewriteFunctionType` 中保持原签名,不能因为没有 entry block arguments 被改写成空签名。 + +`ptoas --enable-vmi` 额外拒绝 public `func.func` 的 VMI-typed signature: + +```text +VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan +``` + +这样 test-opt 仍可覆盖 internal/private function signature physicalization,用户入口则不会把 +layout-assigned VMI 值隐式暴露成 public ABI。 + +Slice 3 完成条件: + +```text +1. All VMI values have layout-assigned types after the pass. +2. All masks have b8/b16/b32 granularity after the pass. +3. CFG and call tests prove branch/yield/signature layout equality. +4. Multi-use rematerializable producer tests prove broadcast, constant, iota, + create_mask, and constant_mask rematerialization vs ensure_layout / + ensure_mask_* is deterministic. +5. The pass runs the layout-assigned VMI hard gate before returning, including + recursive TypeAttr/TypedAttr rejection; covered by + vmi_layout_assignment_post_gate_type_attr_invalid.pto. +``` + +## 6. Slice 4: `vmi-to-vpto` + +推荐实现为 pass: + +```text +recommended pass name: vmi-to-vpto +anchor: ModuleOp +source file: lib/PTO/Transforms/VMIToVPTO.cpp +``` + +第一步实现必须先落地 MLIR OneToN conversion 框架: + +```text +VMIToVPTOTypeConverter : OneToNTypeConverter: + !pto.vmi.vreg -> ordered !pto.vreg list + !pto.vmi.mask -> ordered !pto.mask list + +Structural patterns: + populateFuncTypeConversionPatterns + scf::populateSCFStructuralOneToNTypeConversions + project-local OneToN patterns for cf.br/cf.cond_br/cf.switch + project-local OneToN patterns for scf.execute_region/scf.index_switch + +VMI patterns: + OneToNOpConversionPattern for pack/unpack/ensure_*/semantic ops + +Final residual gate: + reject pto.vmi.*, !pto.vmi.*, unrealized_conversion_cast + scan SSA types, block argument types, function signatures, and op/module TypeAttr or TypedAttr payloads +``` + +这一步可以先支持 type-only physicalization 和 `pack/unpack` helper physicalization,但不能让未实现的 VMI semantic op 静默通过。 +如果还有 `pto.vmi.*` 或 VMI type 残留,必须报 `VMI-RESIDUAL-OP`。 + +当前 slice 支持 VMI function/input/block argument 展开成 physical arguments,并支持: + +```text +pto.vmi.unpack(layouted VMI aggregate) -> physical parts: + replace with OneToN adaptor source parts + +pto.vmi.pack(physical parts) -> layouted VMI aggregate: + replace with the physical parts through resultMapping + +pto.vmi.ensure_layout / ensure_mask_layout / ensure_mask_granularity: + ensure_layout must compare the original VMI source/result layout attrs, not only the converted physical type list. + If source/result layouts are identical, replace with source parts. This identity case supports partial/tail physical + chunks because no lane reordering or packing is performed. + If deinterleaved=2 -> contiguous, emit one pto.vintlv. + If contiguous -> deinterleaved=2, emit one pto.vdintlv. + If deinterleaved=4 -> contiguous, emit the two-level pto.vintlv tree. + If contiguous -> deinterleaved=4, emit the reverse two-level pto.vdintlv tree. + ensure_mask_layout supports the same contiguous <-> deinterleaved=2/4 layout conversions with predicate + rearrange ops: + deinterleaved=2 -> contiguous: pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=2: pto.pdintlv_b8/b16/b32 + deinterleaved=4 -> contiguous: two-level pto.pintlv_b8/b16/b32 tree + contiguous -> deinterleaved=4: two-level pto.pdintlv_b8/b16/b32 tree + ensure_mask_granularity supports concrete b8/b16/b32 logical predicate-preserving conversion: + widening b8 -> b16 -> b32: split each physical chunk with pto.punpack LOWER/HIGHER + narrowing b32 -> b16 -> b8: pack physical chunk pairs with pto.ppack LOWER/HIGHER and merge halves with pto.por + b8 <-> b32 conversions are lowered as two adjacent steps through b16. + +pto.vmi.broadcast: + current direct lowering requires the physical result element width to be 8, + 16, or 32 bits, because the vdup is predicated by pto.mask. + Other semantic element types need a dedicated materialization contract before + vmi-to-vpto may lower them. + for each physical result part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical result element width + emit pto.vdup(scalar, all_true_mask) + This is layout-independent because every logical lane has the same scalar value. A deinterleaved layout simply + receives one identical vdup per partition/chunk; no vintlv/vdintlv is needed. + +pto.vmi.iota: + semantics: + ASC: result[lane] = base + lane + DESC: result[lane] = base - lane + supported element types follow pto.vci: + integer 8/16/32 and f16/f32 + contiguous full-chunk direct path: + for each physical chunk c: + chunk_base = base +/- c * lanes_per_part + emit pto.vci chunk_base {order = ASC|DESC} + deinterleaved layout requires strided index materialization because physical part p contains logical lanes: + p, p + factor, p + 2 * factor, ... + The required formula is: + ASC: base + p + factor * local_lane + DESC: base - p - factor * local_lane + The current lowering materializes this per physical chunk: + local = pto.vci 0 + scaled = pto.vmuls local, factor + ASC: result = pto.vadds scaled, base + part_offset + DESC: result = pto.vsub pto.vdup(base - part_offset), scaled + Partial/tail chunks are allowed. The physical padding lanes receive the natural continuation of the generated iota + sequence and remain padding/undef at the VMI semantic level; memory writes, masks, reductions, and other + externally-visible consumers must still obey the VMI padding rules. + +pto.vmi.constant_mask: + support dense bool constants for concrete b8/b16/b32 masks. For each physical chunk: + if the active lanes form a prefix: + emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + otherwise decompose the static bitset into active runs: + run [lo, hi) = prefix(hi) & ~prefix(lo) + combine runs with pto.por under an all-true predicate + pred-only masks remain unsupported until they have a concrete b8/b16/b32 consumer granularity. + +pto.vmi.mask_and / mask_or / mask_xor / mask_not: + for each physical predicate part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical mask granularity + mask_and emits pto.pand(lhs_part, rhs_part, all_true_mask) + mask_or emits pto.por(lhs_part, rhs_part, all_true_mask) + mask_xor emits pto.pxor(lhs_part, rhs_part, all_true_mask) + mask_not emits pto.pnot(source_part, all_true_mask) + +pto.vmi.addf / addi / subf / subi / mulf / muli / divf / minf / maxf / negf / absf / absi / sqrt / exp / ln / relu / andi / ori / xori / shli / shrui / not: + current direct lowering requires the physical element width to be 8, 16, or + 32 bits, because every emitted VPTO op is predicated by a materialized + pto.mask. VMI types such as index or f64 remain valid semantic + surface types only after a dedicated lowering contract exists; until then + vmi-to-vpto must report VMI-UNSUPPORTED before OneToN conversion. + This common predicate-maskability rule is necessary but not sufficient for + every target op. Direct lowering must also preflight the concrete VPTO/VISA + element contract before OneToN rewriting: + addf/subf/mulf -> pto.vadd/vsub/vmul support f16/bf16/f32 floating types + divf -> pto.vdiv supports f16/f32 floating types + minf/maxf -> pto.vmin/vmax support f16/bf16/f32 floating types + negf/absf/sqrt/exp/ln/relu -> pto.vneg/vabs/vsqrt/vexp/vln/vrelu support f16/f32 floating types + absi -> pto.vabs supports signless/signed i8/i16/i32 integer types + bf16/f8 remain legal VMI float-like semantic types for the ops whose VMI + semantics allow them, but vmi-to-vpto must report VMI-UNSUPPORTED until a + materialization plan or wider target contract exists. + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" from the physical element width + addf/addi emit pto.vadd(lhs_part, rhs_part, all_true_mask) + subf/subi emit pto.vsub(lhs_part, rhs_part, all_true_mask) + mulf/muli emit pto.vmul(lhs_part, rhs_part, all_true_mask) + divf emits pto.vdiv(lhs_part, rhs_part, all_true_mask) + minf emits pto.vmin(lhs_part, rhs_part, all_true_mask) + maxf emits pto.vmax(lhs_part, rhs_part, all_true_mask) + negf emits pto.vneg(source_part, all_true_mask) + absf/absi emit pto.vabs(source_part, all_true_mask) + sqrt emits pto.vsqrt(source_part, all_true_mask) + exp emits pto.vexp(source_part, all_true_mask) + ln emits pto.vln(source_part, all_true_mask) + relu emits pto.vrelu(source_part, all_true_mask) + andi emits pto.vand(lhs_part, rhs_part, all_true_mask) + ori emits pto.vor(lhs_part, rhs_part, all_true_mask) + xori emits pto.vxor(lhs_part, rhs_part, all_true_mask) + shli emits pto.vshl(lhs_part, rhs_part, all_true_mask) + shrui emits pto.vshr(lhs_part, rhs_part, all_true_mask) + not emits pto.vnot(source_part, all_true_mask) + +pto.vmi.fma: + semantic: + result = fused_multiply_add(lhs, rhs, acc) + It must not be decomposed to pto.vmi.mulf + pto.vmi.addf because VPTO VMULA + may produce different floating-point results from separate multiply and add. + layout assignment: + lhs, rhs, acc, and result belong to one data layout equivalence class. + current direct lowering: + source/result element type must be f16, bf16, or f32 + for each physical part: + materialize pto.pset_b16/b32 "PAT_ALL" from the physical element width + emit pto.vmula(acc_part, lhs_part, rhs_part, all_true_mask) + The VMI operand order is lhs, rhs, acc; the VPTO operand order is acc, lhs, rhs. + +pto.vmi.cmpf / cmpi: + current direct lowering has the same 8/16/32-bit physical element-width + precondition as elementwise arithmetic, so the result predicate can be + materialized as b8/b16/b32. + target element contract: + cmpf: f16/bf16/f32, matching VISA VCMP floating-point element types + cmpi: signless/signed/unsigned i8/i16/i32, matching VISA VCMP integer element types + for each physical part: + materialize pto.pset_b8/b16/b32 "PAT_ALL" as the seed predicate + canonicalize predicate to VPTO cmp_mode eq/ne/lt/le/gt/ge + emit pto.vcmp(lhs_part, rhs_part, seed_mask, cmp_mode) + supported cmpf ordered aliases: + oeq -> eq + one -> ne + olt -> lt + ole -> le + ogt -> gt + oge -> ge + supported cmpi signed aliases: + slt -> lt + sle -> le + sgt -> gt + sge -> ge + unsupported floating-point predicates such as ord/uno/ult/ule/ugt/uge must emit VMI-UNSUPPORTED until NaN-aware + predicate construction is designed. + unsupported unsigned integer predicates ult/ule/ugt/uge must emit VMI-UNSUPPORTED until VPTO integer signedness + materialization is explicit. + +pto.vmi.active_prefix_index: + semantic: + idx[i] = popcount(mask[0 .. i)) + result element type must be signless i8/i16/i32, and concrete mask granularity must match the result element width. + current direct lowering: + only contiguous layout + only one physical result/mask chunk + result and mask chunks must be full, with no padding logical lanes + materialize a zero vreg carrier with pto.vdup + emit pto.vusqz(carrier, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could affect the observable prefix + multi-chunk contiguous values need cross-chunk prefix carry + deinterleaved layouts need logical-lane-order prefix reconstruction + both must report VMI-UNSUPPORTED before OneToN conversion + +pto.vmi.compress: + semantic: + keep source lanes whose mask lane is true and compact them in logical lane order; inactive tail lanes are zero/undef + at the VMI semantic level unless consumed by an operation that defines them. + current direct lowering: + source/result/mask must be contiguous + source/result/mask must each materialize to one physical chunk + source chunk must be full, with no padding logical lanes + emit pto.vsqz(source, mask) + unsupported cases: + partial/tail chunks because padding mask lanes could be squeezed into the observable result prefix + multi-chunk values need cross-chunk compaction and SQZN/carry planning + deinterleaved layouts need logical-lane-order compaction before physical part placement + compress_store is not implied by register compress; store-coupled VSQZ #st=1 and VSTUR require a separate + producer/consumer pairing plan + +pto.vmi.compress_store: + semantic: + store source lanes whose mask lane is true as a dense logical memory stream: + k = 0 + for lane in logical order: + if mask[lane]: + base[offset + k] = value[lane] + k += 1 + layout assignment: + value use is requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct lowering: + value and mask must be contiguous + value and mask must each materialize to one physical chunk + the value chunk must be full, with no padding logical lanes + destination must be a UB !pto.ptr because pto.vstur is pointer-only and UB-only + lower as: + store_base = pto.addptr destination, offset + squeezed = pto.vsqz(value, mask) + align0 = pto.init_align + align1 = pto.vstur align0, squeezed, store_base, "POST_UPDATE" + pto.vstar align1, store_base + The pto.vstur user is the required consumer that lets the VPTO LLVM emitter + set VSQZ #st=1. A plain register pto.vsqz must not be assumed to enqueue + SQZN for store. + unsupported cases: + memref or GM destination until an explicit pointer/materialization plan exists + partial/tail physical chunks, because padding mask lanes could be squeezed into memory + multi-chunk values, because they need cross-chunk active-count compaction and SQZN/VSTUR state planning + deinterleaved layouts, because compaction must be in logical lane order + +pto.vmi.reduce_addi: + semantic: + acc = init[0] + for lane in logical order: + if mask[lane]: + acc = acc + source[lane] // integer wraparound addition + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element width must be 32 bits; narrower vcadd widens its result and needs a separate result type plan + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + i8/i16 until widening result and init conversion are designed + partial/tail source chunks because padding lanes must not participate + floating-point add reduction without pto.vmi.reduce_addf {reassoc} + +pto.vmi.reduce_addf: + semantic: + requires {reassoc}; without it the verifier rejects the op + acc = init[0] + for lane in any reassociated tree over active logical lanes: + acc = acc + source[lane] + result[0] = acc + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of b32 physical chunks as source + lower as: + first_lane = pto.pge_b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcadd(source_chunk, mask_chunk) + acc = pto.vadd(reduced, acc, first_lane) + result = acc + unsupported cases: + missing reassoc attr + f16 until accumulator precision and rounding contract are designed + partial/tail source chunks because padding lanes must not participate + +pto.vmi.reduce_maxf / pto.vmi.reduce_minf: + semantic: + acc = init[0] + for each active logical lane in logical lane order: + reduce_maxf: acc = max(acc, source[lane]) + reduce_minf: acc = min(acc, source[lane]) + result[0] = acc + inactive lanes inside each physical chunk follow VPTO identities: + reduce_maxf uses pto.vcmax, where inactive FP lanes behave as -INF + reduce_minf uses pto.vcmin, where inactive FP lanes behave as +INF + NaN and signed-zero behavior follows pto.vcmax/pto.vcmin for the chunk + reduction and pto.vmax/pto.vmin for serial chunk accumulation. The index + lane produced by pto.vcmax/pto.vcmin is ignored because VMI exposes only the + rank-0 value result. + layout assignment: + source use is requested as contiguous + init use is requested as contiguous + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from source element width + current direct lowering: + source element type must be f16 or f32 + source must materialize to one or more full physical chunks with no padding logical lanes + init/result must be rank-0 VMI vectors and each materialize to one physical chunk + mask must materialize to the same number of physical chunks as source + lower reduce_maxf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmax(source_chunk, mask_chunk) + acc = pto.vmax(reduced, acc, first_lane) + result = acc + lower reduce_minf as: + first_lane = pto.pge_b16/b32 "PAT_VL1" + acc = init + for each source_chunk, mask_chunk in physical order: + reduced = pto.vcmin(source_chunk, mask_chunk) + acc = pto.vmin(reduced, acc, first_lane) + result = acc + unsupported cases: + bf16/fp8/f64 until VPTO reduction and combine semantics are designed + partial/tail source chunks because padding lanes must not participate + integer min/max until signed/unsigned and inactive identity contracts are explicit + +pto.vmi.select: + current direct lowering is a storage-width select rather than a semantic + arithmetic op: source/result physical elements must be b8/b16/b32-maskable, + but signedness and float-vs-integer interpretation are not inspected. + for each physical part: + consume the corresponding physical predicate part + emit pto.vsel(true_part, false_part, predicate_part) + +pto.vmi.extf, direct path: + support 16-bit float-like contiguous source part -> f32 deinterleaved=2 result parts + materialize pto.pset_b16 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=EVEN/ODD) + partial/tail is valid when the logical lanes fit in the one physical source + part; PAT_ALL may convert padding lanes, but those lanes remain padding in + the deinterleaved result + support 8-bit contiguous source part -> f32 deinterleaved=4 result parts + materialize pto.pset_b8 "PAT_ALL" + emit pto.vcvt(source_part, mask, part=P0/P1/P2/P3) + the same padding rule applies + reject other extf width/layout shapes until their exact part plan is implemented + +pto.vmi.truncf, direct path: + support f32 deinterleaved=2 source parts -> 16-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(even_f32_part, mask, rnd=R, sat=SAT, part=EVEN) + emit pto.vcvt(odd_f32_part, mask, rnd=R, sat=SAT, part=ODD) + materialize pto.pset_b16 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the two source parts pack into one physical + result part; converted padding lanes remain result padding + support f32 deinterleaved=4 source parts -> 8-bit contiguous result part + materialize pto.pset_b32 "PAT_ALL" for the source conversion + emit pto.vcvt(p0_f32_part, mask, rnd=R, sat=SAT, part=P0) + emit pto.vcvt(p1_f32_part, mask, rnd=R, sat=SAT, part=P1) + emit pto.vcvt(p2_f32_part, mask, rnd=R, sat=SAT, part=P2) + emit pto.vcvt(p3_f32_part, mask, rnd=R, sat=SAT, part=P3) + materialize pto.pset_b8 "PAT_ALL" + merge mutually exclusive part results with pto.vor + partial/tail is valid when the four source parts pack into one physical + result part; converted padding lanes remain result padding + reject other truncf width/layout shapes until their exact pack plan is implemented + +pto.vmi.bitcast: + for each physical part: + emit pto.vbitcast(source_part) -> result_part_type + source/result layouts must match, physical arity must match, and every + corresponding physical chunk must carry the same number of logical bits. + Padding bits may map only to result padding bits; any shape where source + padding would become result logical data remains unsupported. + +pto.vmi.channel_split / pto.vmi.channel_merge: + support 2-way and 4-way channel transforms for contiguous per-channel values + and matching deinterleaved=C merged values. + + channel_split C=2: + if the source layout is already deinterleaved=2, forward physical chunks + directly to the two contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 2*N contiguous chunks. For each pair of dense chunks: + %ch0_i, %ch1_i = pto.vdintlv %dense_2i, %dense_2i_plus_1 + Results are returned in per-channel order: + channel0 chunks..., channel1 chunks... + + channel_split C=4: + if the source layout is already deinterleaved=4, forward physical chunks + directly to the four contiguous channel results. + if the source layout is contiguous, source logical vector must physicalize + as 4*N contiguous chunks. The lowering is the same two-level pto.vdintlv + tree used by contiguous -> deinterleaved=4 materialization, but the + partition-major output is interpreted as four separate contiguous channel + results. + + channel_merge C=2/C=4: + inputs are consumed as per-channel contiguous chunks. + If the result layout is deinterleaved=C, the physical chunks are forwarded + directly in partition-major order. + If the result layout is contiguous, the lowering uses the reverse + pto.vintlv tree and returns dense contiguous chunks for the merged result. + + Unsupported: + channel counts other than 2 or 4 + non-matching channel input/result layouts + arity-changing or uneven partial physical channel groups that cannot form + complete intlv/dintlv groups + +pto.vmi.shuffle: + first try whole physical chunk forwarding cases: + source/result layouts are assigned + every non-padding lane in a result physical chunk maps to the same source physical chunk + source lane number equals result lane number inside the physical chunk + result padding lanes are ignored and remain semantically unobservable + + If forwarding fails, try vci-materializable vselr per physical chunk: + every result physical chunk has no padding lane + every lane in a result physical chunk maps to the same source physical chunk + source lane indices inside the chunk form one ASC or DESC consecutive sequence + materialize the index vector with pto.vci(base_lane, ASC|DESC) + emit pto.vselr(source_chunk, index_vector) + + Examples: + identity 128xf32 -> 128xf32: + indices = [0, 1, ..., 127] + forward dense chunks 0 and 1 + + second physical chunk 128xf32 -> 64xf32: + indices = [64, 65, ..., 127] + forward dense chunk 1 + + tail prefix 128xf32 -> 4xf32: + indices = [0, 1, 2, 3] + forward dense chunk 0 + lanes 4..63 of the physical result are padding lanes and are not part of + the logical vmi value + + chunk swap 128xf32 -> 128xf32: + indices = [64, 65, ..., 127, 0, 1, ..., 63] + forward dense chunks in order 1, 0 + + reverse one 64xf32 chunk: + indices = [63, 62, ..., 0] + index = pto.vci 63 {order = DESC} : i32 -> !pto.vreg<64xi32> + result = pto.vselr source_chunk, index + + Unsupported: + partial physical chunk projection whose observable result lanes are not + padding-safe forwarding, e.g. [1, 2, 3, 4] -> 4xf32 when it would require + shifting lanes rather than forwarding a whole physical chunk + broadcast, duplicate lanes, arbitrary non-affine permutation + current implementation emits VMI-UNSUPPORTED for these cases before + OneToN conversion, instead of leaving a generic residual VMI op. +``` + +`func.return` 携带 VMI operand 时必须通过 OneToN func/return structural pattern 展开成 physical +return operands。不能只取第一个 physical part;这种错误会导致函数类型已经返回两个 physical value, +但 `func.return` 只返回一个 value。 + +### 6.1 Type Conversion + +Use one shared physicalization helper: + +```text +VMIVRegType -> N physical !pto.vreg +VMIMaskType -> N physical !pto.mask +``` + +Physical result ordering must be: + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=K: + p0_chunk0, p0_chunk1, ..., p1_chunk0, ..., p(K-1)_chunkN +``` + +### 6.2 Structural Conversion + +The pass must convert: + +```text +operation results +block arguments +branch operands +cf.br / cf.cond_br successor block signatures +scf.if results and yields +scf.for iter_args and yields +func arguments/results +call operands/results +return operands +cf.br / cf.cond_br / cf.switch block arguments and successor operands +scf.execute_region results and yields: + current implementation uses a project-local OneToN structural pattern. +scf.index_switch results and yields: + current implementation uses a project-local OneToN structural pattern. +``` + +Do not rely on a defining op to recover parts. Any VMI value may come from a block argument or function +argument, so `unpack` must be valid on arbitrary layout-assigned VMI SSA values before final lowering. + +### 6.3 Op Lowering + +Internal helper lowering: + +```text +unpack: + replace with physical values in helper ordering + +pack: + materialize one logical VMI aggregate before it is immediately consumed by another VMI helper + must not remain after final gate + +ensure_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + identity source/result layouts do not require full chunks + if source/result layouts differ, either: + every source/result physical chunk is full, with no padding lanes; or + source/result both have complete contiguous/deinterleaved=2/4 materialization groups and their materialized + physical arity still equals the original VMI physical arity + arity-changing partial/tail layout conversion remains unsupported because it would need an explicit padding + packing/drop plan + otherwise report VMI-UNSUPPORTED before OneToN conversion + + compare the original VMI source/result layout attrs: + same layout: + forward the converted source parts + deinterleaved=2 -> contiguous: + %d0, %d1 = pto.vintlv %p0, %p1 + contiguous -> deinterleaved=2: + %p0, %p1 = pto.vdintlv %d0, %d1 + deinterleaved=4 -> contiguous: + %a0, %a1 = pto.vintlv %p0, %p2 + %b0, %b1 = pto.vintlv %p1, %p3 + %d0, %d1 = pto.vintlv %a0, %b0 + %d2, %d3 = pto.vintlv %a1, %b1 + contiguous -> deinterleaved=4: + %a0, %b0 = pto.vdintlv %d0, %d1 + %a1, %b1 = pto.vdintlv %d2, %d3 + %p0, %p2 = pto.vdintlv %a0, %a1 + %p1, %p3 = pto.vdintlv %b0, %b1 + + It is a bug to treat layout conversion as identity merely because both sides convert to the same + number of physical !pto.vreg values with the same type. For example: + !pto.vmi.vreg<128xf32, deinterleaved=2> + !pto.vmi.vreg<128xf32, contiguous> + both physicalize to two !pto.vreg<64xf32> values, but their logical lane order differs. + +ensure_mask_layout: + preflight: + source/result must have computable physical arity + source/result physical arity must match + if source/result layouts differ, every source/result physical predicate chunk must be full, with no padding lanes + identity source/result layouts do not require full chunks + otherwise report VMI-UNSUPPORTED before OneToN conversion + + same-layout: + forward source parts + deinterleaved=2 -> contiguous: + use pto.pintlv_b8/b16/b32 on each partition pair + contiguous -> deinterleaved=2: + use pto.pdintlv_b8/b16/b32 on each dense pair + deinterleaved=4 -> contiguous: + use the same two-level tree as data layout conversion, replacing pto.vintlv with pto.pintlv_b8/b16/b32 + contiguous -> deinterleaved=4: + use the reverse two-level tree, replacing pto.vdintlv with pto.pdintlv_b8/b16/b32 + source/result granularity must be identical; granularity conversion belongs to ensure_mask_granularity. + +ensure_mask_granularity: + source/result layout and logical lane count must match. + source/result granularity must be concrete b8/b16/b32. + identity conversion forwards physical parts. + widening conversion: + b8 -> b16 or b16 -> b32 uses pto.punpack LOWER/HIGHER for each source physical chunk. + each source physical mask chunk can produce up to two result chunks in logical order. + narrowing conversion: + b32 -> b16 or b16 -> b8 uses pto.ppack LOWER for the low source chunk. + if a high source chunk exists, use pto.ppack HIGHER and merge the two partial masks with pto.por under PAT_ALL. + this handles odd tail groups because the missing high half is padding and remains zero. + multi-step conversion: + b8 -> b32 is b8 -> b16 -> b32. + b32 -> b8 is b32 -> b16 -> b8. +``` + +Elementwise lowering: + +```text +for each physical part: + lower add/cmp/select to corresponding VPTO op sequence + preserve source/result physical ordering + cmp predicates must be canonicalized before creating pto.vcmp: + eq/ne/lt/le/gt/ge pass through + ordered FP aliases oeq/one/olt/ole/ogt/oge map to eq/ne/lt/le/gt/ge + signed integer aliases slt/sle/sgt/sge map to lt/le/gt/ge + unordered/NaN-sensitive FP predicates are unsupported until represented explicitly + unsigned integer predicates are unsupported until signedness is represented explicitly +``` + +Producer lowering: + +```text +broadcast: + TypeConverter gives the ordered result physical types. + For each result physical vreg: + create all-true mask with the vreg element width + emit pto.vdup scalar -> that physical vreg + + This is valid for contiguous and deinterleaved layouts because splat has no lane-order dependence. + +constant: + Splat dense constants use the same path as broadcast: + create scalar arith.constant from the splat attribute + emit pto.vdup per physical result part + require the same 8/16/32-bit physical result element-width precondition as + broadcast + Non-splat dense constants need an explicit constant materialization strategy or must remain unsupported with a + precise diagnostic; do not synthesize an arbitrary lane sequence by scalar inserts unless that path is designed. + +create_mask / constant_mask: + constant active_lanes create_mask lowers per physical mask part: + clamp active_lanes to [0, logical lane count] + compute active prefix count for each physical mask chunk with the VMI lane-map helper + emit pto.pge_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a chunk prefix count has no supported PAT_VL token, fall back to pto.plt_b8/b16/b32 with a constant i32 count + Dynamic active_lanes with contiguous layout lowers by chaining pto.plt_b8/b16/b32 over the physical chunks: + active_i32 = arith.index_cast active_lanes : index to i32 + active_i32 = minui(maxsi(active_i32, 0), logical_lane_count) + mask0, remaining0 = pto.plt_b* active_i32 + mask1, remaining1 = pto.plt_b* remaining0 + ... + Dynamic active_lanes with deinterleaved layout remaps one logical prefix into per-part dynamic lane counts before + chaining pto.plt_b*: + active_i32 = minui(maxsi(index_cast(active_lanes), 0), logical_lane_count) + part_count(part) = (active_i32 + factor - 1 - part) / factor + then chain pto.plt_b* independently for each partition in VMI physical order: + p0 chunks..., p1 chunks..., ... + dense constant_mask lowers per physical mask part: + first map logical lanes to physical predicate lanes using the assigned VMI layout + prefix chunks emit pto.pset_b8/b16/b32 PAT_ALL, PAT_ALLF, or supported PAT_VL* + if a prefix count has no supported PAT_VL token, emit pto.plt_b8/b16/b32 with a constant i32 count + non-prefix chunks are decomposed into static active runs: + prefix(hi) = pto.pge/plt for the run end + prefix(lo) = pto.pge/plt for the run begin + run = prefix(hi) & ~prefix(lo) using pto.pnot + pto.pand + chunk = run0 | run1 | ... using pto.por + +Unsupported diagnostics: + unexpected residual dynamic pto.vmi.create_mask after OneToN conversion: + VMI-UNSUPPORTED: dynamic pto.vmi.create_mask active_lanes could not be lowered by the current runtime predicate + generation plan + This is a final-gate diagnostic for malformed or newly unsupported dynamic shapes. The supported dynamic + contiguous/deinterleaved=2/deinterleaved=4 paths above must lower before this residual gate. + + non-splat pto.vmi.constant: + VMI-UNSUPPORTED: non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan + + partial/tail pto.vmi.load/tile_read: + VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe + full-read footprint (...; safe-read proof failed: ...) + GM-backed direct pto.vmi.load/masked_load/expand_load/tile_read: + VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering + emits pto.vlds/pto.vsts and requires UB-backed memory) + unsupported partial/tail pto.vmi.store/masked_store/tile_write: + VMI-UNSUPPORTED: pto.vmi. requires an 8/16/32-bit predicate-maskable element type and either full + physical chunks or contiguous/deinterleaved tail-store materialization, with UB-backed destination; unsupported + cases include values such as f64/index that have no b64 predicate representation, GM-backed destinations that + still need a memory movement/materialization plan, and uneven deinterleaved physical groups that cannot form + complete intlv groups + + unsupported non-identity partial/tail pto.vmi.ensure_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_layout cannot materialize the requested data layout conversion; unsupported cases + include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete intlv + groups + If the helper has a single consumer, the main diagnostic is emitted on the + consumer op and operand, including both the actual operand VMI type and the + required VMI type. For example, pto.vmi.truncf operand #0 can report + `!pto.vmi.vreg<128xf32, contiguous>` vs. + `!pto.vmi.vreg<128xf32, deinterleaved=4>` for f32->fp8. The failed + pto.vmi.ensure_layout conversion is attached as a note. + + unsupported non-identity partial/tail pto.vmi.ensure_mask_layout: + VMI-UNSUPPORTED: pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion; unsupported + cases include arity-changing partial/tail conversion and uneven deinterleaved groups that cannot form complete + predicate intlv groups + + unsupported pto.vmi.ensure_mask_granularity: + VMI-UNSUPPORTED: non-identity mask granularity materialization requires concrete b8/b16/b32 masks with matching + lane count and layout (...) + + unsupported pto.vmi.extf direct path shape: + VMI-UNSUPPORTED: pto.vmi.extf supports only one contiguous 16-bit float-like or fp8-like physical source chunk to f32 + deinterleaved=2/4 results; partial/tail is allowed only when source padding maps to result padding + + unsupported pto.vmi.truncf direct path shape: + VMI-UNSUPPORTED: pto.vmi.truncf supports only f32 deinterleaved=2 source parts to one contiguous f16 result chunk + or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk + + unsupported pto.vmi.bitcast shape: + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical arity and matching + per-chunk logical bit footprints (...) + + unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: + VMI-UNSUPPORTED: pto.vmi.channel_split supports only 2 or 4 channels + VMI-UNSUPPORTED: pto.vmi.channel_merge supports only 2 or 4 channels + unsupported pto.vmi.channel_split / pto.vmi.channel_merge layout: + VMI-UNSUPPORTED: pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel + layout, and every result layout to be contiguous + VMI-UNSUPPORTED: pto.vmi.channel_merge requires every input layout to be contiguous and result layout to be + contiguous or matching deinterleaved channel layout +``` + +Width conversion lowering: + +```text +f16 -> f32: + supported direct path when source is contiguous and result is deinterleaved=2: + pto.vcvt part=EVEN produces logical lanes 0,2,4,... + pto.vcvt part=ODD produces logical lanes 1,3,5,... + source/result physical arity must be 1 -> 2 + +f8 -> f32: + supported direct path when source is contiguous and result is deinterleaved=4: + pto.vcvt part=P0/P1/P2/P3 produces the four modulo-4 lane partitions + source/result physical arity must be 1 -> 4 + +f32 -> f16: + supported direct path when source is deinterleaved=2 and result is contiguous: + pto.vcvt part=EVEN consumes even/source part 0 + pto.vcvt part=ODD consumes odd/source part 1 + pto.vor merges mutually exclusive f16 part results into one contiguous vreg + source/result physical arity must be 2 -> 1 + current default conversion attrs are rnd=R, sat=SAT +``` + +Memory lowering: + +```text +vmi.load: + current direct memory path first reads contiguous physical chunks. The logical lane count must be an exact multiple + of the physical vreg lane count. + For each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + dense_i = pto.vlds base[offset_i] + + If the requested VMI result layout is contiguous, return the dense chunks directly. + If the requested VMI result layout is deinterleaved=2: + prefer pto.vldsx2 "DINTLV_B8/B16/B32" per physical chunk group: + %p0_i, %p1_i = pto.vldsx2 base[offset_i], "DINTLV_B*" + return results in VMI partition-major order: + p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ... + If the requested VMI result layout is deinterleaved=4 with exactly four physical parts: + use dense pto.vlds chunks followed by the reverse two-level pto.vdintlv tree. + + For larger multi-chunk deinterleaved=4 loads, apply the same conversion per contiguous chunk group and return + physical parts in VMI partition-major order: + deinterleaved=4: p0_chunks..., p1_chunks..., p2_chunks..., p3_chunks... + +vmi.store: + direct lowering requires value element width to be 8, 16, or 32 bits so the + emitted pto.vsts/pto.vstsx2 predicate can be materialized as b8/b16/b32. + contiguous layout with full physical chunks: + offset_i = base_offset + i * lanesPerPart + mask_i = pto.pset_b8/b16/b32 "PAT_ALL" + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks still use PAT_ALL + the final chunk computes valid_lanes = logical_lane_count - chunk_i * lanesPerPart + tail_mask_i = pto.plt_b8/b16/b32(valid_lanes) + pto.vsts tail_value_i, base[offset_i], tail_mask_i + padding lanes therefore have no externally visible store effect. + +deinterleaved store: + deinterleaved=2 with full physical chunks: + prefer pto.vstsx2 "INTLV_B8/B16/B32" per physical chunk group: + pto.vstsx2 p0_i, p1_i, base[offset_i], "INTLV_B*", all_true_mask + offset_i = base_offset + i * 2 * lanesPerPart + the vstsx2 dist mode writes logical lane 0,1,2,3,... order externally. + + current safe path lowers through proven register materialization before store: + deinterleaved=4 with exactly four physical parts: + use the two-level pto.vintlv tree, then store %d0/%d1/%d2/%d3 as contiguous chunks + + Larger multi-chunk deinterleaved=4 values use the same conversion per chunk group. The final store order is dense + chunk order, so external memory observes logical lane 0,1,2,... order. + +vmi.masked_load: + semantics: + if mask[lane] is true, result[lane] = memory[base + lane] + if mask[lane] is false, result[lane] = passthru[lane] + inactive mask lanes do not by themselves permit unsafe memory reads + current direct path: + result, passthru, and mask are requested as contiguous + full physical chunks can always use pto.vlds because every loaded lane is logical + partial/tail chunks require the same statically safe full-read proof as vmi.load + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + non-contiguous layouts + unsafe partial/tail read footprints + target true masked/non-faulting load and guarded/scratch fallback + +vmi.gather: + semantics: + if mask[lane] is true, result[lane] = memory[base + indices[lane]] + if mask[lane] is false, result[lane] = passthru[lane] and no memory read occurs for that lane + indices are interpreted in element units, not bytes + layout assignment: + result natural layout is contiguous + indices and passthru uses are requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + source must be !pto.ptr + T must be a 32-bit element type + indices must be signless or unsigned i32 + result / indices / passthru / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + gathered_i = pto.vgather2_bc source, indices_i, mask_i + result_i = pto.vsel gathered_i, passthru_i, mask_i + reason for vsel: + VGATHER2_BC false predicate lanes do not read memory but produce zero; VMI false lanes preserve passthru. + unsupported cases: + f16/b16/f8/i8 result element types + partial/tail chunks + non-contiguous layouts + memref/gm source + guarded/scratch fallback + +vmi.scatter: + semantics: + if mask[lane] is true, memory[base + indices[lane]] = value[lane] + if mask[lane] is false, no memory write occurs for that lane + indices are interpreted in element units, not bytes + if two active lanes have the same index, VMI logical semantics require an ordered conflict policy or an explicit + no-conflict proof before direct target lowering + layout assignment: + value and indices uses are requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct path: + op must carry {indices_unique} + destination must be !pto.ptr + T must be a 32-bit element type + indices must be signless or unsigned i32 + value / indices / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + pto.vscatter value_i, destination, indices_i, mask_i + reason for indices_unique: + VSCATTER false predicate lanes do not write, but duplicate active indices have target-defined/undefined grant + behavior. VMI cannot lower duplicate-index logical order semantics to VSCATTER without a proof or fallback. + unsupported cases: + missing indices_unique proof + f16/b16/f8/i8 value element types + partial/tail chunks + non-contiguous layouts + memref/gm destination + ordered duplicate-index fallback + +vmi.expand_load: + semantics: + k = 0 + for lane in logical order: + if mask[lane]: + result[lane] = memory[base + k] + k += 1 + else: + result[lane] = passthru[lane] + layout assignment: + result natural layout is contiguous + passthru use is requested as contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + static all-active path: + pto.vmi.create_mask with constant active_lanes >= logical lane count + dense all-true pto.vmi.constant_mask + in that case expand_load degenerates to ordinary vmi.load: + for each contiguous physical chunk i: + loaded_i = pto.vlds base[offset_i] + result_i = loaded_i + partial/tail chunks still require the same statically safe full-read proof as vmi.load. + runtime-mask path: + source must be !pto.ptr + T must be a 32-bit element type + result / passthru / mask must be contiguous one full physical chunk + mask granularity must be b32 + base_i = pto.addptr source, offset + indices_i = pto.vusqz(zero_i32_carrier, mask_i) + loaded_i = pto.vgather2_bc base_i, indices_i, mask_i + result_i = pto.vsel loaded_i, passthru_i, mask_i + unsupported cases: + runtime masks across multiple physical chunks + runtime masks on non-32-bit element types + non-contiguous layouts + unsafe partial/tail read footprints + guarded load or scratch fallback + +vmi.masked_store: + semantics: + if mask[lane] is true, store value[lane] + if mask[lane] is false, no memory write occurs for that logical lane + current full-footprint path: + value and mask are requested as contiguous at the use site + mask granularity is derived from value element width + for each contiguous physical chunk i: + offset_i = base_offset + i * lanesPerPart + pto.vsts value_i, base[offset_i], mask_i + contiguous layout with a final partial physical chunk: + full chunks store with the user mask directly + the final chunk computes tail_valid_i with pto.plt_b8/b16/b32(valid_lanes) + store_mask_i = pto.pand user_mask_i, tail_valid_i, all_true_mask_i + pto.vsts tail_value_i, base[offset_i], store_mask_i + padding lanes and user-inactive lanes therefore both have no write effect. + If the incoming value/mask are deinterleaved, layout assignment inserts + ensure_layout/ensure_mask_layout or the vmi-to-vpto pattern materializes the same contiguous representation before + emitting stores. This preserves logical memory order and keeps inactive lanes write-free. + +non-full chunks: + vmi.store, vmi.masked_store, and vmi.tile_write support contiguous tail chunks by predicating the final pto.vsts with + a prefix valid mask. masked_store additionally ANDs the user mask with the tail-valid mask. + deinterleaved=2/4 tail store/masked_store/tile_write is supported only through explicit layout materialization to + contiguous chunks first. This requires every deinterleaved part to have the same physical chunk count, so the + materializer can build complete vintlv/pintlv groups. After materialization, each contiguous chunk is predicated by + the logical tail-valid mask; chunks whose active logical lane count is zero are not emitted as stores. Uneven + deinterleaved groups, such as 129xf32 with deinterleaved=2, remain unsupported until a padding/scratch plan can + assemble only the observable contiguous chunks. + vmi.load and tile_read support partial/tail chunks only when the direct full physical read is statically safe: + statically shaped memref source, constant non-negative offset (or tile_read offset 0), and enough elements for the + whole physical read footprint. Padding lanes must never become observable. Other partial/tail load cases still need + scratch/guarded/true-masked load planning. + +vmi.tile_read / vmi.tile_write, current direct full-footprint path: + This is not transfer_read padding lowering. It is only the tile/memref equivalent of the full-chunk direct memory + path above. + + tile_read: + source must lower to one VPTO buffer-like value. + logical lane count must be an exact multiple of the physical lanes per part. + use offset 0 as the tile base offset. + contiguous result layout reads physical chunks with pto.vlds. + deinterleaved=2 result layout prefers pto.vldsx2 "DINTLV_B8/B16/B32" with offset 0. + other supported layouts materialize the requested result layout after contiguous reads. + + tile_write: + destination must lower to one VPTO buffer-like value. + use offset 0 as the tile base offset. + value element width must be 8, 16, or 32 bits so pto.vsts/pto.vstsx2 can receive a materialized predicate. + contiguous source layout stores every physical chunk with pto.vsts and an all-true mask. + if the final contiguous chunk is partial, store it with a prefix valid-lane mask. + deinterleaved=2 source layout prefers pto.vstsx2 "INTLV_B8/B16/B32" with offset 0. + other supported layouts materialize the source value to contiguous layout first. + deinterleaved=2/4 tail source layouts are supported through this materialization path only when every + deinterleaved part has the same physical chunk count; zero-active materialized chunks are skipped. + + Unsupported: + padding value semantics + partial/tail tile footprints + transfer_read-style out-of-bounds reads + write masks + non-identity tile indexing/permutation + any path that would expose padding lanes or reorder externally visible memory +``` + +Final hard gate: + +```text +no pto.vmi op remains +no !pto.vmi.* type remains, including in function signatures +no UnrealizedConversionCastOp remains +physical arity matches helper for every lowered value +``` + +Slice 4 完成条件: + +```text +1. `f16 -> f32 -> add -> store` lowers with deinterleaved=2 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +2. `f8 -> f32 -> add -> store` lowers with deinterleaved=4 and stores contiguous logical order. + Covered by vmi_to_vpto_e2e_widen_add_store.pto. +3. Non-full memory physical arity and valid lane map are tested. + Covered by vmi_to_vpto_load_nonfull_invalid.pto, vmi_to_vpto_store_deint_invalid.pto, + vmi_to_vpto_load_safe_tail_memref.pto, + vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_masked_load_safe_tail_memref.pto, + vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_expand_load_all_active.pto, + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto, and multi-chunk load/store layout tests. +4. Full-footprint tile_read/tile_write direct path lowers through pto.vlds/pto.vsts or deinterleaved=2 x2 dist + instructions with offset 0. + Covered by vmi_to_vpto_tile_read_write.pto. +5. Internal func.call boundaries expand callee signatures, call operands/results, and returned VMI values together. + Covered by vmi_layout_assignment_call_boundary.pto, vmi_layout_assignment_indirect_call_invalid.pto, + and vmi_to_vpto_call_boundary.pto. +6. Structured control-flow carrying VMI values expands iter args, yields, results, masks, and returns together. + Covered by vmi_layout_assignment_cf_switch.pto, + vmi_layout_assignment_scf_execute_region.pto, + vmi_layout_assignment_scf_index_switch.pto, + vmi_layout_assignment_scf_while.pto, vmi_to_vpto_cf_branch.pto, + vmi_to_vpto_scf_for.pto, vmi_to_vpto_scf_if.pto, and the user-facing + vmi_ptoas_cli_control_flow.pto. +7. Final gate rejects residual VMI helper and unrealized casts. + Covered by vmi_to_vpto_ensure_identity.pto, + vmi_to_vpto_ensure_layout_partial_invalid.pto, + vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto, + vmi_to_vpto_ensure_mask_layout_partial_invalid.pto, + vmi_to_vpto_unsupported_op_invalid.pto, + vmi_to_vpto_unrealized_cast_residual_invalid.pto, + vmi_to_vpto_type_attr_residual_invalid.pto, and per-feature unsupported + tests. +8. Same-family indirect memory ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_gather_scatter_shape_invalid.pto together with the existing gather/scatter positive and + per-feature negative tests. +9. Same-family reduction ops reject unsupported direct-lowering shapes consistently. + Covered by vmi_to_vpto_reduce_shape_invalid.pto together with the existing reduce add/min/max positive and + per-feature negative tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto and + vmi_to_vpto_reduce_addf_f16_invalid.pto. +10. Target-specific element contracts are checked before OneToN rewriting for direct VPTO ops. + Covered by vmi_to_vpto_bf16_arith.pto, vmi_to_vpto_math_element_type_invalid.pto, + vmi_to_vpto_cmp_select.pto, vmi_to_vpto_cmp_element_type_invalid.pto, + vmi_to_vpto_fma.pto, vmi_to_vpto_fma_element_type_invalid.pto, and + vmi_to_vpto_unary_math.pto for negf/absf/absi/sqrt/exp/ln/relu, plus + vmi_to_vpto_relu_element_type_invalid.pto. +11. Same-family mask logic ops lower through the physical mask granularity instead of assuming b32 masks. + Covered by vmi_to_vpto_mask_logic.pto for mask_and/mask_or/mask_xor/mask_not on b32 masks produced by + cmpf and on direct b8/b16 mask operands. +``` + +## 7. Slice 5: Tile Memory And Padding + +The Slice 4 direct path may lower full-footprint `tile_read/tile_write` with offset 0. For partial `tile_read`, it may +also lower to plain `pto.vlds` only when the static safe-read proof above succeeds. Do not lower any other partial or +padded `tile_read` as a plain load until a richer access plan proves it is safe. + +Implement an internal `VMIMemoryAccessPlan`: + +```text +base +logical lane count +logical_shape +permutation_map +lane-to-address map in element units +validMask +paddingValue +safeReadProof +writeMask +target capability decision +fallback resource decision +``` + +Current implementation status: + +```text +lib/PTO/Transforms/VMIToVPTO.cpp + VMIMemoryAccessPlan + VMIMemorySafeReadProof + VMIMemoryLogicalShape + VMIMemoryLaneAddressMap + VMIMemoryFallbackDecision + +currently routed through the plan: + contiguous identity logical_shape/permutation/lane-to-address map in element units + explicit rejection of non-identity memref layouts until subview/affine lane maps are represented + covered by vmi_to_vpto_memref_layout_invalid.pto, including a memref.subview-produced strided view + subview diagnostics name the missing normalized base/offset/stride lane-to-address plan + target true masked/non-faulting load capability query + current result is missing capability because pto.vlds has no mask operand + covered by vmi_to_vpto_masked_load_nonfull_invalid.pto + stable gather masked-load option + covered by vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto + currently emits a TODO diagnostic instead of lowering through VGATHER2 + direct pto.vmi.load partial/tail safe full-read proof + pto.vmi.masked_load partial/tail safe full-read proof + pto.vmi.expand_load static all-active safe full-read proof + VMI-to-VPTO rewrite match guard for load/tile_read full-or-safe reads + pto.vmi.store/tile_write direct write target decision with all-true writeMask kind + pto.vmi.masked_store direct write target decision with explicit writeMask kind + unsafe partial/tail read fallback decision as RequiredUnavailable diagnostic + covered by vmi_to_vpto_load_nonfull_invalid.pto, + vmi_to_vpto_masked_load_nonfull_invalid.pto, and + vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto + +currently not implemented by the plan: + paddingValue materialization (intentionally unsupported in the first implementation stage) + non-all-true validMask direct masked/non-faulting load lowering + scratch/guarded fallback lowering or allocation + lowering for non-identity logical_shape/permutation_map/lane-to-address maps, including subview or affine lane maps + writeMask fallback planning beyond the existing contiguous tail-store predicate path +``` + +Important first-stage contract: + +```text +VMI physical tail lanes and transfer paddingValue are different concepts. + +Physical tail lanes: + arise because pto.vreg is fixed at 256 bytes + are outside the logical VMI lane count + may be read/computed only when the extra lanes remain unobservable + +transfer_read-style paddingValue: + is an observable logical result for invalid/OOB transfer lanes + cannot be dropped or replaced by arbitrary physical tail contents + is not materialized by the first-stage VMI implementation + +Therefore any frontend path that still needs transfer_read paddingValue +semantics must stop before direct VMI-to-VPTO lowering with VMI-UNSUPPORTED, +unless it has already canonicalized to an all-valid load/masked_load subset +whose invalid lanes are proven absent. +``` + +`tile_read` decision tree: + +```text +safeReadProof full && validMask all true: + direct load + +safeReadProof full && validMask not all true: + first-stage: VMI-UNSUPPORTED because paddingValue materialization is not implemented + future: full load + padding materialization + select + +target true masked/non-faulting load: + first-stage: VMI-UNSUPPORTED because true masked/non-faulting load and paddingValue materialization are not implemented + future: masked load + padding materialization + +otherwise: + first-stage: VMI-UNSUPPORTED with the missing fallback reason + future: split safe regions, scratch fill/copy/load, guarded fallback, or diagnostic +``` + +`tile_write` decision tree: + +```text +writeMask all true && full footprint safe-writable: + direct store + +target true masked store: + masked store + +otherwise: + split/guarded/scatter-like fallback or diagnostic +``` + +Slice 5 完成条件: + +```text +1. Unsafe partial/tail read-like ops never lower to a potentially invalid full + read unless the physical footprint is statically proven safe. +2. PaddingValue materialization is not required in the first implementation + stage. Any path that would require paddingValue, true masked/non-faulting + load, scratch fill/copy/load, or guarded fallback must report + `VMI-UNSUPPORTED` with the missing fallback reason. +3. Non-identity logical_shape/permutation_map/lane-to-address maps, including + subview or affine lane maps, are explicitly rejected before lowering. +4. Store-like partial/tail writes are supported only by the existing + full-chunk or contiguous/deinterleaved tail-store predicate paths. Other + writeMask fallback paths must report `VMI-UNSUPPORTED`. +``` + +## 8. Target Capability Registry + +Add one explicit registry object, passed into layout assignment and VMI-to-VPTO: + +```text +supportsElementType(type, purpose) +getNaturalLayout(op) +supportsLayoutConversion(srcLayout, dstLayout, elementType) +getLayoutMaterializationPlan(srcLayout, dstLayout, elementType) +supportsMaskGranularityConversion(srcG, dstG) +supportsMemoryAccessPlan(plan) +supportsPrefixPopcount(maskType) +supportsReductionScanContract(op) +getScratchResource(plan) +``` + +The registry returns structured results: + +```text +supported +unsupported_missing_capability +unsupported_disabled_by_option +unsupported_resource +``` + +Diagnostics must expose that reason. A pass must not silently choose scalar fallback when fallback is disabled. + +Current implementation status: + +```text +include/PTO/Transforms/VMITargetCapabilities.h + VMITargetCapabilityRegistry + VMICapabilityResult { status, reason } + +currently routed through the registry: + element-type purpose checks for predicate-maskable vregs and direct elementwise/cmp/fma/relu VPTO lowering + reduction-family element-type contracts for reduce_addi/reduce_addf/reduce_maxf/reduce_minf + direct pto.vlds/vsts memory source/destination support + missing target true masked/non-faulting load capability for the current pto.vlds surface + pointer-only UB memory support for pto.vgather2_bc/pto.vscatter/pto.vstur based VMI paths + supported source/result layout conversion pairs + supported b8/b16/b32 mask granularity conversion pairs + pto.vmi.channel_split/channel_merge supported channel count + +still legacy helper-based and should migrate into the registry as follow-up: + full layout materialization plans and padding-safety checks + adjacent ppack/punpack mask granularity materialization plans + prefix popcount and full reduction/scan/contract shape capability checks +``` + +## 9. Diagnostics + +Centralize diagnostic codes in one header or utility file: + +```text +VMI-UNSUPPORTED +VMI-LAYOUT-CONTRACT +VMI-PASS-INVARIANT +VMI-RESIDUAL-OP +``` + +Current implementation defines these codes and their `": "` prefixes in `include/PTO/IR/VMIUtils.h`. Transform and +CLI code must reference those constants instead of spelling the diagnostic code strings locally; a source grep for the +four code strings should find only the central definitions. + +Every diagnostic should include: + +```text +source op +logical VMI type +producer natural layout, if any +consumer required layout, if any +missing capability or disabled option +available materialization paths, if known +``` + +## 10. Lit Test Layout + +Use a dedicated directory: + +```text +test/lit/vmi/ +``` + +Minimum test files: + +```text +vmi_type_attr_parse.mlir +vmi_type_attr_invalid.mlir +vmi_op_verifier_basic.mlir +vmi_producer_boundary.mlir +vmi_layout_assignment_widen.mlir +vmi_layout_assignment_cfg.mlir +vmi_layout_assignment_broadcast_remat.mlir +vmi_layout_assignment_iota_remat.mlir +vmi_layout_assignment_mask_remat.mlir +vmi_to_vpto_deinterleaved2.mlir +vmi_to_vpto_deinterleaved4.mlir +vmi_to_vpto_compaction_deint_invalid.mlir +vmi_to_vpto_non_full_tile.mlir +vmi_tile_read_padding.mlir +vmi_tile_write_mask.mlir +vmi_pipeline_hard_gates.mlir +``` + +Each pass test must use `FileCheck` to prove both positive output and negative absence: + +```text +CHECK: pto.vmi.addf +CHECK-NOT: pto.vadd +CHECK-NOT: unrealized_conversion_cast +``` + +Final lowering tests must check: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +## 11. Implementation Order + +Recommended merge order: + +```text +1. VMI type/attr + helper + parse/verify tests. +2. Slice 1 op shells + verifier tests. +3. VMI producer boundary verifier. +4. layout assignment for straight-line code. +5. layout assignment for scf/cf/function boundaries. +6. vmi-to-vpto type conversion + pack/unpack/unpackable block args. +7. deinterleaved=2 f16 widen end-to-end. +8. deinterleaved=4 f8 widen end-to-end. +9. tile_read/tile_write padding-safe lowering. +10. remaining semantic op families. +``` + +Do not merge a pass that leaves hidden side tables as a required interpretation mechanism. Temporary internal +analysis structures are fine only if the pass materializes the final state into IR before returning. + +## 12. Review Checklist Before Coding Each Slice + +Before implementation: + +```text +1. Is the op/type syntax written in ODS and tested by parser round-trip? +2. Does every verifier rule have a negative test? +3. Does every pass have a post-pass hard gate? +4. Are CFG block arguments and function signatures covered? +5. Does any lowering rely on a defining op that block arguments do not have? +6. Does memory lowering prove safe footprint separately from valid lane mask? +7. Does mask granularity follow consumer element width? +8. Does final VPTO lowering leave zero VMI op/type/helper or unrealized-cast residuals? +``` + +If any answer is no, the slice is not ready to be treated as complete. + +## 13. Adding One VMI Op End To End + +新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的六个落点, +否则很容易出现 verifier 能过、layout pass 不知道怎么约束、或控制流 physicalization 后残留 VMI type。 + +```text +1. ODS surface: + include/PTO/IR/VMIOps.td + +2. semantic verifier: + lib/PTO/IR/VMI.cpp + +3. layout facts: + lib/PTO/Transforms/VMILayoutAssignment.cpp + +4. vmi-to-vpto preflight: + lib/PTO/Transforms/VMIToVPTO.cpp::verifySupportedVMIToVPTOOps + +5. OneToN lowering pattern: + lib/PTO/Transforms/VMIToVPTO.cpp::populateVMIOneToNConversionPatterns + +6. focused lit tests: + test/lit/vmi/ +``` + +这六个落点的职责不同: + +```text +ODS: + 只定义 op 形状、operand/result type 类别、assembly format、interface 和 verifier hook。 + +VMI.cpp verifier: + 检查局部语义,例如元素类型、rank、lane count、predicate 字符串、source/result bit 数关系。 + 不能依赖 def-use 图,不能决定 layout。 + +LayoutAssignment: + 只收集 value-level layout/granularity 事实: + - producer natural layout + - operands that must share layout with result + - consumer required layout + - mask consumer required granularity + 不能在 collect 阶段改 IR。 + +VMIToVPTO preflight: + 在 rewrite 前拒绝当前 lowering 不支持但语义合法的 case。 + 典型例子是 partial physical chunk、non-prefix mask constant、dynamic create_mask、unsupported shuffle。 + +OneToN pattern: + 从 adaptor 读取 physical parts,按已经确定的 layout 发 VPTO op。 + 不能重新推断 layout,也不能通过 defining op 找 physical parts。 + +lit: + 至少覆盖 parser/verify、layout assignment、positive lowering、negative unsupported diagnostic。 +``` + +### Layout Fact Template + +新增 op 时先给它归类,再写 layout 约束。不要从 VPTO 指令形态反推 VMI layout;layout 的来源必须是 +logical vector 语义和当前物理指令的天然限制。 + +```text +elementwise same-shape op: + examples: + addf/addi/subf/mulf/andi/shli/shrui/absf/absi/sqrt + layout rule: + all data operands and result are in one equivalence class + lowering rule: + emit one VPTO op per physical part + +compare op: + examples: + cmpf/cmpi + layout rule: + lhs/rhs data layout unified + result mask requested to the same data layout + result mask granularity comes from lhs/rhs element width + lowering rule: + emit one vcmp per data part, producing corresponding mask part + +mask logical op: + examples: + mask_and/mask_or/mask_xor/mask_not + layout rule: + all mask operands/results share layout and granularity + lowering rule: + emit one predicate op per physical mask part + +layout-changing producer: + examples: + extf f16->f32, extf f8->f32, truncf f32->f16, truncf f32->fp8-like + layout rule: + source/request side follows instruction input contract + result natural layout follows instruction output contract + lowering rule: + emit the instruction sequence that preserves logical lane order under that layout + +memory consumer/producer: + examples: + load/store/tile_read/tile_write + layout rule: + load/tile_read result natural layout is chosen by memory dist capability + store/tile_write value operand requests the layout that memory dist can consume + lowering rule: + direct path only when every physical chunk has no padding lane and footprint is safe + +structural boundary: + examples: + scf.if result/yield, scf.for iter args, cf.br successor operands, func.call + layout rule: + semantically identical incoming/outgoing values are unified + lowering rule: + handled by OneToN structural patterns, not by op semantic lowering +``` + +代码里 `LayoutSolver::addConstraints()` 应该只表达上面的事实。例如一个普通 elementwise binary op +只需要: + +```cpp +if (auto addf = dyn_cast(op)) { + if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || + failed(unite(addf.getLhs(), addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +一个 layout-changing op 不应该把 source/result 直接 `unite`,而是明确写 producer/consumer 合同: + +```cpp +if (auto extf = dyn_cast(op)) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, factor), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); +} +``` + +### OneToN Pattern Template + +`vmi-to-vpto` pattern 的输入不再是 logical VMI value,而是 adaptor 里已经 flatten 好的 physical parts。 +pattern 只做三件事: + +```text +1. 从 adaptor 取每个 logical operand 的 physical part list。 +2. 从 resultMapping 取每个 logical result 对应的 physical result type list。 +3. 按 part 顺序创建 VPTO op,并用 resultMapping replace 原 op。 +``` + +普通 elementwise binary op 的代码形态应该接近: + +```cpp +LogicalResult matchAndRewrite(VMIAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + if (lhsParts.size() != rhsParts.size() || lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical arity mismatch"); + + SmallVector results; + for (auto [lhs, rhs, resultType] : llvm::zip_equal(lhsParts, rhsParts, resultTypes)) + results.push_back(rewriter.create(op.getLoc(), resultType, lhs, rhs)); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); +} +``` + +这里不能调用 `op.getLhs().getDefiningOp()` 去找物理寄存器。原因是 VMI value 可以来自: + +```text +function argument +block argument +scf.for iter arg +scf.if result +cf.br successor argument +func.call result +``` + +这些 value 很多没有 VMI defining op。physical parts 的唯一合法来源是 OneToN adaptor 和 +OneToNTypeMapping。 + +### Control-Flow Checklist + +每新增一个 op,不一定要写新的控制流 pattern;但必须检查它的结果或 operand 是否可能跨边界。 +如果只是普通 VMI value,那么已有 structural OneToN pattern 应该负责边界 physicalization: + +```text +func.func / func.call / func.return: + upstream func OneToN conversion + +scf.if / scf.for / scf.while / scf.yield: + upstream SCF OneToN structural conversion plus layout solver equivalence constraints + +cf.br / cf.cond_br / cf.switch: + project-local OneToN patterns flatten successor operands and rewrite destination block signatures + +scf.execute_region / scf.index_switch: + project-local OneToN patterns flatten region results +``` + +新增 op 的测试要至少放一个跨边界用例,证明 op 的 result 不是只在 straight-line IR 中工作: + +```mlir +%r = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.addf %a, %b : ... -> !pto.vmi.vreg<128xf32> + scf.yield %x : !pto.vmi.vreg<128xf32> +} else { + scf.yield %c : !pto.vmi.vreg<128xf32> +} +pto.vmi.store %r, %ptr, %off : ... +``` + +对应 lowering test 必须检查: + +```text +CHECK-NOT: pto.vmi. +CHECK-NOT: !pto.vmi. +CHECK-NOT: unrealized_conversion_cast +``` + +如果这个测试失败,通常不是该 op 的 VPTO pattern 本身错,而是 layout assignment 没有把 yield/result/consumer +约束统一,或者 OneToN structural pattern 漏了某种 region/control-flow op。 + +### Preflight Versus Pattern Failure + +语义合法但当前还没有物理实现的 case,应该在 `verifySupportedVMIToVPTOOps()` 里给稳定 diagnostic, +不要让 pattern 随机 `notifyMatchFailure()` 后落成 generic conversion failure。 + +```text +use verifier failure: + op 本身语义非法,任何 target 都不应该接受。 + examples: + absf on integer element + shrui on signed integer element + bitcast total bits mismatch + +use VMI-LAYOUT-CONTRACT: + 多个 producer/consumer/control-flow 约束互相冲突。 + examples: + one value simultaneously required as contiguous and deinterleaved=2 + one mask simultaneously required as b16 and b32 + +use VMI-UNSUPPORTED in preflight: + VMI semantics are valid, but current VPTO materialization is not implemented. + examples: + partial/tail memory access + pred-only constant mask without concrete b8/b16/b32 granularity + shuffle that requires vselr index-vector materialization + bitcast across partial physical chunks + +use VMI-RESIDUAL-OP: + conversion framework finished but VMI op/type/helper/cast remains. + This is a pass bug or missing pattern, not a user semantic error. +``` + +Pattern-local `notifyMatchFailure()` is still useful for debugging competing patterns, but it must not be the only +user-visible explanation for a known unsupported VMI semantic case. diff --git a/include/PTO/IR/PTOAttrs.td b/include/PTO/IR/PTOAttrs.td index 6ae245f815..6e46ddbf7d 100644 --- a/include/PTO/IR/PTOAttrs.td +++ b/include/PTO/IR/PTOAttrs.td @@ -38,6 +38,8 @@ class PTO_Attr traits = []> let mnemonic = attrMnemonic; } +include "PTO/IR/VMIAttrs.td" + //===----------------------------------------------------------------------===// // Address Space //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 70ade82bbf..3e2a20f408 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -76,6 +76,7 @@ class PTO_DpsOp traits = []> class PTO_Op traits = []> : Op; +include "PTO/IR/VMIOps.td" include "PTO/IR/VPTOOps.td" //===----------------------------------------------------------------------===// diff --git a/include/PTO/IR/PTOTypeDefs.td b/include/PTO/IR/PTOTypeDefs.td index 69003cf5a4..7310f2a8e2 100644 --- a/include/PTO/IR/PTOTypeDefs.td +++ b/include/PTO/IR/PTOTypeDefs.td @@ -377,4 +377,5 @@ def F4E2M1x2Type : TypeDef { + let summary = "VMI logical vector register layout"; + let parameters = (ins + StringRefParameter<"layout kind">:$kind, + "int64_t":$factor + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); + static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, + int64_t factor); + + bool isContiguous() const { return getKind() == "contiguous"; } + bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIATTRS diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td new file mode 100644 index 0000000000..6f567bb8a5 --- /dev/null +++ b/include/PTO/IR/VMIOps.td @@ -0,0 +1,562 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIOps.td - PTO VMI semantic operations -------------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMIOPS +#define MLIR_DIALECT_PTO_IR_VMIOPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def VMI_VRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType>($_self)">, + "VMI logical vector register type">; + +def VMI_MaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIMaskType>($_self)">, + "VMI logical mask type">; + +def VMI_ValueTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VMIVRegType, ::mlir::pto::VMIMaskType>($_self)">, + "VMI logical vector or mask type">; + +def PTO_PhysicalVRegTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::VRegType>($_self)">, + "PTO physical vector register type">; + +def PTO_PhysicalMaskTypeConstraint : Type< + CPred<"::llvm::isa<::mlir::pto::MaskType>($_self)">, + "PTO physical mask type">; + +def PTO_PhysicalVMIPartTypeConstraint : AnyTypeOf< + [PTO_PhysicalVRegTypeConstraint, PTO_PhysicalMaskTypeConstraint], + "PTO physical vector register or mask type">; + +class VMI_Op traits = []> + : PTO_Op<"vmi." # mnemonic, traits>; + +def VMIConstantOp : VMI_Op<"constant"> { + let summary = "VMI logical vector constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIBroadcastOp : VMI_Op<"broadcast"> { + let summary = "Broadcast one scalar or rank-0 VMI vector to a VMI logical vector"; + let arguments = (ins AnyType:$value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)"; +} + +def VMIIotaOp : VMI_Op<"iota"> { + let summary = "Create a VMI logical index vector from a scalar base"; + let arguments = (ins + AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$base, + OptionalAttr:$order + ); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$base attr-dict `:` type($base) `->` type($result)"; +} + +def VMICreateMaskOp : VMI_Op<"create_mask"> { + let summary = "Create a VMI logical prefix predicate mask"; + let arguments = (ins Index:$active_lanes); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; +} + +def VMIConstantMaskOp : VMI_Op<"constant_mask"> { + let summary = "VMI logical predicate mask constant"; + let arguments = (ins AnyAttr:$value); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIMaskAndOp : VMI_Op<"mask_and"> { + let summary = "VMI logical predicate mask and"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskOrOp : VMI_Op<"mask_or"> { + let summary = "VMI logical predicate mask or"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskXOrOp : VMI_Op<"mask_xor"> { + let summary = "VMI logical predicate mask xor"; + let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaskNotOp : VMI_Op<"mask_not"> { + let summary = "VMI logical predicate mask not"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAddFOp : VMI_Op<"addf"> { + let summary = "VMI floating-point elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIAddIOp : VMI_Op<"addi"> { + let summary = "VMI integer elementwise add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubFOp : VMI_Op<"subf"> { + let summary = "VMI floating-point elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISubIOp : VMI_Op<"subi"> { + let summary = "VMI integer elementwise subtract"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulFOp : VMI_Op<"mulf"> { + let summary = "VMI floating-point elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMulIOp : VMI_Op<"muli"> { + let summary = "VMI integer elementwise multiply"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIFmaOp : VMI_Op<"fma"> { + let summary = "VMI fused floating-point multiply-add"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs, + VMI_VRegTypeConstraint:$acc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)"; +} + +def VMIDivFOp : VMI_Op<"divf"> { + let summary = "VMI floating-point elementwise divide"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMinFOp : VMI_Op<"minf"> { + let summary = "VMI floating-point elementwise minimum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIMaxFOp : VMI_Op<"maxf"> { + let summary = "VMI floating-point elementwise maximum"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINegFOp : VMI_Op<"negf"> { + let summary = "VMI floating-point elementwise negate"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsFOp : VMI_Op<"absf"> { + let summary = "VMI floating-point elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAbsIOp : VMI_Op<"absi"> { + let summary = "VMI integer elementwise absolute value"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMISqrtOp : VMI_Op<"sqrt"> { + let summary = "VMI floating-point elementwise square root"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExpOp : VMI_Op<"exp"> { + let summary = "VMI floating-point elementwise exponential"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILnOp : VMI_Op<"ln"> { + let summary = "VMI floating-point elementwise natural logarithm"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIReluOp : VMI_Op<"relu"> { + let summary = "VMI floating-point elementwise ReLU"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIAndIOp : VMI_Op<"andi"> { + let summary = "VMI integer elementwise bitwise and"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIOrIOp : VMI_Op<"ori"> { + let summary = "VMI integer elementwise bitwise or"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIXOrIOp : VMI_Op<"xori"> { + let summary = "VMI integer elementwise bitwise xor"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShLIOp : VMI_Op<"shli"> { + let summary = "VMI integer elementwise left shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMIShRUIOp : VMI_Op<"shrui"> { + let summary = "VMI unsigned integer elementwise right shift"; + let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMINotOp : VMI_Op<"not"> { + let summary = "VMI integer elementwise bitwise not"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMICmpFOp : VMI_Op<"cmpf"> { + let summary = "VMI floating-point elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMICmpIOp : VMI_Op<"cmpi"> { + let summary = "VMI integer elementwise compare"; + let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; +} + +def VMISelectOp : VMI_Op<"select"> { + let summary = "VMI elementwise select"; + let arguments = (ins VMI_MaskTypeConstraint:$mask, VMI_VRegTypeConstraint:$true_value, + VMI_VRegTypeConstraint:$false_value); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask `,` $true_value `,` $false_value attr-dict `:` type($mask) `,` type($true_value) `,` type($false_value) `->` type($result)"; +} + +def VMIActivePrefixIndexOp : VMI_Op<"active_prefix_index"> { + let summary = "VMI per-lane active-prefix index from a predicate mask"; + let arguments = (ins VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$mask attr-dict `:` type($mask) `->` type($result)"; +} + +def VMICompressOp : VMI_Op<"compress"> { + let summary = "VMI compact active source lanes according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMICompressStoreOp : VMI_Op<"compress_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI store active source lanes contiguously according to a predicate mask"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIReduceAddIOp : VMI_Op<"reduce_addi"> { + let summary = "VMI masked integer add reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceAddFOp : VMI_Op<"reduce_addf"> { + let summary = "VMI masked floating-point add reduction with explicit reassociation permission"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMaxFOp : VMI_Op<"reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { + let summary = "VMI masked floating-point minimum reduction with a rank-0 vector init"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_VRegTypeConstraint:$init, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; +} + +def VMIExtFOp : VMI_Op<"extf"> { + let summary = "VMI floating-point elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncFOp : VMI_Op<"truncf"> { + let summary = "VMI floating-point elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIBitcastOp : VMI_Op<"bitcast"> { + let summary = "VMI bitwise vector reinterpretation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector load with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIGatherOp : VMI_Op<"gather", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed gather with passthrough lanes"; + let arguments = (ins PtrOrMemRef:$source, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($indices) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIExpandLoadOp : VMI_Op<"expand_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load a dense active-lane stream into masked logical lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + VMI_MaskTypeConstraint:$mask, + VMI_VRegTypeConstraint:$passthru); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $mask `,` $passthru attr-dict `:` type($source) `,` type($mask) `,` type($passthru) `->` type($result)"; +} + +def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, Index:$offset); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($mask)"; +} + +def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical masked indexed scatter"; + let arguments = (ins VMI_VRegTypeConstraint:$value, + PtrOrMemRef:$destination, + VMI_VRegTypeConstraint:$indices, + VMI_MaskTypeConstraint:$mask, + OptionalAttr:$indices_unique); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $indices `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($indices) `,` type($mask)"; +} + +def VMITileReadOp : VMI_Op<"tile_read", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical tile read"; + let arguments = (ins AnyType:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITileWriteOp : VMI_Op<"tile_write", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical tile write"; + let arguments = (ins VMI_VRegTypeConstraint:$value, AnyType:$destination); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination attr-dict `:` type($value) `,` type($destination)"; +} + +def VMIShuffleOp : VMI_Op<"shuffle"> { + let summary = "VMI static lane shuffle"; + let arguments = (ins VMI_VRegTypeConstraint:$source, DenseI64ArrayAttr:$indices); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $indices `]` attr-dict `:` type($source) `->` type($result)"; +} + +def VMIChannelSplitOp : VMI_Op<"channel_split"> { + let summary = "VMI split interleaved logical channels"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs Variadic:$results); + let hasVerifier = 1; +} + +def VMIChannelMergeOp : VMI_Op<"channel_merge"> { + let summary = "VMI merge logical channels by interleaving"; + let arguments = (ins Variadic:$inputs); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; +} + +def VMIEnsureLayoutOp : VMI_Op<"ensure_layout"> { + let summary = "Internal VMI data layout materialization helper"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout"> { + let summary = "Internal VMI mask layout materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIEnsureMaskGranularityOp : VMI_Op<"ensure_mask_granularity"> { + let summary = "Internal VMI mask granularity materialization helper"; + let arguments = (ins VMI_MaskTypeConstraint:$source); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIUnpackOp : VMI_Op<"unpack"> { + let summary = "Internal VMI value projection to physical parts"; + let arguments = (ins VMI_ValueTypeConstraint:$source); + let results = (outs Variadic:$parts); + let hasVerifier = 1; +} + +def VMIPackOp : VMI_Op<"pack"> { + let summary = "Internal physical parts materialized as one VMI value"; + let arguments = (ins Variadic:$parts); + let results = (outs VMI_ValueTypeConstraint:$result); + let hasVerifier = 1; +} + +#endif // MLIR_DIALECT_PTO_IR_VMIOPS diff --git a/include/PTO/IR/VMITypeDefs.td b/include/PTO/IR/VMITypeDefs.td new file mode 100644 index 0000000000..4ec6bb5009 --- /dev/null +++ b/include/PTO/IR/VMITypeDefs.td @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMITypeDefs.td - PTO VMI type definitions -----------*- tablegen -*-===// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PTO_IR_VMITYPEDEFS +#define MLIR_DIALECT_PTO_IR_VMITYPEDEFS + +include "PTO/IR/PTODialect.td" +include "PTO/IR/PTOAttrs.td" + +def VMIVRegType : TypeDef { + let mnemonic = "vmi.vreg"; + let summary = "A VMI logical vector register value"; + + let parameters = (ins + "int64_t":$elementCount, + "Type":$elementType, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + bool hasLayout() const { return static_cast(getLayout()); } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +def VMIMaskType : TypeDef { + let mnemonic = "vmi.mask"; + let summary = "A VMI logical predicate mask value"; + + let parameters = (ins + "int64_t":$elementCount, + StringRefParameter<"mask granularity view">:$granularity, + "mlir::Attribute":$layout + ); + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + static bool isSupportedGranularity(::llvm::StringRef granularity); + static bool isConcreteGranularity(::llvm::StringRef granularity); + + bool hasLayout() const { return static_cast(getLayout()); } + bool isPred() const { return getGranularity() == "pred"; } + bool isB8() const { return getGranularity() == "b8"; } + bool isB16() const { return getGranularity() == "b16"; } + bool isB32() const { return getGranularity() == "b32"; } + VMILayoutAttr getLayoutAttr() const { + return ::llvm::dyn_cast_or_null(getLayout()); + } + }]; +} + +#endif // MLIR_DIALECT_PTO_IR_VMITYPEDEFS diff --git a/include/PTO/IR/VMIUtils.h b/include/PTO/IR/VMIUtils.h new file mode 100644 index 0000000000..e55e558034 --- /dev/null +++ b/include/PTO/IR/VMIUtils.h @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIUtils.h - PTO VMI shared helpers ----------------------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_IR_VMIUTILS_H +#define PTO_IR_VMIUTILS_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir::pto { + +inline constexpr StringLiteral kVMIDiagUnsupported = "VMI-UNSUPPORTED"; +inline constexpr StringLiteral kVMIDiagLayoutContract = + "VMI-LAYOUT-CONTRACT"; +inline constexpr StringLiteral kVMIDiagPassInvariant = "VMI-PASS-INVARIANT"; +inline constexpr StringLiteral kVMIDiagResidualOp = "VMI-RESIDUAL-OP"; + +inline constexpr StringLiteral kVMIDiagUnsupportedPrefix = + "VMI-UNSUPPORTED: "; +inline constexpr StringLiteral kVMIDiagLayoutContractPrefix = + "VMI-LAYOUT-CONTRACT: "; +inline constexpr StringLiteral kVMIDiagPassInvariantPrefix = + "VMI-PASS-INVARIANT: "; +inline constexpr StringLiteral kVMIDiagResidualOpPrefix = "VMI-RESIDUAL-OP: "; + +struct VMIPhysicalLane { + int64_t part = 0; + int64_t chunk = 0; + int64_t lane = 0; +}; + +FailureOr getDataLanesPerPart(Type elementType); +FailureOr getMaskLanesPerPart(StringRef granularity); +FailureOr getVMIPhysicalArity(Type type); +FailureOr mapLogicalLaneToPhysical(Type type, + int64_t logicalLane); +FailureOr mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, int64_t lane); +FailureOr isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane); + +} // namespace mlir::pto + +#endif // PTO_IR_VMIUTILS_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3de31a89bf..b83bdbc195 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -107,6 +107,14 @@ LogicalResult validateVPTOEmissionIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); std::unique_ptr createPTOValidateVPTOIRPass(); std::unique_ptr createPTOValidateVPTOEmissionIRPass(); +LogicalResult validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +LogicalResult validateVMILayoutAssignedIR(ModuleOp module, + llvm::raw_ostream *diagOS = nullptr); +std::unique_ptr createPTOValidateVMIIRPass(); +std::unique_ptr createPTOValidateVMILayoutIRPass(); +std::unique_ptr createVMILayoutAssignmentPass(); +std::unique_ptr createVMIToVPTOPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); std::unique_ptr createFoldTileBufIntrinsicsPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 63b06b6dbf..435b70a328 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -787,6 +787,77 @@ def PTOValidateVPTOIR : Pass<"pto-validate-vpto-ir", "ModuleOp"> { "mlir::scf::SCFDialect"]; } +def PTOValidateVMIIR : Pass<"pto-validate-vmi-ir", "ModuleOp"> { + let summary = "Validate VMI producer-boundary semantic IR"; + let description = [{ + Checks that VMI producer-boundary IR uses only surface VMI data/mask types, + native pto.vmi semantic ops, and structural control-flow/function ops. This + pass runs before layout assignment, so layout-assigned VMI types, VMI helper + ops, and physical VPTO register types are rejected. + }]; + let constructor = "mlir::pto::createPTOValidateVMIIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def PTOValidateVMILayoutIR + : Pass<"pto-validate-vmi-layout-ir", "ModuleOp"> { + let summary = "Validate layout-assigned VMI IR"; + let description = [{ + Checks the post-layout-assignment VMI stage: every VMI data value must have + a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 + granularity and layout, physical VPTO register values must not appear yet, + and VMI typed values must stay inside VMI semantic/helper or structural ops. + }]; + let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { + let summary = "Assign concrete VMI layouts and mask granularities"; + let description = [{ + Solves VMI layout constraints and materializes the chosen layout and mask + granularity into VMI types. This pass is the boundary between surface VMI + semantic IR and layout-assigned VMI IR. + }]; + let constructor = "mlir::pto::createVMILayoutAssignmentPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMIToVPTO : Pass<"vmi-to-vpto", "ModuleOp"> { + let summary = "Convert layout-assigned VMI IR to physical VPTO IR"; + let description = [{ + Converts layout-assigned VMI aggregate data/mask types to ordered physical + VPTO register and mask value lists using MLIR OneToNTypeConversion. This + pass is responsible for VMI 1:N type conversion, structural control-flow + and function/call signature conversion, and VMI semantic op physicalization. + }]; + let constructor = "mlir::pto::createVMIToVPTOPass()"; + let options = [ + Option<"enableStableGatherMaskedLoad", + "enable-stable-gather-masked-load", "bool", + /*default=*/"false", + "Reserve the stable VGATHER-based lowering path for VMI masked " + "loads; currently emits a TODO diagnostic when used."> + ]; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + def PTOValidateVPTOEmissionIR : Pass<"pto-validate-vpto-emission-ir", "ModuleOp"> { let summary = diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h new file mode 100644 index 0000000000..15b4f19f1d --- /dev/null +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -0,0 +1,318 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMITargetCapabilities.h - VMI target capability registry -*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMITARGETCAPABILITIES_H +#define PTO_TRANSFORMS_VMITARGETCAPABILITIES_H + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Twine.h" + +#include + +namespace mlir::pto { + +enum class VMICapabilityStatus { + supported, + unsupported_missing_capability, + unsupported_disabled_by_option, + unsupported_resource, +}; + +enum class VMIElementPurpose { + PredicateMask, + F16F32, + F16BF16F32, + SignlessOrSignedI8I16I32, + AnyI8I16I32, + VMula, + VRelu, +}; + +enum class VMIReductionKind { + AddI, + AddF, + MaxF, + MinF, +}; + +enum class VMIFallbackResourceKind { + ScratchMemory, + GuardedControlFlow, +}; + +struct VMICapabilityResult { + VMICapabilityStatus status = VMICapabilityStatus::supported; + std::string reason; + + static VMICapabilityResult supported() { return {}; } + + static VMICapabilityResult missingCapability(const Twine &reason) { + VMICapabilityResult result; + result.status = VMICapabilityStatus::unsupported_missing_capability; + result.reason = reason.str(); + return result; + } + + bool isSupported() const { + return status == VMICapabilityStatus::supported; + } + + LogicalResult toLogicalResult(std::string *outReason = nullptr) const { + if (isSupported()) + return success(); + if (outReason) + *outReason = reason; + return failure(); + } +}; + +class VMITargetCapabilityRegistry { +public: + VMICapabilityResult supportsElementType(Type type, + VMIElementPurpose purpose) const { + switch (purpose) { + case VMIElementPurpose::PredicateMask: { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type); + if (elementBits == 8 || elementBits == 16 || elementBits == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires an 8/16/32-bit element type so VPTO b8/b16/b32 " + "predicate masks can be materialized"); + } + case VMIElementPurpose::F16F32: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::F16BF16F32: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16/bf16/f32 element type for direct VPTO lowering"); + case VMIElementPurpose::SignlessOrSignedI8I16I32: + if (isSignlessOrSignedI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed i8/i16/i32 element type for direct VPTO " + "lowering"); + case VMIElementPurpose::AnyI8I16I32: + if (isAnyI8I16I32(type)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires signless/signed/unsigned i8/i16/i32 element type for " + "direct VPTO lowering"); + case VMIElementPurpose::VMula: + if (type.isF16() || type.isBF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "requires f16, bf16, or f32 element type for pto.vmula"); + case VMIElementPurpose::VRelu: + if (type.isF16() || type.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "pto.vrelu direct lowering supports only f16/f32 VMI " + "floating-point element types"); + } + llvm_unreachable("unhandled VMI element purpose"); + } + + VMICapabilityResult supportsDirectMemory(Type type, StringRef role) const { + switch (classifyDirectMemoryRole(type)) { + case DirectMemoryRole::UB: + case DirectMemoryRole::Unknown: + return VMICapabilityResult::supported(); + case DirectMemoryRole::GM: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is GM-backed, but current direct VMI-to-VPTO memory lowering " + "emits pto.vlds/pto.vsts and requires UB-backed memory"); + case DirectMemoryRole::Other: + return VMICapabilityResult::missingCapability( + Twine(role) + + " is not UB-backed memory supported by pto.vlds/pto.vsts"); + } + llvm_unreachable("unhandled direct memory role"); + } + + VMICapabilityResult supportsUBPointerMemory(Type type, StringRef role, + StringRef physicalOp, + StringRef ubReason) const { + auto ptrType = dyn_cast(type); + if (!ptrType) + return VMICapabilityResult::missingCapability( + Twine("requires a !pto.ptr ") + role + " because " + physicalOp + + " is pointer-only"); + if (ptrType.getMemorySpace().getAddressSpace() != AddressSpace::VEC) + return VMICapabilityResult::missingCapability( + Twine("requires a UB ") + role + " because " + ubReason); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsChannelCount(StringRef opName, + int64_t channels) const { + if (channels == 2 || channels == 4) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + Twine(opName) + " supports only 2 or 4 channels"); + } + + VMICapabilityResult supportsLayoutConversion(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + Type elementType) const { + (void)elementType; + if (!sourceLayout || !resultLayout) + return VMICapabilityResult::missingCapability( + "requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMICapabilityResult::supported(); + if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "unsupported source/result layout pair"); + } + + VMICapabilityResult supportsMaskGranularityConversion( + StringRef sourceGranularity, StringRef resultGranularity) const { + if (!VMIMaskType::isConcreteGranularity(sourceGranularity) || + !VMIMaskType::isConcreteGranularity(resultGranularity)) + return VMICapabilityResult::missingCapability( + "requires concrete b8/b16/b32 source and result granularities"); + return VMICapabilityResult::supported(); + } + + VMICapabilityResult supportsTrueMaskedLoad(Type sourceType, Type resultType, + Type maskType) const { + (void)sourceType; + (void)resultType; + (void)maskType; + return VMICapabilityResult::missingCapability( + "target true masked/non-faulting load is unavailable because the " + "current VPTO pto.vlds surface has no mask operand"); + } + + VMICapabilityResult supportsFallbackResource( + VMIFallbackResourceKind kind) const { + switch (kind) { + case VMIFallbackResourceKind::ScratchMemory: + return VMICapabilityResult::missingCapability( + "scratch memory fallback resource allocation is not implemented"); + case VMIFallbackResourceKind::GuardedControlFlow: + return VMICapabilityResult::missingCapability( + "guarded memory fallback control-flow lowering is not implemented"); + } + llvm_unreachable("unhandled VMI fallback resource kind"); + } + + VMICapabilityResult supportsReductionElementType( + VMIReductionKind kind, Type elementType) const { + switch (kind) { + case VMIReductionKind::AddI: + if (pto::getPTOStorageElemBitWidth(elementType) == 32 && + isa(elementType)) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only 32-bit integer elements because narrow " + "vcadd widens its result"); + case VMIReductionKind::AddF: + if (elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f32 elements; f16 requires an explicit " + "accumulator precision and rounding contract"); + case VMIReductionKind::MaxF: + case VMIReductionKind::MinF: + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements because pto.vcmax/" + "pto.vcmin support only those floating-point element types"); + } + llvm_unreachable("unhandled VMI reduction kind"); + } + +private: + enum class DirectMemoryRole { Unknown, UB, GM, Other }; + + DirectMemoryRole classifyDirectMemoryRole(Type type) const { + if (auto ptrType = dyn_cast(type)) { + switch (ptrType.getMemorySpace().getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + auto memrefType = dyn_cast(type); + if (!memrefType) + return DirectMemoryRole::Other; + + Attribute memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return DirectMemoryRole::Unknown; + + if (auto addressSpace = dyn_cast(memorySpace)) { + switch (addressSpace.getAddressSpace()) { + case AddressSpace::GM: + case AddressSpace::Zero: + return DirectMemoryRole::GM; + case AddressSpace::VEC: + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + if (auto integerSpace = dyn_cast(memorySpace)) { + switch (integerSpace.getInt()) { + case static_cast(AddressSpace::GM): + case static_cast(AddressSpace::Zero): + return DirectMemoryRole::GM; + case static_cast(AddressSpace::VEC): + return DirectMemoryRole::UB; + default: + return DirectMemoryRole::Other; + } + } + + return DirectMemoryRole::Other; + } + + static bool isSignlessOrSignedI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType || intType.isUnsigned()) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } + + static bool isAnyI8I16I32(Type type) { + auto intType = dyn_cast(type); + if (!intType) + return false; + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + } +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMITARGETCAPABILITIES_H diff --git a/lib/PTO/IR/CMakeLists.txt b/lib/PTO/IR/CMakeLists.txt index 74b9e0bd68..4f8d995796 100644 --- a/lib/PTO/IR/CMakeLists.txt +++ b/lib/PTO/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(PTOIR PTO.cpp VPTO.cpp + VMI.cpp PTOAttrs.cpp PTOSyncUtils.cpp PTOTypeDefs.cpp diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp new file mode 100644 index 0000000000..1f9a43f51a --- /dev/null +++ b/lib/PTO/IR/VMI.cpp @@ -0,0 +1,1407 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMI.cpp - PTO VMI type and attribute support -----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static std::string formatVMIVRegType(int64_t elementCount, Type elementType, + Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.vreg<" << elementCount << "x" << elementType; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static std::string formatVMIMaskType(int64_t elementCount, + StringRef granularity, + Attribute layout) { + std::string result; + llvm::raw_string_ostream os(result); + os << "!pto.vmi.mask<" << elementCount << "x" << granularity; + if (layout) + os << ", " << layout; + os << ">"; + return result; +} + +static bool isSupportedVMIElementType(Type type) { + return isa(type) || + pto::isPTOLowPrecisionType(type); +} + +static bool isVMIFloatLikeType(Type type) { + return isa(type) || pto::isPTOLowPrecisionType(type); +} + +static bool isVMIIntegerLikeType(Type type) { + return isa(type); +} + +static bool isVMIIotaElementType(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth() == 8 || intType.getWidth() == 16 || + intType.getWidth() == 32; + return type.isF16() || type.isF32(); +} + +static bool isCompatibleScalarForSemanticType(Type semanticType, + Type scalarType) { + if (semanticType == scalarType) + return true; + + auto semanticInt = dyn_cast(semanticType); + auto scalarInt = dyn_cast(scalarType); + if (!semanticInt || !scalarInt || + semanticInt.getWidth() != scalarInt.getWidth()) + return false; + + if (semanticInt.isSigned()) + return scalarInt.isSigned() || scalarInt.isSignless(); + if (semanticInt.isUnsigned()) + return scalarInt.isUnsigned() || scalarInt.isSignless(); + return scalarInt.isSignless(); +} + +static unsigned getVMIElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static std::optional getVMIIntegerOrFloatBitWidth(Type type) { + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + return std::nullopt; +} + +static int64_t divideCeilNonNegative(int64_t value, int64_t divisor) { + return value == 0 ? 0 : (value + divisor - 1) / divisor; +} + +static LogicalResult parseOptionalVMILayout(AsmParser &parser, + Attribute &layout) { + if (failed(parser.parseOptionalComma())) + return success(); + + if (failed(parser.parseAttribute(layout))) + return failure(); + if (!mlir::isa(layout)) + return parser.emitError(parser.getCurrentLocation(), + "expected #pto.vmi.layout attribute"); + return success(); +} + +static FailureOr getVMIElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getAssignedVMILayout(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr; +} + +static FailureOr getLayoutFactor(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isContiguous() ? 1 : (*layout).getFactor(); +} + +static FailureOr getPhysicalLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static int64_t getMaskGranularityBitWidth(StringRef granularity) { + if (granularity == "b8") + return 8; + if (granularity == "b16") + return 16; + if (granularity == "b32") + return 32; + return 0; +} + +static bool isLayoutAssigned(VMIVRegType type) { + return static_cast(type.getLayoutAttr()); +} + +static bool isLayoutAssigned(VMIMaskType type) { + return static_cast(type.getLayoutAttr()); +} + +static LogicalResult verifyAllSameVRegShapeAndLayout(Operation *op, + ArrayRef types, + bool requireSameElement) { + if (types.empty()) + return success(); + + VMIVRegType first = types.front(); + bool anyLayout = llvm::any_of(types, [](VMIVRegType type) { + return isLayoutAssigned(type); + }); + + for (VMIVRegType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError("requires all VMI data values to have the same logical lane count"); + if (requireSameElement && type.getElementType() != first.getElementType()) + return op->emitOpError("requires all VMI data values to have the same element type"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError("requires either all or no VMI data values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError("requires all layout-assigned VMI data values to have the same layout"); + } + return success(); +} + +static LogicalResult verifyElementwiseVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, + VMIVRegType result) { + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatUnaryVRegOp(Operation *op, + VMIVRegType source, + VMIVRegType result) { + if (!isVMIFloatLikeType(source.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {source, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyFloatTernaryVRegOp(Operation *op, VMIVRegType lhs, + VMIVRegType rhs, VMIVRegType acc, + VMIVRegType result) { + if (!isVMIFloatLikeType(lhs.getElementType())) + return op->emitOpError("requires floating-point-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(op, {lhs, rhs, acc, result}, + /*requireSameElement=*/true); +} + +static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( + Operation *op, ArrayRef types) { + if (types.empty()) + return success(); + + VMIMaskType first = types.front(); + bool anyLayout = llvm::any_of(types, [](VMIMaskType type) { + return isLayoutAssigned(type); + }); + + for (VMIMaskType type : types) { + if (type.getElementCount() != first.getElementCount()) + return op->emitOpError( + "requires all VMI mask values to have the same logical lane count"); + if (type.getGranularity() != first.getGranularity()) + return op->emitOpError( + "requires all VMI mask values to have the same granularity"); + if (anyLayout && !isLayoutAssigned(type)) + return op->emitOpError( + "requires either all or no VMI mask values to carry layout"); + if (anyLayout && type.getLayout() != first.getLayout()) + return op->emitOpError( + "requires all layout-assigned VMI mask values to have the same " + "layout"); + } + return success(); +} + +static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, + VMIVRegType dataType) { + if (maskType.getElementCount() != dataType.getElementCount()) + return op->emitOpError("requires mask logical lane count to match data lane count"); + + if (isLayoutAssigned(maskType) || isLayoutAssigned(dataType)) { + if (!isLayoutAssigned(maskType) || !isLayoutAssigned(dataType)) + return op->emitOpError("requires either both mask and data to carry layout or neither to carry layout"); + if (maskType.getLayout() != dataType.getLayout()) + return op->emitOpError("requires mask layout to match data layout"); + } + + if (maskType.isPred()) + return success(); + + unsigned elementBitWidth = getVMIElementBitWidth(dataType.getElementType()); + int64_t maskBitWidth = getMaskGranularityBitWidth(maskType.getGranularity()); + if (elementBitWidth != 0 && maskBitWidth != 0 && + elementBitWidth != static_cast(maskBitWidth)) + return op->emitOpError("requires mask granularity to match data element width"); + + return success(); +} + +static Type getMemoryElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, + VMIVRegType dataType, + StringRef role) { + Type memoryElementType = getMemoryElementType(memoryType); + if (!memoryElementType) + return success(); + if (memoryElementType != dataType.getElementType()) + return op->emitOpError() + << "requires memory " << role + << " element type to match VMI data element type"; + return success(); +} + +static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, + TypeRange physicalTypes) { + FailureOr expectedArity = getVMIPhysicalArity(vmiType); + if (failed(expectedArity)) + return op->emitOpError("requires a layout-assigned VMI type with computable physical arity"); + if (static_cast(physicalTypes.size()) != *expectedArity) + return op->emitOpError() + << "requires " << *expectedArity << " physical parts, got " + << physicalTypes.size(); + + if (auto vregType = dyn_cast(vmiType)) { + FailureOr lanesPerPart = + getDataLanesPerPart(vregType.getElementType()); + if (failed(lanesPerPart)) + return op->emitOpError("requires data element type with known physical lane count"); + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical data parts to be !pto.vreg"); + if (partType.getElementCount() != *lanesPerPart || + partType.getElementType() != vregType.getElementType()) + return op->emitOpError("requires physical data part type to match VMI lane-map helper"); + } + return success(); + } + + auto maskType = dyn_cast(vmiType); + if (!maskType) + return op->emitOpError("requires VMI data or mask type"); + if (maskType.isPred()) + return op->emitOpError("requires layout-assigned mask with concrete granularity"); + + for (Type physicalType : physicalTypes) { + auto partType = dyn_cast(physicalType); + if (!partType) + return op->emitOpError("requires physical mask parts to be !pto.mask"); + if (partType.getGranularity() != maskType.getGranularity()) + return op->emitOpError("requires physical mask part granularity to match VMI mask"); + } + return success(); +} + +static int64_t getLogicalLanesInPart(int64_t elementCount, int64_t factor, + int64_t part) { + if (part < 0 || part >= factor || part >= elementCount) + return 0; + return ((elementCount - 1 - part) / factor) + 1; +} + +} // namespace + +VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context) { + return VMILayoutAttr::get(context, "contiguous", 1); +} + +VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, + int64_t factor) { + return VMILayoutAttr::get(context, "deinterleaved", factor); +} + +Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { + SMLoc loc = parser.getCurrentLocation(); + StringRef kind; + int64_t factor = 1; + + if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) + return {}; + + if (kind == "contiguous") { + factor = 1; + } else if (kind == "deinterleaved") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; + } else { + parser.emitError(parser.getCurrentLocation(), + "expected VMI layout kind 'contiguous' or " + "'deinterleaved'"); + return {}; + } + + if (failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), kind, + factor); +} + +void VMILayoutAttr::print(AsmPrinter &printer) const { + printer << "<" << getKind(); + if (isDeinterleaved()) + printer << " = " << getFactor(); + printer << ">"; +} + +LogicalResult +VMILayoutAttr::verify(function_ref emitError, + StringRef kind, int64_t factor) { + if (kind == "contiguous") { + if (factor != 1) + return emitError() + << "#pto.vmi.layout requires factor to be 1"; + return success(); + } + + if (kind == "deinterleaved") { + if (factor != 2 && factor != 4) + return emitError() + << "#pto.vmi.layout expected factor to be 2 or 4"; + return success(); + } + + return emitError() << "expected VMI layout kind to be 'contiguous' or " + "'deinterleaved'"; +} + +Type VMIVRegType::parse(AsmParser &parser) { + SmallVector shape; + Type elementType; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseType(elementType)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), + shape.front(), elementType, layout); +} + +void VMIVRegType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x"; + printer.printType(getElementType()); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIVRegType::verify(function_ref emitError, + int64_t elementCount, Type elementType, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected a positive element count"; + + if (!isSupportedVMIElementType(elementType)) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected an integer, index, floating-point, or " + "PTO low-precision element type"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected layout to be #pto.vmi.layout"; + + return success(); +} + +bool VMIMaskType::isSupportedGranularity(StringRef granularity) { + return granularity == "pred" || isConcreteGranularity(granularity); +} + +bool VMIMaskType::isConcreteGranularity(StringRef granularity) { + return granularity == "b8" || granularity == "b16" || granularity == "b32"; +} + +Type VMIMaskType::parse(AsmParser &parser) { + SmallVector shape; + StringRef granularity; + Attribute layout; + SMLoc loc = parser.getCurrentLocation(); + + if (failed(parser.parseLess()) || + failed(parser.parseDimensionList(shape, /*allowDynamic=*/false, + /*withTrailingX=*/true)) || + shape.size() != 1 || failed(parser.parseKeyword(&granularity)) || + failed(parseOptionalVMILayout(parser, layout)) || + failed(parser.parseGreater())) + return {}; + + return parser.getChecked(loc, parser.getContext(), + shape.front(), granularity, layout); +} + +void VMIMaskType::print(AsmPrinter &printer) const { + printer << "<" << getElementCount() << "x" << getGranularity(); + if (getLayout()) + printer << ", " << getLayout(); + printer << ">"; +} + +LogicalResult VMIMaskType::verify(function_ref emitError, + int64_t elementCount, StringRef granularity, + Attribute layout) { + if (elementCount <= 0) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected a positive element count"; + + if (!isSupportedGranularity(granularity)) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected granularity to be one of pred, b8, b16, " + "b32"; + + if (layout && !mlir::isa(layout)) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' expected layout to be #pto.vmi.layout"; + + if (granularity == "pred" && layout) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' pred mask must not carry layout"; + + if (granularity != "pred" && !layout) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' concrete mask granularity requires layout"; + + return success(); +} + +LogicalResult VMIConstantOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements constant attribute"); + if (denseAttr.getElementType() != resultType.getElementType()) + return emitOpError("requires dense constant element type to match result element type"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense constant element count to match result logical lane count"); + return success(); +} + +LogicalResult VMIBroadcastOp::verify() { + auto resultType = cast(getResult().getType()); + Type valueType = getValue().getType(); + if (valueType == resultType.getElementType()) + return success(); + if (auto vregType = dyn_cast(valueType)) { + if (vregType.getElementCount() != 1) + return emitOpError("requires VMI vector input to have one logical lane"); + if (vregType.getElementType() != resultType.getElementType()) + return emitOpError("requires VMI vector input element type to match " + "result element type"); + return success(); + } + return emitOpError("requires scalar or VMI vector input element type to " + "match result element type"); +} + +LogicalResult VMIIotaOp::verify() { + auto resultType = cast(getResult().getType()); + Type elementType = resultType.getElementType(); + if (!isVMIIotaElementType(elementType)) + return emitOpError("requires result element type to be integer 8/16/32 " + "or f16/f32"); + if (!isCompatibleScalarForSemanticType(elementType, getBase().getType())) + return emitOpError("requires base type to match result element type"); + + if (std::optional order = getOrder()) { + if (*order != "ASC" && *order != "DESC") + return emitOpError("requires order to be ASC or DESC"); + } + return success(); +} + +LogicalResult VMICreateMaskOp::verify() { + auto resultType = cast(getResult().getType()); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + +LogicalResult VMIConstantMaskOp::verify() { + auto resultType = cast(getResult().getType()); + auto denseAttr = dyn_cast(getValue()); + if (!denseAttr) + return emitOpError("requires dense elements mask constant attribute"); + if (!denseAttr.getElementType().isInteger(1)) + return emitOpError("requires dense mask constant element type to be i1"); + if (denseAttr.getNumElements() != resultType.getElementCount()) + return emitOpError("requires dense mask constant element count to match result logical lane count"); + return success(); +} + +LogicalResult VMIMaskAndOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskXOrOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {lhsType, rhsType, resultType}); +} + +LogicalResult VMIMaskNotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyAllSameMaskShapeLayoutAndGranularity( + getOperation(), {sourceType, resultType}); +} + +LogicalResult VMIAddFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIAddIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMISubIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMulIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIFmaOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto accType = cast(getAcc().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatTernaryVRegOp(getOperation(), lhsType, rhsType, accType, + resultType); +} + +LogicalResult VMIDivFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMinFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIMaxFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINegFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAbsIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMISqrtOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIExpOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMILnOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIReluOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + return verifyFloatUnaryVRegOp(getOperation(), sourceType, resultType); +} + +LogicalResult VMIAndIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIXOrIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShLIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMIShRUIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + auto integerType = dyn_cast(lhsType.getElementType()); + if (!integerType || integerType.isSigned()) + return emitOpError( + "requires signless or unsigned integer VMI element type"); + return verifyElementwiseVRegOp(getOperation(), lhsType, rhsType, resultType); +} + +LogicalResult VMINotOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + return verifyAllSameVRegShapeAndLayout(getOperation(), {sourceType, resultType}, + /*requireSameElement=*/true); +} + +LogicalResult VMICmpFOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIFloatLikeType(lhsType.getElementType())) + return emitOpError("requires floating-point-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMICmpIOp::verify() { + auto lhsType = cast(getLhs().getType()); + auto rhsType = cast(getRhs().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(lhsType.getElementType())) + return emitOpError("requires integer-like VMI element type"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {lhsType, rhsType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), resultType, lhsType); +} + +LogicalResult VMISelectOp::verify() { + auto maskType = cast(getMask().getType()); + auto trueType = cast(getTrueValue().getType()); + auto falseType = cast(getFalseValue().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {trueType, falseType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMIActivePrefixIndexOp::verify() { + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + auto resultIntType = dyn_cast(resultType.getElementType()); + if (!resultIntType || !resultIntType.isSignless()) + return emitOpError("requires signless integer result element type"); + unsigned resultWidth = resultIntType.getWidth(); + if (resultWidth != 8 && resultWidth != 16 && resultWidth != 32) + return emitOpError("requires i8, i16, or i32 result element type"); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +LogicalResult VMICompressOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {sourceType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMICompressStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMICompressStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIReduceAddIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceAddFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto initType = cast(getInit().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!getOperation()->hasAttr("reassoc")) + return emitOpError( + "requires reassoc attr because VPTO vcadd performs pair-wise " + "floating-point reduction"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, sourceType); +} + +template +LogicalResult verifyReduceMinMaxFOp(OpTy op) { + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return op.emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementType() != initType.getElementType() || + sourceType.getElementType() != resultType.getElementType()) + return op.emitOpError( + "requires source, init, and result element types to match"); + if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) + return op.emitOpError("requires init and result to be rank-0 VMI vectors"); + if (failed(verifyAllSameVRegShapeAndLayout(op.getOperation(), + {initType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(op.getOperation(), maskType, sourceType); +} + +LogicalResult VMIReduceMaxFOp::verify() { + return verifyReduceMinMaxFOp(*this); +} + +LogicalResult VMIReduceMinFOp::verify() { + return verifyReduceMinMaxFOp(*this); +} + +LogicalResult VMIExtFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError("requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType()) || + !isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError("requires result element type to be narrower than source element type"); + return success(); +} + +LogicalResult VMIBitcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + std::optional sourceBits = + getVMIIntegerOrFloatBitWidth(sourceType.getElementType()); + std::optional resultBits = + getVMIIntegerOrFloatBitWidth(resultType.getElementType()); + if (!sourceBits || !resultBits) + return emitOpError( + "requires integer or floating-point source and result element types"); + if (sourceType.getElementCount() * static_cast(*sourceBits) != + resultType.getElementCount() * static_cast(*resultBits)) + return emitOpError( + "requires source and result to carry the same total number of bits"); + + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError( + "requires either both source and result to carry layout or neither " + "to carry layout"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result layouts to match"); + } + + return success(); +} + +LogicalResult VMILoadOp::verify() { + return verifyMemoryElementMatches(getOperation(), getSource().getType(), + cast(getResult().getType()), + "source"); +} + +void VMILoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIMaskedLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIMaskedLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIGatherOp::verify() { + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return emitOpError("requires signless or unsigned 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {indicesType, passthruType, resultType}, + /*requireSameElement=*/false))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIGatherOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIExpandLoadOp::verify() { + auto maskType = cast(getMask().getType()); + auto passthruType = cast(getPassthru().getType()); + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {passthruType, resultType}, + /*requireSameElement=*/true))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIExpandLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMIStoreOp::verify() { + return verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + cast(getValue().getType()), + "destination"); +} + +void VMIStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIMaskedStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + valueType, "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIMaskedStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIScatterOp::verify() { + auto valueType = cast(getValue().getType()); + auto indicesType = cast(getIndices().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + valueType, "destination"))) + return failure(); + + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return emitOpError("requires signless or unsigned 32-bit integer indices"); + + if (failed(verifyAllSameVRegShapeAndLayout( + getOperation(), {valueType, indicesType}, + /*requireSameElement=*/false))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIScatterOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMITileReadOp::verify() { + return verifyMemoryElementMatches(getOperation(), getSource().getType(), + cast(getResult().getType()), + "source"); +} + +void VMITileReadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + +LogicalResult VMITileWriteOp::verify() { + return verifyMemoryElementMatches(getOperation(), + getDestination().getType(), + cast(getValue().getType()), + "destination"); +} + +void VMITileWriteOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + +LogicalResult VMIShuffleOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires result element type to match source element type"); + if (static_cast(getIndices().size()) != resultType.getElementCount()) + return emitOpError("requires shuffle index count to match result logical lane count"); + for (int64_t index : getIndices()) { + if (index < 0 || index >= sourceType.getElementCount()) + return emitOpError("requires every shuffle index to select an existing source logical lane"); + } + if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires either both source and result to carry layout or neither to carry layout"); + } + return success(); +} + +LogicalResult VMIChannelSplitOp::verify() { + auto sourceType = cast(getSource().getType()); + if (getResults().size() < 2) + return emitOpError("requires at least two channel results"); + auto firstResultType = cast(getResults().front().getType()); + if (sourceType.getElementCount() != + static_cast(getResults().size()) * firstResultType.getElementCount()) + return emitOpError("requires source lane count to equal result count times per-channel lane count"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (resultType.getElementCount() != firstResultType.getElementCount() || + resultType.getElementType() != sourceType.getElementType()) + return emitOpError("requires every channel result to have equal lane count and source element type"); + } + bool anyLayout = isLayoutAssigned(sourceType); + for (Value result : getResults()) + anyLayout |= isLayoutAssigned(cast(result.getType())); + if (anyLayout) { + if (!isLayoutAssigned(sourceType)) + return emitOpError("requires layout-assigned channel_split source when any channel result has layout"); + for (Value result : getResults()) { + auto resultType = cast(result.getType()); + if (!isLayoutAssigned(resultType)) + return emitOpError("requires every channel_split result to carry layout when source has layout"); + if (!cast(resultType.getLayout()).isContiguous()) + return emitOpError("requires layout-assigned channel_split results to be contiguous"); + } + int64_t channels = getResults().size(); + if (channels == 2 || channels == 4) { + auto sourceLayout = cast(sourceType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_split source to be contiguous or deinterleaved by result count"); + } + } + return success(); +} + +LogicalResult VMIChannelMergeOp::verify() { + if (getInputs().size() < 2) + return emitOpError("requires at least two channel inputs"); + auto firstInputType = cast(getInputs().front().getType()); + auto resultType = cast(getResult().getType()); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (inputType.getElementCount() != firstInputType.getElementCount() || + inputType.getElementType() != firstInputType.getElementType()) + return emitOpError("requires all channel inputs to have the same lane count and element type"); + } + if (resultType.getElementCount() != + static_cast(getInputs().size()) * firstInputType.getElementCount() || + resultType.getElementType() != firstInputType.getElementType()) + return emitOpError("requires result lane count and element type to match merged channels"); + bool anyLayout = isLayoutAssigned(resultType); + for (Value input : getInputs()) + anyLayout |= isLayoutAssigned(cast(input.getType())); + if (anyLayout) { + if (!isLayoutAssigned(resultType)) + return emitOpError("requires layout-assigned channel_merge result when any channel input has layout"); + for (Value input : getInputs()) { + auto inputType = cast(input.getType()); + if (!isLayoutAssigned(inputType)) + return emitOpError("requires every channel_merge input to carry layout when result has layout"); + if (!cast(inputType.getLayout()).isContiguous()) + return emitOpError("requires layout-assigned channel_merge inputs to be contiguous"); + } + int64_t channels = getInputs().size(); + if (channels == 2 || channels == 4) { + auto resultLayout = cast(resultType.getLayout()); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return emitOpError("requires layout-assigned channel_merge result to be contiguous or deinterleaved by input count"); + } + } + return success(); +} + +LogicalResult VMIEnsureLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result to preserve VMI data shape and element type"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskLayoutOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount() || + sourceType.getGranularity() != resultType.getGranularity()) + return emitOpError("requires source and result to preserve VMI mask shape and granularity"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + return success(); +} + +LogicalResult VMIEnsureMaskGranularityOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError("requires source and result to preserve VMI mask lane count"); + if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) + return emitOpError("requires source and result to be layout-assigned"); + if (sourceType.getLayout() != resultType.getLayout()) + return emitOpError("requires source and result mask layouts to match"); + if (sourceType.isPred() || resultType.isPred()) + return emitOpError("requires concrete source and result mask granularities"); + return success(); +} + +LogicalResult VMIUnpackOp::verify() { + return verifyPhysicalParts(getOperation(), getSource().getType(), + getParts().getTypes()); +} + +LogicalResult VMIPackOp::verify() { + return verifyPhysicalParts(getOperation(), getResult().getType(), + getParts().getTypes()); +} + +FailureOr mlir::pto::getDataLanesPerPart(Type elementType) { + unsigned elementBitWidth = pto::getPTOStorageElemBitWidth(elementType); + if (elementBitWidth == 0) + return failure(); + constexpr int64_t kPhysicalVRegBits = 256 * 8; + if (kPhysicalVRegBits % elementBitWidth != 0) + return failure(); + return kPhysicalVRegBits / elementBitWidth; +} + +FailureOr mlir::pto::getMaskLanesPerPart(StringRef granularity) { + if (granularity == "b8") + return 256; + if (granularity == "b16") + return 128; + if (granularity == "b32") + return 64; + return failure(); +} + +FailureOr mlir::pto::getVMIPhysicalArity(Type type) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + + int64_t arity = 0; + for (int64_t part = 0; part < *factor; ++part) { + int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + arity += divideCeilNonNegative(lanesInPart, *lanesPerPart); + } + return arity; +} + +FailureOr +mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (logicalLane < 0 || logicalLane >= *elementCount) + return failure(); + + int64_t part = logicalLane % *factor; + int64_t indexInPart = logicalLane / *factor; + return VMIPhysicalLane{part, indexInPart / *lanesPerPart, + indexInPart % *lanesPerPart}; +} + +FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, + int64_t chunk, + int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + int64_t indexInPart = chunk * *lanesPerPart + lane; + int64_t logicalLane = indexInPart * *factor + part; + if (logicalLane >= *elementCount) + return failure(); + return logicalLane; +} + +FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, + int64_t chunk, int64_t lane) { + FailureOr elementCount = getVMIElementCount(type); + FailureOr factor = getLayoutFactor(type); + FailureOr lanesPerPart = getPhysicalLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + return failure(); + if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || + lane >= *lanesPerPart) + return failure(); + + int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + int64_t indexInPart = chunk * *lanesPerPart + lane; + return indexInPart >= lanesInPart; +} diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 3a6c04aed0..f2fe3ece10 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -35,6 +35,9 @@ add_mlir_dialect_library(PTOTransforms VPTOBufferMaterialization.cpp PTOValidateVPTOIR.cpp PTOUnrollSIMTForPass.cpp + PTOValidateVMIIR.cpp + VMILayoutAssignment.cpp + VMIToVPTO.cpp PTOInferVPTOVecScope.cpp InsertSync/PTOInsertSync.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp new file mode 100644 index 0000000000..6ce3e8eecd --- /dev/null +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -0,0 +1,445 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- PTOValidateVMIIR.cpp - VMI boundary verifier ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOVALIDATEVMIIR +#define GEN_PASS_DEF_PTOVALIDATEVMILAYOUTIR +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +bool isVMIType(Type type) { return isa(type); } + +bool isPhysicalVPTOType(Type type) { + return isa(type); +} + +bool containsVMIOrPhysicalType(Type type) { + if (isVMIType(type) || isPhysicalVPTOType(type)) + return true; + + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), [](Type input) { + return containsVMIOrPhysicalType(input); + }) || + llvm::any_of(functionType.getResults(), [](Type result) { + return containsVMIOrPhysicalType(result); + }); + } + + if (auto shapedType = dyn_cast(type)) + return containsVMIOrPhysicalType(shapedType.getElementType()); + + return false; +} + +bool containsVMIOrPhysicalType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIOrPhysicalType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, [](Attribute element) { + return containsVMIOrPhysicalType(element); + }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return containsVMIOrPhysicalType(namedAttr.getValue()); + }); + + return false; +} + +bool isSurfaceVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return !vregType.getLayout(); + if (auto maskType = dyn_cast(type)) + return maskType.isPred() && !maskType.getLayout(); + return false; +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return false; +} + +bool isVMIHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity" || + name == "pto.vmi.pack" || name == "pto.vmi.unpack"; +} + +bool isVMILayoutHelperOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "pto.vmi.ensure_layout" || + name == "pto.vmi.ensure_mask_layout" || + name == "pto.vmi.ensure_mask_granularity"; +} + +bool isVMISemanticOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name.starts_with("pto.vmi.") && !isVMIHelperOp(op); +} + +bool isStructuralOp(Operation *op) { + StringRef name = op->getName().getStringRef(); + return name == "builtin.module" || name.starts_with("func.") || + name.starts_with("scf.") || name.starts_with("cf."); +} + +bool hasVMIOrPhysicalType(Operation *op) { + auto hasInterestingType = [](Type type) { + return isVMIType(type) || isPhysicalVPTOType(type); + }; + if (llvm::any_of(op->getOperandTypes(), hasInterestingType) || + llvm::any_of(op->getResultTypes(), hasInterestingType)) + return true; + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + if (llvm::any_of(block.getArgumentTypes(), hasInterestingType)) + return true; + } + } + return false; +} + +void mirrorDiagnostic(llvm::raw_ostream *diagOS, Twine message) { + if (diagOS) + *diagOS << message << "\n"; +} + +LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagPassInvariantPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagPassInvariantPrefix) + message); + return failure(); +} + +LogicalResult verifyBoundaryType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isSurfaceVMIType(type)) + return emitInvariant( + owner, diagOS, + "VMI producer boundary requires surface !pto.vmi.vreg or " + "!pto.vmi.mask type"); + + return success(); +} + +LogicalResult verifyBoundaryTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyBoundaryType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyBoundaryTypeTree(owner, shapedType.getElementType(), diagOS); + + return success(); +} + +LogicalResult verifyLayoutAssignedType(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (isPhysicalVPTOType(type)) + return emitInvariant( + owner, diagOS, + "physical VPTO register type appears before VMI-to-VPTO"); + + if (isVMIType(type) && !isLayoutAssignedVMIType(type)) + return emitInvariant( + owner, diagOS, + "layout-assigned VMI IR requires !pto.vmi.vreg with layout and " + "!pto.vmi.mask with b8/b16/b32 granularity plus layout"); + + return success(); +} + +LogicalResult verifyLayoutAssignedTypeTree(Operation *owner, Type type, + llvm::raw_ostream *diagOS) { + if (failed(verifyLayoutAssignedType(owner, type, diagOS))) + return failure(); + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(owner, input, diagOS))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(owner, result, diagOS))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedTypeTree(owner, shapedType.getElementType(), + diagOS); + + return success(); +} + +template +LogicalResult verifyAttributeTypes(Operation *owner, Attribute attr, + llvm::raw_ostream *diagOS, + TypeVerifier verifyType) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typeAttr.getValue(), diagOS))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyType(owner, typedAttr.getType(), diagOS))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyAttributeTypes(owner, element, diagOS, verifyType))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyAttributeTypes(owner, namedAttr.getValue(), diagOS, + verifyType))) + return failure(); + } + + return success(); +} + +bool isFunctionTypeAttr(Operation *op, NamedAttribute attr) { + return isa(op) && attr.getName() == "function_type"; +} + +LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, + NamedAttribute attr, + llvm::raw_ostream *diagOS) { + if (isFunctionTypeAttr(op, attr)) + return success(); + if (containsVMIOrPhysicalType(attr.getValue())) + return emitInvariant( + op, diagOS, + "VMI or physical VPTO type appears in a non-signature attribute"); + return success(); +} + +LogicalResult verifyOperationTypes(Operation *op, llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyBoundaryTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyBoundaryTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, + llvm::raw_ostream *diagOS) { + if (auto funcOp = dyn_cast(op)) { + FunctionType functionType = funcOp.getFunctionType(); + for (Type type : functionType.getInputs()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : functionType.getResults()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Type type : block.getArgumentTypes()) { + if (failed(verifyLayoutAssignedTypeTree(op, type, diagOS))) + return failure(); + } + } + } + for (NamedAttribute attr : op->getAttrs()) { + if (failed(verifyNoHiddenVMIAttributeType(op, attr, diagOS))) + return failure(); + if (failed(verifyAttributeTypes(op, attr.getValue(), diagOS, + verifyLayoutAssignedTypeTree))) + return failure(); + } + return success(); +} + +LogicalResult verifyOperationBoundary(Operation *op, + llvm::raw_ostream *diagOS) { + if (failed(verifyOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) + return emitInvariant( + op, diagOS, + "VMI helper op appears before layout assignment or VMI-to-VPTO"); + + if (isVMISemanticOp(op) || isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +LogicalResult verifyLayoutAssignedOperation(Operation *op, + llvm::raw_ostream *diagOS) { + if (failed(verifyLayoutAssignedOperationTypes(op, diagOS))) + return failure(); + + if (!hasVMIOrPhysicalType(op)) + return success(); + + if (isVMIHelperOp(op)) { + if (isVMILayoutHelperOp(op)) + return success(); + return emitInvariant( + op, diagOS, + "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); + } + + if (isVMISemanticOp(op) || isStructuralOp(op)) + return success(); + + return emitInvariant(op, diagOS, + "VMI typed value is used by a non-VMI semantic op"); +} + +struct PTOValidateVMIIRPass + : public mlir::pto::impl::PTOValidateVMIIRBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMIIRPass) + + void runOnOperation() override { + if (failed(validateVMIProducerBoundaryIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +struct PTOValidateVMILayoutIRPass + : public mlir::pto::impl::PTOValidateVMILayoutIRBase< + PTOValidateVMILayoutIRPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMILayoutIRPass) + + void runOnOperation() override { + if (failed(validateVMILayoutAssignedIR(getOperation(), &llvm::errs()))) + signalPassFailure(); + } +}; + +} // namespace + +LogicalResult mlir::pto::validateVMIProducerBoundaryIR( + ModuleOp module, llvm::raw_ostream *diagOS) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyOperationBoundary(op, diagOS))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult mlir::pto::validateVMILayoutAssignedIR( + ModuleOp module, llvm::raw_ostream *diagOS) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyLayoutAssignedOperation(op, diagOS))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +std::unique_ptr mlir::pto::createPTOValidateVMIIRPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createPTOValidateVMILayoutIRPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp new file mode 100644 index 0000000000..e4d201d45c --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -0,0 +1,1330 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMILayoutAssignment.cpp - Assign VMI layouts ----------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTASSIGNMENT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct DataNode { + Value value; + VMIVRegType type; + unsigned parent = 0; + VMILayoutAttr naturalLayout; +}; + +struct MaskNode { + Value value; + VMIMaskType type; + unsigned parent = 0; + VMILayoutAttr requestedLayout; + std::string requestedGranularity; +}; + +struct DataUseRequest { + OpOperand *operand; + VMILayoutAttr layout; +}; + +struct MaskUseRequest { + OpOperand *operand; + VMILayoutAttr layout; + std::string granularity; +}; + +static unsigned getElementBitWidth(Type type) { + if (isa(type)) + return 64; + return pto::getPTOStorageElemBitWidth(type); +} + +static StringRef getMaskGranularityForElement(Type elementType) { + switch (getElementBitWidth(elementType)) { + case 8: + return "b8"; + case 16: + return "b16"; + case 32: + return "b32"; + default: + return ""; + } +} + +static bool isLane0SplatShuffle(VMIShuffleOp op) { + auto sourceType = cast(op.getSource().getType()); + ArrayRef indices = op.getIndices(); + return sourceType.getElementCount() == 1 && !indices.empty() && + llvm::all_of(indices, [](int64_t index) { return index == 0; }); +} + +bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), [](Type input) { + return containsVMIType(input); + }) || + llvm::any_of(functionType.getResults(), [](Type result) { + return containsVMIType(result); + }); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +struct LayoutSolver { + explicit LayoutSolver(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities) + : module(module), ctx(module.getContext()), capabilities(capabilities) {} + + unsigned addDataValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = dataIds.try_emplace(value, dataNodes.size()); + if (inserted) + dataNodes.push_back( + DataNode{value, type, it->second, type.getLayoutAttr()}); + return it->second; + } + + unsigned addMaskValue(Value value) { + auto type = dyn_cast(value.getType()); + if (!type) + return ~0u; + auto [it, inserted] = maskIds.try_emplace(value, maskNodes.size()); + if (inserted) { + std::string granularity; + if (VMIMaskType::isConcreteGranularity(type.getGranularity())) + granularity = type.getGranularity().str(); + maskNodes.push_back( + MaskNode{value, type, it->second, type.getLayoutAttr(), granularity}); + } + return it->second; + } + + unsigned find(unsigned id) { + if (dataNodes[id].parent == id) + return id; + dataNodes[id].parent = find(dataNodes[id].parent); + return dataNodes[id].parent; + } + + unsigned findMask(unsigned id) { + if (maskNodes[id].parent == id) + return id; + maskNodes[id].parent = findMask(maskNodes[id].parent); + return maskNodes[id].parent; + } + + LogicalResult unite(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addDataValue(lhs); + unsigned rhsId = addDataValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = find(lhsId); + unsigned rhsRoot = find(rhsId); + if (lhsRoot == rhsRoot) + return success(); + dataNodes[rhsRoot].parent = lhsRoot; + VMILayoutAttr lhsNatural = dataNodes[lhsRoot].naturalLayout; + VMILayoutAttr rhsNatural = dataNodes[rhsRoot].naturalLayout; + if (lhsNatural && rhsNatural && lhsNatural != rhsNatural) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << lhsNatural << " and " << rhsNatural; + if (!lhsNatural) + dataNodes[lhsRoot].naturalLayout = rhsNatural; + return success(); + } + + LogicalResult uniteMask(Value lhs, Value rhs, Operation *op) { + unsigned lhsId = addMaskValue(lhs); + unsigned rhsId = addMaskValue(rhs); + if (lhsId == ~0u || rhsId == ~0u) + return success(); + unsigned lhsRoot = findMask(lhsId); + unsigned rhsRoot = findMask(rhsId); + if (lhsRoot == rhsRoot) + return success(); + + MaskNode &lhsNode = maskNodes[lhsRoot]; + MaskNode &rhsNode = maskNodes[rhsRoot]; + if (lhsNode.requestedLayout && rhsNode.requestedLayout && + lhsNode.requestedLayout != rhsNode.requestedLayout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << lhsNode.requestedLayout << " and " << rhsNode.requestedLayout; + if (!lhsNode.requestedGranularity.empty() && + !rhsNode.requestedGranularity.empty() && + lhsNode.requestedGranularity != rhsNode.requestedGranularity) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << lhsNode.requestedGranularity << " and " + << rhsNode.requestedGranularity; + + rhsNode.parent = lhsRoot; + if (!lhsNode.requestedLayout) + lhsNode.requestedLayout = rhsNode.requestedLayout; + if (lhsNode.requestedGranularity.empty()) + lhsNode.requestedGranularity = rhsNode.requestedGranularity; + return success(); + } + + LogicalResult setNaturalLayout(Value value, VMILayoutAttr layout, + Operation *op) { + unsigned id = addDataValue(value); + if (id == ~0u || !layout) + return success(); + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " + << existing << " and " << layout; + dataNodes[root].naturalLayout = layout; + return success(); + } + + VMILayoutAttr getContiguousLayout() { + return VMILayoutAttr::getContiguous(ctx); + } + + VMILayoutAttr getDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + unsigned root = find(id); + if (dataNodes[root].naturalLayout) + return dataNodes[root].naturalLayout; + return getContiguousLayout(); + } + + LogicalResult requestMask(Value mask, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + unsigned id = addMaskValue(mask); + if (id == ~0u) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask layout or granularity"; + MaskNode &node = maskNodes[findMask(id)]; + if (node.requestedLayout && node.requestedLayout != layout) + return op->emitError() + << kVMIDiagLayoutContractPrefix << "conflicting mask layouts " + << node.requestedLayout << " and " << layout; + if (!node.requestedGranularity.empty() && + node.requestedGranularity != granularity) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << node.requestedGranularity << " and " << granularity; + node.requestedLayout = layout; + node.requestedGranularity = granularity.str(); + return success(); + } + + void requestDataUse(OpOperand &operand, VMILayoutAttr layout) { + if (isa(operand.get().getType())) + dataUseRequests.push_back(DataUseRequest{&operand, layout}); + } + + bool canAdoptConsumerRequestedLayout(Value value) { + if (!value.hasOneUse()) + return false; + Operation *definingOp = value.getDefiningOp(); + return definingOp && isa(definingOp); + } + + LogicalResult applyConsumerDrivenDataLayouts() { + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + if (!canAdoptConsumerRequestedLayout(value)) + continue; + unsigned id = addDataValue(value); + if (id == ~0u) + continue; + unsigned root = find(id); + VMILayoutAttr existing = dataNodes[root].naturalLayout; + if (existing && existing != request.layout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "conflicting natural layouts " + << existing << " and " << request.layout; + dataNodes[root].naturalLayout = request.layout; + } + return success(); + } + + LogicalResult requestMaskUse(OpOperand &operand, VMILayoutAttr layout, + StringRef granularity, Operation *op) { + if (!isa(operand.get().getType())) + return success(); + if (!layout || granularity.empty()) + return op->emitError() + << kVMIDiagLayoutContractPrefix + << "cannot infer concrete mask use layout or granularity"; + maskUseRequests.push_back( + MaskUseRequest{&operand, layout, granularity.str()}); + return success(); + } + + LogicalResult collect() { + module.walk([&](Operation *op) { + for (Value result : op->getResults()) { + addDataValue(result); + addMaskValue(result); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (BlockArgument arg : block.getArguments()) { + addDataValue(arg); + addMaskValue(arg); + } + }); + return success(); + } + + LogicalResult addConstraints() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto maskAnd = dyn_cast(op)) { + if (failed(uniteMask(maskAnd.getLhs(), maskAnd.getRhs(), op)) || + failed(uniteMask(maskAnd.getLhs(), maskAnd.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskOr = dyn_cast(op)) { + if (failed(uniteMask(maskOr.getLhs(), maskOr.getRhs(), op)) || + failed(uniteMask(maskOr.getLhs(), maskOr.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskXor = dyn_cast(op)) { + if (failed(uniteMask(maskXor.getLhs(), maskXor.getRhs(), op)) || + failed(uniteMask(maskXor.getLhs(), maskXor.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maskNot = dyn_cast(op)) { + if (failed(uniteMask(maskNot.getSource(), maskNot.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addf = dyn_cast(op)) { + if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || + failed(unite(addf.getLhs(), addf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto addi = dyn_cast(op)) { + if (failed(unite(addi.getLhs(), addi.getRhs(), op)) || + failed(unite(addi.getLhs(), addi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subf = dyn_cast(op)) { + if (failed(unite(subf.getLhs(), subf.getRhs(), op)) || + failed(unite(subf.getLhs(), subf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto subi = dyn_cast(op)) { + if (failed(unite(subi.getLhs(), subi.getRhs(), op)) || + failed(unite(subi.getLhs(), subi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto mulf = dyn_cast(op)) { + if (failed(unite(mulf.getLhs(), mulf.getRhs(), op)) || + failed(unite(mulf.getLhs(), mulf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto muli = dyn_cast(op)) { + if (failed(unite(muli.getLhs(), muli.getRhs(), op)) || + failed(unite(muli.getLhs(), muli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto fma = dyn_cast(op)) { + if (failed(unite(fma.getLhs(), fma.getRhs(), op)) || + failed(unite(fma.getLhs(), fma.getAcc(), op)) || + failed(unite(fma.getLhs(), fma.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto divf = dyn_cast(op)) { + if (failed(unite(divf.getLhs(), divf.getRhs(), op)) || + failed(unite(divf.getLhs(), divf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto minf = dyn_cast(op)) { + if (failed(unite(minf.getLhs(), minf.getRhs(), op)) || + failed(unite(minf.getLhs(), minf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto maxf = dyn_cast(op)) { + if (failed(unite(maxf.getLhs(), maxf.getRhs(), op)) || + failed(unite(maxf.getLhs(), maxf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto negf = dyn_cast(op)) { + if (failed(unite(negf.getSource(), negf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absf = dyn_cast(op)) { + if (failed(unite(absf.getSource(), absf.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto absi = dyn_cast(op)) { + if (failed(unite(absi.getSource(), absi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto sqrt = dyn_cast(op)) { + if (failed(unite(sqrt.getSource(), sqrt.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto exp = dyn_cast(op)) { + if (failed(unite(exp.getSource(), exp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ln = dyn_cast(op)) { + if (failed(unite(ln.getSource(), ln.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto relu = dyn_cast(op)) { + if (failed(unite(relu.getSource(), relu.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto andi = dyn_cast(op)) { + if (failed(unite(andi.getLhs(), andi.getRhs(), op)) || + failed(unite(andi.getLhs(), andi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ori = dyn_cast(op)) { + if (failed(unite(ori.getLhs(), ori.getRhs(), op)) || + failed(unite(ori.getLhs(), ori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto xori = dyn_cast(op)) { + if (failed(unite(xori.getLhs(), xori.getRhs(), op)) || + failed(unite(xori.getLhs(), xori.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shli = dyn_cast(op)) { + if (failed(unite(shli.getLhs(), shli.getRhs(), op)) || + failed(unite(shli.getLhs(), shli.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shrui = dyn_cast(op)) { + if (failed(unite(shrui.getLhs(), shrui.getRhs(), op)) || + failed(unite(shrui.getLhs(), shrui.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto notOp = dyn_cast(op)) { + if (failed(unite(notOp.getSource(), notOp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpf = dyn_cast(op)) { + if (failed(unite(cmpf.getLhs(), cmpf.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + if (failed(unite(cmpi.getLhs(), cmpi.getRhs(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + if (failed(unite(select.getTrueValue(), select.getFalseValue(), op)) || + failed(unite(select.getTrueValue(), select.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + if (failed(setNaturalLayout(activePrefix.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + requestDataUse(compress.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(compress.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + requestDataUse(reduce.getInitMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto extf = dyn_cast(op)) { + auto sourceType = cast(extf.getSource().getType()); + auto resultType = cast(extf.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extf.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extf.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto truncf = dyn_cast(op)) { + auto sourceType = cast(truncf.getSource().getType()); + auto resultType = cast(truncf.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 32 && resultBits == 16) + requestDataUse(truncf.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 2)); + else if (sourceBits == 32 && resultBits == 8) + requestDataUse(truncf.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 4)); + if (failed(setNaturalLayout(truncf.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto bitcast = dyn_cast(op)) { + if (failed(unite(bitcast.getSource(), bitcast.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto gather = dyn_cast(op)) { + auto resultType = cast(gather.getResult().getType()); + requestDataUse(gather.getIndicesMutable(), getContiguousLayout()); + requestDataUse(gather.getPassthruMutable(), getContiguousLayout()); + if (failed(requestMaskUse(gather.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(gather.getResult(), + getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + requestDataUse(load.getPassthruMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse(store.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto scatter = dyn_cast(op)) { + auto valueType = cast(scatter.getValue().getType()); + requestDataUse(scatter.getValueMutable(), getContiguousLayout()); + requestDataUse(scatter.getIndicesMutable(), getContiguousLayout()); + if (failed(requestMaskUse(scatter.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse(store.getMaskMutable(), + getContiguousLayout(), + getMaskGranularityForElement( + valueType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto tileWrite = dyn_cast(op)) { + requestDataUse(tileWrite.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + VMICapabilityResult capability = + capabilities.supportsChannelCount("pto.vmi.channel_split", + channels); + if (!capability.isSupported()) { + split.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + requestDataUse( + split.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, channels)); + for (Value result : split.getResults()) + if (failed(setNaturalLayout(result, getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + VMICapabilityResult capability = + capabilities.supportsChannelCount("pto.vmi.channel_merge", + channels); + if (!capability.isSupported()) { + merge.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; + return WalkResult::interrupt(); + } + for (OpOperand &input : merge.getInputsMutable()) + requestDataUse(input, getContiguousLayout()); + if (failed(setNaturalLayout( + merge.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, channels), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto shuffle = dyn_cast(op)) { + auto sourceType = cast(shuffle.getSource().getType()); + auto resultType = cast(shuffle.getResult().getType()); + if (sourceType.hasLayout() || resultType.hasLayout()) + return WalkResult::advance(); + + requestDataUse(shuffle.getSourceMutable(), getContiguousLayout()); + if (isLane0SplatShuffle(shuffle)) + return WalkResult::advance(); + if (failed(setNaturalLayout(shuffle.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto ifOp = dyn_cast(op)) { + if (failed(addIfConstraints(ifOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto executeRegionOp = dyn_cast(op)) { + if (failed(addExecuteRegionConstraints(executeRegionOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto indexSwitchOp = dyn_cast(op)) { + if (failed(addIndexSwitchConstraints(indexSwitchOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto whileOp = dyn_cast(op)) { + if (failed(addWhileConstraints(whileOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto forOp = dyn_cast(op)) { + if (failed(addForConstraints(forOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto branchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(branchOp.getDest(), + branchOp.getDestOperands(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto condBranchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(condBranchOp.getTrueDest(), + condBranchOp.getTrueDestOperands(), + op)) || + failed(addBranchConstraints(condBranchOp.getFalseDest(), + condBranchOp.getFalseDestOperands(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto switchOp = dyn_cast(op)) { + if (failed(addBranchConstraints(switchOp.getDefaultDestination(), + switchOp.getDefaultOperands(), op))) + return WalkResult::interrupt(); + for (auto [dest, operands] : + llvm::zip(switchOp.getCaseDestinations(), + switchOp.getCaseOperands())) { + if (failed(addBranchConstraints(dest, operands, op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto returnOp = dyn_cast(op)) { + if (failed(addReturnConstraints(returnOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto callOp = dyn_cast(op)) { + if (failed(addCallConstraints(callOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (op->getName().getStringRef() == "func.call_indirect") { + if (hasVMIValueTypes(op)) { + op->emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto funcOp = dyn_cast(op)) { + if (funcOp.empty() && hasVMIFunctionType(funcOp)) { + funcOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed function declaration requires an explicit " + "external ABI materialization plan"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult uniteEquivalentValues(Value lhs, Value rhs, Operation *op) { + if (failed(unite(lhs, rhs, op))) + return failure(); + return uniteMask(lhs, rhs, op); + } + + LogicalResult addIfConstraints(scf::IfOp ifOp) { + for (OpResult result : ifOp->getResults()) { + unsigned resultNo = result.getResultNumber(); + for (Region *region : {&ifOp.getThenRegion(), &ifOp.getElseRegion()}) { + if (region->empty()) + continue; + auto yieldOp = + dyn_cast(region->front().getTerminator()); + if (!yieldOp || resultNo >= yieldOp.getNumOperands()) + continue; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(resultNo), + ifOp))) + return failure(); + } + } + return success(); + } + + LogicalResult addYieldConstraints(ResultRange results, scf::YieldOp yieldOp, + Operation *op) { + for (auto [index, result] : llvm::enumerate(results)) { + if (index >= yieldOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, yieldOp.getOperand(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addExecuteRegionConstraints(scf::ExecuteRegionOp executeOp) { + WalkResult result = executeOp.getRegion().walk([&](scf::YieldOp yieldOp) { + if (yieldOp->getParentOp() != executeOp.getOperation()) + return WalkResult::advance(); + if (failed(addYieldConstraints(executeOp->getResults(), yieldOp, + executeOp))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + LogicalResult addIndexSwitchConstraints(scf::IndexSwitchOp indexSwitchOp) { + auto addBlockTerminator = [&](Block &block) -> LogicalResult { + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp) + return success(); + return addYieldConstraints(indexSwitchOp->getResults(), yieldOp, + indexSwitchOp); + }; + + if (failed(addBlockTerminator(indexSwitchOp.getDefaultBlock()))) + return failure(); + for (unsigned idx = 0, e = indexSwitchOp.getNumCases(); idx < e; ++idx) + if (failed(addBlockTerminator(indexSwitchOp.getCaseBlock(idx)))) + return failure(); + return success(); + } + + LogicalResult addWhileConstraints(scf::WhileOp whileOp) { + auto inits = whileOp.getInits(); + auto beforeArgs = whileOp.getBeforeArguments(); + Block &afterBlock = whileOp.getAfter().front(); + auto conditionOp = + dyn_cast(whileOp.getBefore().front().getTerminator()); + auto yieldOp = dyn_cast(afterBlock.getTerminator()); + + for (auto [index, init] : llvm::enumerate(inits)) { + Value anchor = init; + if (index < beforeArgs.size() && + failed(uniteEquivalentValues(anchor, beforeArgs[index], whileOp))) + return failure(); + if (conditionOp && index < conditionOp.getArgs().size() && + failed(uniteEquivalentValues(anchor, conditionOp.getArgs()[index], + whileOp))) + return failure(); + if (index < afterBlock.getNumArguments() && + failed(uniteEquivalentValues(anchor, afterBlock.getArgument(index), + whileOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), + whileOp))) + return failure(); + if (index < whileOp.getNumResults() && + failed(uniteEquivalentValues(anchor, whileOp.getResult(index), + whileOp))) + return failure(); + } + return success(); + } + + LogicalResult addForConstraints(scf::ForOp forOp) { + auto initArgs = forOp.getInitArgs(); + auto regionIterArgs = forOp.getRegionIterArgs(); + auto results = forOp.getResults(); + scf::YieldOp yieldOp = nullptr; + if (Block *body = forOp.getBody()) + yieldOp = dyn_cast(body->getTerminator()); + + for (auto [index, initArg] : llvm::enumerate(initArgs)) { + Value anchor = initArg; + if (index < regionIterArgs.size() && + failed(uniteEquivalentValues(anchor, regionIterArgs[index], forOp))) + return failure(); + if (index < results.size() && + failed(uniteEquivalentValues(anchor, results[index], forOp))) + return failure(); + if (yieldOp && index < yieldOp.getNumOperands() && + failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), + forOp))) + return failure(); + } + return success(); + } + + LogicalResult addBranchConstraints(Block *dest, OperandRange operands, + Operation *op) { + if (!dest) + return success(); + for (auto [index, operand] : llvm::enumerate(operands)) { + if (index >= dest->getNumArguments()) + break; + if (failed(uniteEquivalentValues(operand, dest->getArgument(index), op))) + return failure(); + } + return success(); + } + + LogicalResult addReturnConstraints(func::ReturnOp returnOp) { + auto func = returnOp->getParentOfType(); + if (!func) + return success(); + + auto it = firstReturnOperandsByFunc.find(func); + if (it == firstReturnOperandsByFunc.end()) { + SmallVector operands(returnOp.getOperands()); + firstReturnOperandsByFunc.try_emplace(func, std::move(operands)); + return success(); + } + + ArrayRef firstOperands = it->second; + for (auto [index, operand] : llvm::enumerate(returnOp.getOperands())) { + if (index >= firstOperands.size()) + break; + if (failed(uniteEquivalentValues(firstOperands[index], operand, returnOp))) + return failure(); + } + return success(); + } + + bool hasVMIValueTypes(Operation *op) { + return llvm::any_of(op->getOperandTypes(), containsVMIType) || + llvm::any_of(op->getResultTypes(), containsVMIType); + } + + bool hasVMIFunctionType(func::FuncOp func) { + FunctionType type = func.getFunctionType(); + return llvm::any_of(type.getInputs(), containsVMIType) || + llvm::any_of(type.getResults(), containsVMIType); + } + + LogicalResult addCallConstraints(func::CallOp callOp) { + if (!hasVMIValueTypes(callOp)) + return success(); + + auto callee = SymbolTable::lookupNearestSymbolFrom( + callOp, callOp.getCalleeAttr()); + if (!callee || callee.empty()) + return callOp.emitError() + << kVMIDiagLayoutContractPrefix + << "VMI typed call requires a direct internal callee with a body"; + + for (auto [operand, argument] : + llvm::zip(callOp.getOperands(), callee.getArguments())) { + if (failed(uniteEquivalentValues(operand, argument, callOp))) + return failure(); + } + + SmallVector returns; + callee.walk([&](func::ReturnOp returnOp) { returns.push_back(returnOp); }); + for (func::ReturnOp returnOp : returns) { + for (auto [index, result] : llvm::enumerate(callOp.getResults())) { + if (index >= returnOp.getNumOperands()) + break; + if (failed(uniteEquivalentValues(result, returnOp.getOperand(index), + callOp))) + return failure(); + } + } + return success(); + } + + void rewriteDataTypes() { + for (DataNode &node : dataNodes) { + VMILayoutAttr layout = getDataLayout(node.value); + node.value.setType(VMIVRegType::get(ctx, node.type.getElementCount(), + node.type.getElementType(), layout)); + } + } + + std::optional rematerializeDataUse(Value value, VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + if (auto constant = value.getDefiningOp()) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && denseAttr.isSplat()) + return builder.create(loc, resultType, + constant.getValue()) + .getResult(); + } + if (auto broadcast = value.getDefiningOp()) + return builder + .create(loc, resultType, broadcast.getValue()) + .getResult(); + if (auto iota = value.getDefiningOp()) + return builder.create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) + .getResult(); + return std::nullopt; + } + + LogicalResult insertDataUseMaterializations() { + OpBuilder builder(ctx); + for (DataUseRequest request : dataUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "data use materialization requires layout-assigned source " + "type"; + if (sourceLayout == request.layout) + continue; + + auto resultType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), request.layout); + builder.setInsertionPoint(request.operand->getOwner()); + std::optional rematerialized = + rematerializeDataUse(value, resultType, + request.operand->getOwner()->getLoc(), builder); + if (rematerialized) { + request.operand->set(*rematerialized); + continue; + } + auto ensure = builder.create( + request.operand->getOwner()->getLoc(), resultType, value); + request.operand->set(ensure.getResult()); + } + return success(); + } + + LogicalResult inferMaskRequests() { + WalkResult result = module.walk([&](Operation *op) -> WalkResult { + if (auto cmpf = dyn_cast(op)) { + auto lhsType = cast(cmpf.getLhs().getType()); + if (failed(requestMask(cmpf.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement( + lhsType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto cmpi = dyn_cast(op)) { + auto lhsType = cast(cmpi.getLhs().getType()); + if (failed(requestMask(cmpi.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement( + lhsType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto select = dyn_cast(op)) { + auto resultType = cast(select.getResult().getType()); + if (failed(requestMaskUse(select.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto activePrefix = dyn_cast(op)) { + auto resultType = + cast(activePrefix.getResult().getType()); + if (failed(requestMaskUse(activePrefix.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto compress = dyn_cast(op)) { + auto resultType = cast(compress.getResult().getType()); + if (failed(requestMaskUse(compress.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse(load.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(requestMaskUse(load.getMaskMutable(), + resultType.getLayoutAttr(), + getMaskGranularityForElement( + resultType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); + } + + void rewriteMaskTypes() { + for (MaskNode &node : maskNodes) { + MaskNode &root = maskNodes[findMask(maskIds.lookup(node.value))]; + VMILayoutAttr layout = root.requestedLayout ? root.requestedLayout + : getContiguousLayout(); + StringRef granularity = root.requestedGranularity.empty() + ? StringRef("b32") + : StringRef(root.requestedGranularity); + node.value.setType(VMIMaskType::get(ctx, node.type.getElementCount(), + granularity, layout)); + } + } + + std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, + Location loc, + OpBuilder &builder) { + if (auto createMask = value.getDefiningOp()) + return builder.create(loc, resultType, + createMask.getActiveLanes()) + .getResult(); + if (auto constantMask = value.getDefiningOp()) + return builder + .create(loc, resultType, + constantMask.getValueAttr()) + .getResult(); + return std::nullopt; + } + + LogicalResult insertMaskUseMaterializations() { + OpBuilder builder(ctx); + for (MaskUseRequest request : maskUseRequests) { + Value value = request.operand->get(); + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + continue; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return request.operand->getOwner()->emitError() + << kVMIDiagLayoutContractPrefix + << "mask use materialization requires layout-assigned source " + "type"; + + builder.setInsertionPoint(request.operand->getOwner()); + Value current = value; + VMIMaskType currentType = sourceType; + auto requestedType = VMIMaskType::get(ctx, sourceType.getElementCount(), + request.granularity, + request.layout); + if (sourceType != requestedType) { + std::optional rematerialized = rematerializeMaskUse( + value, requestedType, request.operand->getOwner()->getLoc(), + builder); + if (rematerialized) { + request.operand->set(*rematerialized); + continue; + } + } + + if (sourceLayout != request.layout) { + auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), + currentType.getGranularity(), + request.layout); + auto ensureLayout = builder.create( + request.operand->getOwner()->getLoc(), layoutType, current); + current = ensureLayout.getResult(); + currentType = layoutType; + } + + if (currentType.getGranularity() != request.granularity) { + auto granularityType = + VMIMaskType::get(ctx, currentType.getElementCount(), + request.granularity, request.layout); + auto ensureGranularity = + builder.create( + request.operand->getOwner()->getLoc(), granularityType, + current); + current = ensureGranularity.getResult(); + } + + if (current != value) + request.operand->set(current); + } + return success(); + } + + void rewriteFunctionType() { + module.walk([&](func::FuncOp func) { + if (func.empty()) + return; + + SmallVector inputs; + inputs.reserve(func.getNumArguments()); + for (BlockArgument arg : func.getArguments()) + inputs.push_back(arg.getType()); + + SmallVector results; + auto it = firstReturnOperandsByFunc.find(func); + if (it != firstReturnOperandsByFunc.end()) { + for (Value operand : it->second) + results.push_back(operand.getType()); + } else { + for (Type type : func.getFunctionType().getResults()) { + if (auto vregType = dyn_cast(type)) { + results.push_back(VMIVRegType::get(ctx, vregType.getElementCount(), + vregType.getElementType(), + getContiguousLayout())); + } else if (auto maskType = dyn_cast(type)) { + results.push_back(VMIMaskType::get(ctx, maskType.getElementCount(), + "b32", getContiguousLayout())); + } else { + results.push_back(type); + } + } + } + + func.setFunctionType(FunctionType::get(ctx, inputs, results)); + }); + } + + LogicalResult run() { + if (failed(collect())) + return failure(); + if (failed(addConstraints())) + return failure(); + if (failed(applyConsumerDrivenDataLayouts())) + return failure(); + rewriteDataTypes(); + if (failed(insertDataUseMaterializations())) + return failure(); + if (failed(inferMaskRequests())) + return failure(); + rewriteMaskTypes(); + if (failed(insertMaskUseMaterializations())) + return failure(); + rewriteFunctionType(); + return validateVMILayoutAssignedIR(module); + } + + ModuleOp module; + MLIRContext *ctx; + const VMITargetCapabilityRegistry &capabilities; + DenseMap dataIds; + DenseMap maskIds; + DenseMap> firstReturnOperandsByFunc; + SmallVector dataNodes; + SmallVector maskNodes; + SmallVector dataUseRequests; + SmallVector maskUseRequests; +}; + +struct VMILayoutAssignmentPass + : public mlir::pto::impl::VMILayoutAssignmentBase< + VMILayoutAssignmentPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutAssignmentPass) + + void runOnOperation() override { + VMITargetCapabilityRegistry capabilities; + if (failed(LayoutSolver(getOperation(), capabilities).run())) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutAssignmentPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp new file mode 100644 index 0000000000..db19c2846b --- /dev/null +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -0,0 +1,6269 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- VMIToVPTO.cpp - Convert VMI to physical VPTO IR -------------------===// +//===----------------------------------------------------------------------===// + +// https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933/8 +#pragma GCC diagnostic ignored "-Woverloaded-virtual" + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMITOVPTO +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +bool isVMIType(Type type) { return isa(type); } + +bool containsVMIType(Type type) { + if (isVMIType(type)) + return true; + + if (auto functionType = dyn_cast(type)) + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); + + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + + return false; +} + +bool hasVMIType(TypeRange types) { + return llvm::any_of(types, [](Type type) { return containsVMIType(type); }); +} + +bool hasVMIType(FunctionType type) { + return hasVMIType(type.getInputs()) || hasVMIType(type.getResults()); +} + +bool hasVMIType(Attribute attr) { + if (!attr) + return false; + + if (auto typeAttr = dyn_cast(attr)) + if (containsVMIType(typeAttr.getValue())) + return true; + + if (auto typedAttr = dyn_cast(attr)) + if (containsVMIType(typedAttr.getType())) + return true; + + if (auto arrayAttr = dyn_cast(attr)) + return llvm::any_of(arrayAttr, [](Attribute element) { + return hasVMIType(element); + }); + + if (auto dictAttr = dyn_cast(attr)) + return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { + return hasVMIType(namedAttr.getValue()); + }); + + return false; +} + +bool hasVMIType(Operation *op) { + if (auto func = dyn_cast(op)) + if (hasVMIType(func.getFunctionType())) + return true; + if (hasVMIType(op->getOperandTypes()) || hasVMIType(op->getResultTypes())) + return true; + for (Region ®ion : op->getRegions()) + for (Block &block : region) + if (hasVMIType(block.getArgumentTypes())) + return true; + for (NamedAttribute attr : op->getAttrs()) + if (hasVMIType(attr.getValue())) + return true; + return false; +} + +bool isVMIOp(Operation *op) { + return op->getName().getStringRef().starts_with("pto.vmi."); +} + +bool isLayoutAssignedVMIType(Type type) { + if (auto vregType = dyn_cast(type)) + return static_cast(vregType.getLayoutAttr()); + if (auto maskType = dyn_cast(type)) + return maskType.getLayoutAttr() && + VMIMaskType::isConcreteGranularity(maskType.getGranularity()); + return true; +} + +LogicalResult verifyLayoutAssignedVMITypeTree(Operation *op, Type type) { + if (!isLayoutAssignedVMIType(type)) + return op->emitError() + << kVMIDiagPassInvariantPrefix + << "vmi-to-vpto requires layout-assigned VMI types"; + + if (auto functionType = dyn_cast(type)) { + for (Type input : functionType.getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, input))) + return failure(); + for (Type result : functionType.getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, result))) + return failure(); + } + + if (auto shapedType = dyn_cast(type)) + return verifyLayoutAssignedVMITypeTree(op, shapedType.getElementType()); + + return success(); +} + +LogicalResult verifyVMIToVPTOInputAttribute(Operation *op, Attribute attr) { + if (!attr) + return success(); + + if (auto typeAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typeAttr.getValue()))) + return failure(); + + if (auto typedAttr = dyn_cast(attr)) + if (failed(verifyLayoutAssignedVMITypeTree(op, typedAttr.getType()))) + return failure(); + + if (auto arrayAttr = dyn_cast(attr)) { + for (Attribute element : arrayAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, element))) + return failure(); + } + + if (auto dictAttr = dyn_cast(attr)) { + for (NamedAttribute namedAttr : dictAttr) + if (failed(verifyVMIToVPTOInputAttribute(op, namedAttr.getValue()))) + return failure(); + } + + return success(); +} + +LogicalResult verifyVMIToVPTOInputTypes(Operation *op) { + for (Type type : op->getOperandTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : op->getResultTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + if (auto func = dyn_cast(op)) { + for (Type type : func.getFunctionType().getInputs()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (Type type : func.getFunctionType().getResults()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + } + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Type type : block.getArgumentTypes()) + if (failed(verifyLayoutAssignedVMITypeTree(op, type))) + return failure(); + for (NamedAttribute attr : op->getAttrs()) + if (failed(verifyVMIToVPTOInputAttribute(op, attr.getValue()))) + return failure(); + return success(); +} + +LogicalResult verifyVMIToVPTOInputIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (failed(verifyVMIToVPTOInputTypes(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +static std::optional materializeVPTOToVMI(OpBuilder &builder, + Type resultType, + ValueRange inputs, + Location loc) { + if (!isVMIType(resultType)) + return std::nullopt; + return builder.create(loc, resultType, inputs).getResult(); +} + +static std::optional> +materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + if (!isVMIType(input.getType())) + return std::nullopt; + auto unpackOp = builder.create(loc, resultTypes, input); + return SmallVector(unpackOp->getResults()); +} + +class VMIToVPTOTypeConverter final : public OneToNTypeConverter { +public: + VMIToVPTOTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](VMIVRegType type, SmallVectorImpl &results) + -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(arity) || failed(lanesPerPart)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(VRegType::get(type.getContext(), *lanesPerPart, + type.getElementType())); + return success(); + }); + addConversion([](VMIMaskType type, SmallVectorImpl &results) + -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(MaskType::get(type.getContext(), + type.getGranularity())); + return success(); + }); + TypeConverter::addSourceMaterialization(materializeVPTOToVMI); + TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); + OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); + } +}; + +FailureOr createAllTrueMaskForVReg(Location loc, VRegType vregType, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return rewriter + .create(loc, MaskType::get(ctx, "b8"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 16) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + if (elementBits == 32) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), + rewriter.getStringAttr("PAT_ALL")) + .getResult(); + return failure(); +} + +FailureOr getMaskTypeForVReg(VRegType vregType, + MLIRContext *ctx) { + unsigned elementBits = + pto::getPTOStorageElemBitWidth(vregType.getElementType()); + if (elementBits == 8) + return MaskType::get(ctx, "b8"); + if (elementBits == 16) + return MaskType::get(ctx, "b16"); + if (elementBits == 32) + return MaskType::get(ctx, "b32"); + return failure(); +} + +FailureOr createAllTrueMask(Location loc, MaskType maskType, + PatternRewriter &rewriter) { + StringAttr pattern = rewriter.getStringAttr("PAT_ALL"); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), pattern) + .getResult(); + if (maskType.isB16()) + return rewriter.create(loc, MaskType::get(ctx, "b16"), pattern) + .getResult(); + if (maskType.isB32()) + return rewriter.create(loc, MaskType::get(ctx, "b32"), pattern) + .getResult(); + return failure(); +} + +FailureOr createPatternMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr createPrefixMask(Location loc, MaskType maskType, + StringRef pattern, + PatternRewriter &rewriter) { + StringAttr patternAttr = rewriter.getStringAttr(pattern); + MLIRContext *ctx = rewriter.getContext(); + if (maskType.isB8()) + return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) + .getResult(); + if (maskType.isB16()) + return rewriter.create(loc, MaskType::get(ctx, "b16"), patternAttr) + .getResult(); + if (maskType.isB32()) + return rewriter.create(loc, MaskType::get(ctx, "b32"), patternAttr) + .getResult(); + return failure(); +} + +FailureOr> +createRuntimePrefixMask(Location loc, MaskType maskType, Value activeLanes, + PatternRewriter &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + Type scalarType = activeLanes.getType(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b8"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b16"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, MaskType::get(ctx, "b32"), + scalarType, activeLanes); + return std::make_pair(Value(op.getMask()), Value(op.getScalarOut())); + } + return failure(); +} + +LogicalResult checkSupportedMaskableVReg( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(type.getElementType(), + VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + return success(); +} + +LogicalResult checkSupportedTargetElementVReg( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + VMIElementPurpose purpose, StringRef elementContract, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(type.getElementType(), purpose); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +Value createI32Constant(Location loc, int64_t value, PatternRewriter &rewriter) { + return rewriter.create(loc, value, 32); +} + +Value clampDynamicActiveLanes(Location loc, Value activeLanes, + int64_t maxActiveLanes, + PatternRewriter &rewriter) { + Value activeI32 = + rewriter.create(loc, rewriter.getI32Type(), + activeLanes); + Value zeroI32 = createI32Constant(loc, 0, rewriter); + Value nonNegative = + rewriter.create(loc, activeI32, zeroI32); + Value maxI32 = createI32Constant(loc, maxActiveLanes, rewriter); + return rewriter.create(loc, nonNegative, maxI32); +} + +Value createPartitionActiveLanes(Location loc, Value activeLanesI32, + int64_t factor, int64_t part, + PatternRewriter &rewriter) { + if (factor == 1) + return activeLanesI32; + int64_t bias = factor - 1 - part; + Value biased = activeLanesI32; + if (bias != 0) + biased = + rewriter.create(loc, biased, + createI32Constant(loc, bias, rewriter)); + return rewriter.create( + loc, biased, createI32Constant(loc, factor, rewriter)); +} + +std::optional getPrefixPattern(int64_t activeLanes, + int64_t lanesPerPart) { + if (activeLanes <= 0) + return std::string("PAT_ALLF"); + if (activeLanes >= lanesPerPart) + return std::string("PAT_ALL"); + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return std::string("PAT_VL") + std::to_string(activeLanes); + default: + return std::nullopt; + } +} + +FailureOr getSingleValue(Operation *op, ValueRange values, + StringRef description, + PatternRewriter &rewriter) { + if (values.size() != 1) { + (void)rewriter.notifyMatchFailure(op, description); + return failure(); + } + return values.front(); +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + return (lhs + rhs - 1) / rhs; +} + +FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isContiguous() ? 1 : layout.getFactor(); +} + +FailureOr getDataChunksInPart(VMIVRegType type, int64_t part) { + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(factor) || failed(lanesPerPart) || part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = + (type.getElementCount() + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +FailureOr getDataFlatPartIndex(VMIVRegType type, int64_t part, + int64_t chunk) { + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor) || part < 0 || part >= *factor || chunk < 0) + return failure(); + + int64_t flatIndex = 0; + for (int64_t currentPart = 0; currentPart < part; ++currentPart) { + FailureOr chunks = getDataChunksInPart(type, currentPart); + if (failed(chunks)) + return failure(); + flatIndex += *chunks; + } + + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks) || chunk >= *chunks) + return failure(); + return flatIndex + chunk; +} + +FailureOr checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("requires assigned layout"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return *lanesPerPart; +} + +FailureOr getVMITypeLayoutFactor(Type type) { + Attribute layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayout(); + else + return failure(); + + auto layoutAttr = dyn_cast_or_null(layout); + if (!layoutAttr) + return failure(); + return layoutAttr.isContiguous() ? 1 : layoutAttr.getFactor(); +} + +FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +LogicalResult checkFullVMIPhysicalChunks(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason); + +LogicalResult checkSupportedLayoutMaterialization( + const VMITargetCapabilityRegistry &capabilities, Type sourceType, + Type resultType, VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMICapabilityResult layoutCapability = + capabilities.supportsLayoutConversion(sourceLayout, resultLayout, + Type{}); + if (!layoutCapability.isSupported()) + return fail(layoutCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + + "; source materialization " + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + + "; result materialization " + resultMaterializationReason); +} + +FailureOr getContiguousMaterializationPartCount(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + Attribute layoutAttr; + if (auto vregType = dyn_cast(type)) + layoutAttr = vregType.getLayout(); + else if (auto maskType = dyn_cast(type)) + layoutAttr = maskType.getLayout(); + else + return fail("requires VMI data or mask type"); + + auto layout = dyn_cast_or_null(layoutAttr); + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous()) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +LogicalResult checkCanMaterializeToContiguous(Type type, std::string *reason) { + return succeeded(getContiguousMaterializationPartCount(type, reason)) + ? success() + : failure(); +} + +std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + } + return std::nullopt; +} + +FailureOr getStaticMemRefElementCount(Type type) { + auto memrefType = dyn_cast(type); + if (!memrefType || !memrefType.hasStaticShape()) + return failure(); + + int64_t elements = 1; + for (int64_t dim : memrefType.getShape()) + elements *= dim; + return elements; +} + +enum class VMIMemoryValidMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryWriteMaskKind { + AllTrue, + ExplicitMask, +}; + +enum class VMIMemoryPermutationKind { + Identity, +}; + +enum class VMIMemoryFallbackDecisionKind { + NotRequired, + RequiredUnavailable, +}; + +struct VMIMemoryLogicalShape { + int64_t elementCount = 0; +}; + +struct VMIMemoryLaneAddressMap { + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + int64_t baseElementOffset = 0; + int64_t elementStride = 1; + int64_t physicalLaneFootprint = 0; + + int64_t getExclusiveEndElement() const { + return baseElementOffset + physicalLaneFootprint * elementStride; + } +}; + +struct VMIMemoryFallbackDecision { + VMIMemoryFallbackDecisionKind kind = + VMIMemoryFallbackDecisionKind::NotRequired; + std::string reason = "not required"; + + static VMIMemoryFallbackDecision notRequired() { return {}; } + + static VMIMemoryFallbackDecision requiredUnavailable(const Twine &reason) { + VMIMemoryFallbackDecision decision; + decision.kind = VMIMemoryFallbackDecisionKind::RequiredUnavailable; + decision.reason = reason.str(); + return decision; + } +}; + +struct VMIMemorySafeReadProof { + bool proven = false; + std::string reason; + std::optional constantOffset; + std::optional staticElementCount; + std::optional laneAddressMap; + int64_t physicalFootprint = 0; +}; + +struct VMIMemoryAccessPlan { + Type baseType; + VMIVRegType valueType; + std::optional constantOffset; + VMIMemoryLogicalShape logicalShape; + VMIMemoryValidMaskKind validMask = VMIMemoryValidMaskKind::AllTrue; + VMIMemoryPermutationKind permutation = VMIMemoryPermutationKind::Identity; + std::optional laneAddressMap; + Attribute paddingValue; + VMIMemoryWriteMaskKind writeMask = VMIMemoryWriteMaskKind::AllTrue; + VMIMemorySafeReadProof safeReadProof; + VMICapabilityResult targetCapability; + VMICapabilityResult trueMaskedLoadCapability; + VMICapabilityResult scratchFallbackCapability; + VMICapabilityResult guardedFallbackCapability; + VMIMemoryFallbackDecision fallbackDecision; +}; + +FailureOr +buildContiguousIdentityLaneAddressMap(int64_t constantOffset, + VMIVRegType resultType, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr lanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + FailureOr arity = getVMIPhysicalArity(resultType); + if (failed(lanesPerPart) || failed(arity)) + return fail("requires computable physical read footprint"); + + VMIMemoryLaneAddressMap map; + map.baseElementOffset = constantOffset; + map.physicalLaneFootprint = *arity * *lanesPerPart; + return map; +} + +VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, + StringRef role, + Value memoryValue = {}) { + auto memrefType = dyn_cast(memoryType); + if (!memrefType || memrefType.getLayout().isIdentity()) + return VMICapabilityResult::supported(); + std::string reason = + (Twine(role) + + " memref layout is non-identity; current VMI memory access plan " + "supports only contiguous identity lane-to-address maps") + .str(); + if (memoryValue && memoryValue.getDefiningOp()) + reason += "; memref.subview requires normalized base/offset/stride " + "lane-to-address planning"; + return VMICapabilityResult::missingCapability(reason); +} + +VMIMemorySafeReadProof +computeSafeFullReadProof(Type sourceType, std::optional constantOffset, + VMIVRegType resultType) { + VMIMemorySafeReadProof proof; + proof.constantOffset = constantOffset; + + auto fail = [&](const Twine &message) { + proof.proven = false; + proof.reason = message.str(); + return proof; + }; + + if (!constantOffset) + return fail("requires constant index offset"); + + FailureOr elements = getStaticMemRefElementCount(sourceType); + if (failed(elements)) + return fail("requires statically shaped memref source"); + proof.staticElementCount = *elements; + + if (*constantOffset < 0) + return fail("requires non-negative offset"); + + std::string addressMapReason; + FailureOr addressMap = + buildContiguousIdentityLaneAddressMap(*constantOffset, resultType, + &addressMapReason); + if (failed(addressMap)) + return fail(addressMapReason); + proof.laneAddressMap = *addressMap; + + proof.physicalFootprint = addressMap->physicalLaneFootprint; + if (addressMap->getExclusiveEndElement() > *elements) + return fail(Twine("full physical read footprint [") + + Twine(addressMap->baseElementOffset) + ", " + + Twine(addressMap->getExclusiveEndElement()) + + ") exceeds static memref element count " + Twine(*elements)); + + proof.proven = true; + return proof; +} + +VMIMemoryAccessPlan +buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value source, Type sourceType, VMIVRegType resultType, + std::optional constantOffset, + VMIMemoryValidMaskKind validMask) { + VMIMemoryAccessPlan plan; + plan.baseType = sourceType; + plan.valueType = resultType; + plan.constantOffset = constantOffset; + plan.logicalShape.elementCount = resultType.getElementCount(); + plan.validMask = validMask; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; + plan.safeReadProof = + computeSafeFullReadProof(sourceType, constantOffset, resultType); + plan.laneAddressMap = plan.safeReadProof.laneAddressMap; + plan.targetCapability = capabilities.supportsDirectMemory(sourceType, + "source"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = + requireIdentityMemRefLayout(sourceType, "source", source); + if (validMask == VMIMemoryValidMaskKind::ExplicitMask) + plan.trueMaskedLoadCapability = + capabilities.supportsTrueMaskedLoad(sourceType, resultType, Type{}); + plan.scratchFallbackCapability = + capabilities.supportsFallbackResource(VMIFallbackResourceKind::ScratchMemory); + plan.guardedFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::GuardedControlFlow); + return plan; +} + +VMIMemoryAccessPlan +buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value destination, Type destinationType, + VMIVRegType valueType, + VMIMemoryWriteMaskKind writeMask) { + VMIMemoryAccessPlan plan; + plan.baseType = destinationType; + plan.valueType = valueType; + plan.logicalShape.elementCount = valueType.getElementCount(); + plan.validMask = VMIMemoryValidMaskKind::AllTrue; + plan.permutation = VMIMemoryPermutationKind::Identity; + plan.writeMask = writeMask; + plan.targetCapability = + capabilities.supportsDirectMemory(destinationType, "destination"); + if (plan.targetCapability.isSupported()) + plan.targetCapability = + requireIdentityMemRefLayout(destinationType, "destination", + destination); + return plan; +} + +void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { + std::string maskedLoadReason; + if (plan.validMask == VMIMemoryValidMaskKind::ExplicitMask && + !plan.trueMaskedLoadCapability.isSupported()) + maskedLoadReason = + (Twine("; ") + plan.trueMaskedLoadCapability.reason).str(); + std::string scratchReason; + if (!plan.scratchFallbackCapability.isSupported()) + scratchReason = (Twine("; ") + plan.scratchFallbackCapability.reason).str(); + std::string guardedReason; + if (!plan.guardedFallbackCapability.isSupported()) + guardedReason = (Twine("; ") + plan.guardedFallbackCapability.reason).str(); + plan.fallbackDecision = VMIMemoryFallbackDecision::requiredUnavailable( + Twine("partial/tail read needs a scratch, guarded, or true " + "masked/non-faulting load fallback, but no such fallback resource " + "plan is implemented") + + maskedLoadReason + scratchReason + guardedReason); +} + +FailureOr +verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, + Type sourceType, Value offset, + PatternRewriter &rewriter) { + std::string fullChunkReason; + FailureOr lanesPerPart = + checkFullDataPhysicalChunks(type, &fullChunkReason); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + + VMIMemorySafeReadProof safeReadProof = + computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); + if (safeReadProof.proven) { + lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + } + + (void)rewriter.notifyMatchFailure( + op, Twine("memory lowering ") + fullChunkReason + + "; safe full-read proof failed: " + safeReadProof.reason); + return failure(); +} + +LogicalResult checkSupportedLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + Value source, Type sourceType, std::optional constantOffset, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMIMemoryAccessPlan accessPlan = + buildReadAccessPlan(capabilities, source, sourceType, type, + constantOffset, VMIMemoryValidMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return success(); + + if (accessPlan.safeReadProof.proven) + return success(); + requireUnavailableReadFallback(accessPlan); + return fail(Twine(fullChunkReason) + + "; safe-read proof failed: " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason); +} + +LogicalResult checkSupportedStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, + Value destination, Type destinationType, std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, type, + VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) + return failure(); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("requires assigned layout"); + if (failed(getDataLanesPerPart(type.getElementType()))) + return fail("requires known physical lanes per part"); + if (layout.isContiguous()) + return success(); + + std::string materializationReason; + if (succeeded(checkCanMaterializeToContiguous(type, &materializationReason))) + return success(); + return fail(Twine("partial/tail store requires contiguous layout or " + "deinterleaved layout that can materialize to contiguous; " + "value ") + + fullChunkReason + ", materialization " + + materializationReason); +} + +LogicalResult +checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIMaskedLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string fullChunkReason; + if (succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (accessPlan.safeReadProof.proven) + return success(); + requireUnavailableReadFallback(accessPlan); + return fail(Twine("partial/tail masked_load requires statically safe " + "full-read footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason); +} + +LogicalResult checkSupportedGatherShape( + const VMITargetCapabilityRegistry &capabilities, VMIGatherOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!resultLayout || !indicesLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, indices, passthru, and mask " + "layouts"); + if (!resultLayout.isContiguous() || !indicesLayout.isContiguous() || + !passthruLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous result, indices, passthru, and mask " + "layouts"); + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) + return fail(sourceCapability.reason); + + if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("currently requires 32-bit result element type so physical " + "offset and result lane counts match pto.vgather2_bc"); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return fail("requires signless or unsigned 32-bit indices"); + if (maskType.getGranularity() != "b32") + return fail("requires b32 mask granularity"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(indicesArity) || failed(passthruArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*resultArity != *indicesArity || *resultArity != *passthruArity || + *resultArity != *maskArity) + return fail("requires result, indices, passthru, and mask to have the " + "same physical arity"); + + std::string resultReason; + std::string indicesReason; + std::string passthruReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) + return fail(Twine("result requires full physical chunks; ") + + resultReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + + return success(); +} + +LogicalResult checkSupportedScatterShape( + const VMITargetCapabilityRegistry &capabilities, VMIScatterOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("indices_unique")) + return fail("requires indices_unique proof because pto.vscatter does not " + "define logical-lane-order duplicate-index semantics"); + + auto valueType = cast(op.getValue().getType()); + auto indicesType = cast(op.getIndices().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr indicesLayout = indicesType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !indicesLayout || !maskLayout) + return fail("requires assigned value, indices, and mask layouts"); + if (!valueLayout.isContiguous() || !indicesLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous value, indices, and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory( + op.getDestination().getType(), "destination", "pto.vscatter", + "pto.vscatter writes only UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + if (pto::getPTOStorageElemBitWidth(valueType.getElementType()) != 32) + return fail("currently requires 32-bit value element type so physical " + "index and value lane counts match pto.vscatter"); + auto indexElementType = dyn_cast(indicesType.getElementType()); + if (!indexElementType || indexElementType.getWidth() != 32 || + indexElementType.isSigned()) + return fail("requires signless or unsigned 32-bit indices"); + if (maskType.getGranularity() != "b32") + return fail("requires b32 mask granularity"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr indicesArity = getVMIPhysicalArity(indicesType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(indicesArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*valueArity != *indicesArity || *valueArity != *maskArity) + return fail("requires value, indices, and mask to have the same physical " + "arity"); + + std::string valueReason; + std::string indicesReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(valueType, &valueReason))) + return fail(Twine("value requires full physical chunks; ") + valueReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + + return success(); +} + +Value stripMaskMaterialization(Value value) { + while (true) { + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + if (auto ensure = value.getDefiningOp()) { + value = ensure.getSource(); + continue; + } + return value; + } +} + +bool isStaticAllActiveMask(Value mask, int64_t expectedLanes, + std::string *reason = nullptr) { + mask = stripMaskMaterialization(mask); + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return false; + }; + + if (auto createMask = mask.getDefiningOp()) { + auto activeConstant = + createMask.getActiveLanes().getDefiningOp(); + if (!activeConstant) + return fail("create_mask active_lanes is dynamic"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("create_mask active_lanes is not an integer constant"); + return activeAttr.getInt() >= expectedLanes + ? true + : fail("create_mask active_lanes is smaller than the logical " + "lane count"); + } + + if (auto constantMask = mask.getDefiningOp()) { + auto denseAttr = dyn_cast(constantMask.getValue()); + if (!denseAttr) + return fail("constant_mask is not a dense integer mask"); + if (denseAttr.getNumElements() != expectedLanes) + return fail("constant_mask element count does not match the logical " + "lane count"); + auto values = denseAttr.getValues(); + for (bool value : values) + if (!value) + return fail("constant_mask contains an inactive lane"); + return true; + } + + return fail("mask is not a static all-active create_mask or constant_mask"); +} + +LogicalResult +checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIExpandLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto passthruType = cast(op.getPassthru().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr passthruLayout = passthruType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), resultType, + getConstantIndexValue(op.getOffset()), + VMIMemoryValidMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (!resultLayout || !passthruLayout || !maskLayout) + return fail("requires assigned result, passthru, and mask layouts"); + if (!resultLayout.isContiguous() || !passthruLayout.isContiguous() || + !maskLayout.isContiguous()) + return fail("requires contiguous result, passthru, and mask layouts"); + + std::string maskReason; + bool staticAllActive = + isStaticAllActiveMask(op.getMask(), resultType.getElementCount(), + &maskReason); + + std::string fullChunkReason; + if (staticAllActive && + succeeded(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return success(); + + if (staticAllActive && accessPlan.safeReadProof.proven) + return success(); + + std::string allActivePathReason; + if (!staticAllActive) { + allActivePathReason = maskReason.empty() ? "requires static all-active mask" + : maskReason; + } else { + requireUnavailableReadFallback(accessPlan); + allActivePathReason = + (Twine("requires full physical chunks or statically safe full-read " + "footprint; value ") + + fullChunkReason + ", safe-read proof " + + accessPlan.safeReadProof.reason + "; fallback decision: " + + accessPlan.fallbackDecision.reason) + .str(); + } + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vgather2_bc", + "pto.vgather2_bc reads only UB"); + if (!sourceCapability.isSupported()) { + if (!isa(op.getSource().getType())) + return fail(Twine("runtime-mask path ") + sourceCapability.reason + + "; all-active path " + allActivePathReason); + return fail(Twine("runtime-mask path ") + sourceCapability.reason); + } + if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("runtime-mask path currently requires 32-bit result element " + "type so prefix indices and gather result lane counts match"); + if (maskType.getGranularity() != "b32") + return fail("runtime-mask path requires b32 mask granularity"); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr passthruArity = getVMIPhysicalArity(passthruType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(passthruArity) || failed(maskArity)) + return fail("runtime-mask path requires computable physical arity"); + if (*resultArity != 1 || *passthruArity != 1 || *maskArity != 1) + return fail("runtime-mask path currently supports only one physical " + "chunk because prefix indices must not reset across chunks"); + + std::string passthruReason; + std::string maskFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("runtime-mask result requires full physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("runtime-mask passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("runtime-mask mask requires full physical chunks; ") + + maskFullReason); + + return success(); +} + +LogicalResult checkSupportedMaskedStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType valueType, + VMIMaskType maskType, Value destination, Type destinationType, + std::string *reason) { + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, destination, destinationType, valueType, + VMIMemoryWriteMaskKind::ExplicitMask); + if (!accessPlan.targetCapability.isSupported()) { + if (reason) + *reason = accessPlan.targetCapability.reason; + return failure(); + } + + std::string valueReason; + std::string maskReason; + if (succeeded(checkFullDataPhysicalChunks(valueType, &valueReason)) && + succeeded(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return success(); + + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity) || *valueArity != *maskArity) + return fail("requires matching value/mask physical arity"); + + std::string valueMaterializationReason; + FailureOr valueParts = getContiguousMaterializationPartCount( + valueType, &valueMaterializationReason); + if (failed(valueParts)) + return fail(Twine("value cannot materialize to contiguous; value ") + + valueReason + ", materialization " + + valueMaterializationReason); + + std::string maskMaterializationReason; + FailureOr maskParts = getContiguousMaterializationPartCount( + maskType, &maskMaterializationReason); + if (failed(maskParts)) + return fail(Twine("mask cannot materialize to contiguous; mask ") + + maskReason + ", materialization " + + maskMaterializationReason); + if (*valueParts != *maskParts) + return fail("requires value/mask contiguous materialization arity to match"); + return success(); +} + +FailureOr getContiguousActiveDataLanes(VMIVRegType vmiType, + int64_t chunk) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + int64_t remaining = vmiType.getElementCount() - chunk * *lanesPerPart; + return std::clamp(remaining, 0, *lanesPerPart); +} + +FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, + int64_t chunk, VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = + getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return createAllTrueMaskForVReg(loc, vregType, rewriter); + + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return failure(); + FailureOr> maskAndRemaining = + createRuntimePrefixMask(loc, *maskType, + createI32Constant(loc, *activeLanes, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, + int64_t chunk, Value userMask, + VRegType vregType, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + FailureOr activeLanes = + getContiguousActiveDataLanes(vmiType, chunk); + if (failed(activeLanes)) + return failure(); + if (*activeLanes == *lanesPerPart) + return userMask; + + auto maskType = dyn_cast(userMask.getType()); + if (!maskType) + return failure(); + FailureOr tailMask = + createContiguousStoreMask(loc, vmiType, chunk, vregType, rewriter); + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(tailMask) || failed(allTrue)) + return failure(); + return rewriter.create(loc, maskType, userMask, *tailMask, *allTrue) + .getResult(); +} + +FailureOr> +computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector sourceFlatIndices; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding)) + return fail("failed to classify result padding lanes"); + if (*padding) + continue; + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, + lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + if (sourcePhysical->lane != lane) + return fail("requires same-lane physical chunks"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + continue; + } + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + } + + if (!sourcePart || !sourceChunk) + return fail("requires at least one logical lane per result chunk"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + sourceFlatIndices.push_back(*sourceFlatIndex); + } + } + + return sourceFlatIndices; +} + +struct ShuffleVselrPlan { + int64_t sourceFlatIndex = 0; + int64_t baseLane = 0; + bool descending = false; +}; + +FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + if (!llvm::all_of(indices, [](int64_t index) { return index == 0; })) + return fail("requires every result lane to select source lane 0"); + + auto sourceType = cast(op.getSource().getType()); + FailureOr sourceLane = + mapLogicalLaneToPhysical(sourceType, 0); + if (failed(sourceLane)) + return fail("failed to map source lane 0"); + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, sourceLane->part, sourceLane->chunk); + if (failed(sourceFlatIndex)) + return fail("source lane 0 part range is out of bounds"); + return *sourceFlatIndex; +} + +FailureOr> +computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known lanes per physical part"); + + ArrayRef indices = op.getIndices(); + if (indices.empty()) + return fail("requires non-empty indices"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires assigned result layout"); + + SmallVector plans; + for (int64_t resultPart = 0; resultPart < *resultFactor; ++resultPart) { + FailureOr resultChunks = + getDataChunksInPart(resultType, resultPart); + if (failed(resultChunks)) + return fail("requires known result physical chunks"); + + for (int64_t resultChunk = 0; resultChunk < *resultChunks; ++resultChunk) { + std::optional sourcePart; + std::optional sourceChunk; + std::optional baseLane; + std::optional descending; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultType, resultPart, resultChunk, lane); + if (failed(padding) || *padding) + return fail("requires full physical result chunks"); + + FailureOr resultLogicalLane = + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, + lane); + if (failed(resultLogicalLane) || + *resultLogicalLane >= static_cast(indices.size())) + return fail("failed to map result lane"); + + FailureOr sourcePhysical = + mapLogicalLaneToPhysical(sourceType, indices[*resultLogicalLane]); + if (failed(sourcePhysical)) + return fail("failed to map source lane"); + + if (!sourcePart) { + sourcePart = sourcePhysical->part; + sourceChunk = sourcePhysical->chunk; + baseLane = sourcePhysical->lane; + continue; + } + + if (*sourcePart != sourcePhysical->part || + *sourceChunk != sourcePhysical->chunk) + return fail("requires one source chunk per result chunk"); + + int64_t ascExpected = *baseLane + lane; + int64_t descExpected = *baseLane - lane; + bool asc = sourcePhysical->lane == ascExpected; + bool desc = sourcePhysical->lane == descExpected; + if (!asc && !desc) + return fail("requires ASC or DESC affine source lane indices"); + + bool laneDescending = desc && !asc; + if (!descending) { + descending = laneDescending; + continue; + } + if (*descending != laneDescending) + return fail("requires one index order per result chunk"); + } + + FailureOr sourceFlatIndex = + getDataFlatPartIndex(sourceType, *sourcePart, *sourceChunk); + if (failed(sourceFlatIndex)) + return fail("source part range is out of bounds"); + plans.push_back(ShuffleVselrPlan{*sourceFlatIndex, *baseLane, + descending.value_or(false)}); + } + } + + return plans; +} + +struct ConstantMaskChunkMaterialization { + SmallVector activeLanes; +}; + +FailureOr> +computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr) + return fail("only dense integer mask constants are supported"); + + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + auto boolValues = denseAttr.getValues(); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + materialization.activeLanes.push_back(boolValues[*logicalLane] ? 1 : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + +std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { + bool seenInactive = false; + int64_t activeCount = 0; + for (int8_t active : activeLanes) { + if (active) { + if (seenInactive) + return std::nullopt; + ++activeCount; + continue; + } + seenInactive = true; + } + return activeCount; +} + +FailureOr materializePrefixMask(Location loc, MaskType maskType, + int64_t activeLanes, + int64_t lanesPerPart, + PatternRewriter &rewriter) { + std::optional pattern = + getPrefixPattern(activeLanes, lanesPerPart); + if (pattern) + return createPatternMask(loc, maskType, *pattern, rewriter); + + FailureOr> maskAndRemaining = + createRuntimePrefixMask(loc, maskType, + createI32Constant(loc, activeLanes, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return failure(); + return maskAndRemaining->first; +} + +FailureOr +materializeConstantMaskChunk(Location loc, MaskType maskType, + ArrayRef activeLanes, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || + static_cast(activeLanes.size()) != *lanesPerPart) + return failure(); + + if (std::optional prefixCount = + getPrefixActiveLaneCount(activeLanes)) + return materializePrefixMask(loc, maskType, *prefixCount, *lanesPerPart, + rewriter); + + FailureOr allTrue = createAllTrueMask(loc, maskType, rewriter); + if (failed(allTrue)) + return failure(); + + Value result; + int64_t lane = 0; + while (lane < *lanesPerPart) { + while (lane < *lanesPerPart && !activeLanes[lane]) + ++lane; + if (lane >= *lanesPerPart) + break; + + int64_t runBegin = lane; + while (lane < *lanesPerPart && activeLanes[lane]) + ++lane; + int64_t runEnd = lane; + + FailureOr prefixEnd = + materializePrefixMask(loc, maskType, runEnd, *lanesPerPart, rewriter); + if (failed(prefixEnd)) + return failure(); + + Value runMask = *prefixEnd; + if (runBegin != 0) { + FailureOr prefixBegin = materializePrefixMask( + loc, maskType, runBegin, *lanesPerPart, rewriter); + if (failed(prefixBegin)) + return failure(); + Value notPrefixBegin = + rewriter.create(loc, maskType, *prefixBegin, *allTrue) + .getResult(); + runMask = + rewriter.create(loc, maskType, *prefixEnd, notPrefixBegin, + *allTrue) + .getResult(); + } + + if (!result) { + result = runMask; + continue; + } + result = rewriter.create(loc, maskType, result, runMask, *allTrue) + .getResult(); + } + + if (result) + return result; + return materializePrefixMask(loc, maskType, 0, *lanesPerPart, rewriter); +} + +Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return baseOffset; + Value delta = rewriter.create(loc, laneOffset); + return rewriter.create(loc, baseOffset, delta).getResult(); +} + +std::optional getX2MemoryDistToken(Type elementType, + StringRef prefix) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return std::nullopt; + return (Twine(prefix) + "_B" + Twine(elementBits)).str(); +} + +std::optional getVPTOCmpMode(StringRef predicate) { + if (predicate == "eq" || predicate == "ne" || predicate == "lt" || + predicate == "le" || predicate == "gt" || predicate == "ge") + return predicate; + if (predicate == "oeq") + return StringRef("eq"); + if (predicate == "one") + return StringRef("ne"); + if (predicate == "olt") + return StringRef("lt"); + if (predicate == "ole") + return StringRef("le"); + if (predicate == "ogt") + return StringRef("gt"); + if (predicate == "oge") + return StringRef("ge"); + if (predicate == "slt") + return StringRef("lt"); + if (predicate == "sle") + return StringRef("le"); + if (predicate == "sgt") + return StringRef("gt"); + if (predicate == "sge") + return StringRef("ge"); + return std::nullopt; +} + +LogicalResult checkSupportedComparePredicate(Operation *op, + StringRef predicate) { + if (getVPTOCmpMode(predicate)) + return success(); + return op->emitError() + << kVMIDiagUnsupportedPrefix << "compare predicate " << predicate + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms oeq/one/olt/ole/ogt/oge, " + "and signed integer forms slt/sle/sgt/sge"; +} + +struct OneToNVMIUnpackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIUnpackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + if (sourceParts.size() != op->getNumResults()) + return rewriter.notifyMatchFailure( + op, "converted source part count must match unpack results"); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIPackOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIPackOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr arity = getVMIPhysicalArity(op.getResult().getType()); + if (failed(arity) || + static_cast(adaptor.getFlatOperands().size()) != *arity) + return rewriter.notifyMatchFailure( + op, "pack part count must match converted VMI result arity"); + rewriter.replaceOp(op, adaptor.getFlatOperands(), + adaptor.getResultMapping()); + return success(); + } +}; + +LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "source and result physical arity mismatch"); + for (auto [part, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + if (part.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "helper requires non-identity physical materialization"); + } + return success(); +} + +FailureOr> +materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "layout materialization requires assigned source/result layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + resultLayout.isContiguous(); + bool contiguousToDeint2 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, + "deinterleaved=2 layout materialization requires 2*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i]); + results.append({materialize.getLow(), materialize.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto materialize = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1]); + part0.push_back(materialize.getLow()); + part1.push_back(materialize.getHigh()); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + resultLayout.isContiguous(); + bool contiguousToDeint4 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, + "deinterleaved=4 layout materialization requires 4*N parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + auto even = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2); + auto odd = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3); + auto low = + rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], even.getLow(), + odd.getLow()); + auto high = + rewriter.create(op->getLoc(), resultTypes[4 * i + 2], + resultTypes[4 * i + 3], even.getHigh(), + odd.getHigh()); + results.append( + {low.getLow(), low.getHigh(), high.getLow(), high.getHigh()}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + auto low = + rewriter.create(op->getLoc(), resultTypes[i], + resultTypes[groups + i], + sourceParts[4 * i], + sourceParts[4 * i + 1]); + auto high = rewriter.create( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3]); + auto even = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], + low.getLow(), high.getLow()); + auto odd = rewriter.create( + op->getLoc(), resultTypes[groups + i], + resultTypes[3 * groups + i], low.getHigh(), high.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI data layout materialization"); + return failure(); +} + +FailureOr> +createPredicateDintlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> +createPredicateIntlv(Location loc, Type lowType, Type highType, Value lhs, + Value rhs, PatternRewriter &rewriter) { + auto maskType = dyn_cast(lowType); + if (!maskType || highType != lowType) + return failure(); + if (maskType.isB8()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB16()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + if (maskType.isB32()) { + auto op = rewriter.create(loc, lowType, highType, lhs, rhs); + return std::make_pair(op.getLow(), op.getHigh()); + } + return failure(); +} + +FailureOr> +materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, + TypeRange resultTypes, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { + if (!sourceLayout || !resultLayout) { + (void)rewriter.notifyMatchFailure( + op, "mask layout materialization requires assigned source/result " + "layouts"); + return failure(); + } + + if (sourceLayout == resultLayout) { + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + return SmallVector(sourceParts.begin(), sourceParts.end()); + } + + bool deint2ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 2 && + resultLayout.isContiguous(); + bool contiguousToDeint2 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2; + if (deint2ToContiguous || contiguousToDeint2) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 2 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=2 mask layout materialization requires 2*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + int64_t groups = sourceParts.size() / 2; + SmallVector results; + results.reserve(sourceParts.size()); + if (deint2ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = + createPredicateIntlv(op->getLoc(), resultTypes[2 * i], + resultTypes[2 * i + 1], sourceParts[i], + sourceParts[groups + i], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({materialize->first, materialize->second}); + } + } else { + SmallVector part0; + SmallVector part1; + part0.reserve(groups); + part1.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> materialize = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[groups + i], sourceParts[2 * i], + sourceParts[2 * i + 1], rewriter); + if (failed(materialize)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(materialize->first); + part1.push_back(materialize->second); + } + results.append(part0); + results.append(part1); + } + return results; + } + + bool deint4ToContiguous = sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + resultLayout.isContiguous(); + bool contiguousToDeint4 = sourceLayout.isContiguous() && + resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4; + if (deint4ToContiguous || contiguousToDeint4) { + if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || + sourceParts.size() % 4 != 0) { + (void)rewriter.notifyMatchFailure( + op, "deinterleaved=4 mask layout materialization requires 4*N " + "parts"); + return failure(); + } + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + int64_t groups = sourceParts.size() / 4; + if (deint4ToContiguous) { + for (int64_t i = 0; i < groups; ++i) { + Value p0 = sourceParts[i]; + Value p1 = sourceParts[groups + i]; + Value p2 = sourceParts[2 * groups + i]; + Value p3 = sourceParts[3 * groups + i]; + FailureOr> even = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2, rewriter); + FailureOr> odd = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + FailureOr> low = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], even->first, + odd->first, rewriter); + FailureOr> high = + createPredicateIntlv(op->getLoc(), resultTypes[4 * i + 2], + resultTypes[4 * i + 3], even->second, + odd->second, rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate intlv mask type"); + results.append({low->first, low->second, high->first, high->second}); + } + } else { + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t i = 0; i < groups; ++i) { + FailureOr> low = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1], + rewriter); + FailureOr> high = createPredicateDintlv( + op->getLoc(), resultTypes[2 * groups + i], + resultTypes[3 * groups + i], sourceParts[4 * i + 2], + sourceParts[4 * i + 3], rewriter); + if (failed(low) || failed(high)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + FailureOr> even = + createPredicateDintlv(op->getLoc(), resultTypes[i], + resultTypes[2 * groups + i], low->first, + high->first, rewriter); + FailureOr> odd = + createPredicateDintlv(op->getLoc(), resultTypes[groups + i], + resultTypes[3 * groups + i], low->second, + high->second, rewriter); + if (failed(even) || failed(odd)) + return rewriter.notifyMatchFailure( + op, "unsupported predicate dintlv mask type"); + part0.push_back(even->first); + part1.push_back(odd->first); + part2.push_back(even->second); + part3.push_back(odd->second); + } + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + } + return results; + } + + (void)rewriter.notifyMatchFailure( + op, "unsupported VMI mask layout materialization"); + return failure(); +} + +int getMaskGranularityRank(StringRef granularity) { + if (granularity == "b8") + return 0; + if (granularity == "b16") + return 1; + if (granularity == "b32") + return 2; + return -1; +} + +StringRef getMaskGranularityForRank(int rank) { + switch (rank) { + case 0: + return "b8"; + case 1: + return "b16"; + case 2: + return "b32"; + default: + return ""; + } +} + +LogicalResult checkSupportedMaskGranularityMaterialization( + const VMITargetCapabilityRegistry &capabilities, VMIMaskType sourceType, + VMIMaskType resultType, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source and result mask lane counts to match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("requires source and result mask layouts to match"); + + VMICapabilityResult granularityCapability = + capabilities.supportsMaskGranularityConversion( + sourceType.getGranularity(), resultType.getGranularity()); + if (!granularityCapability.isSupported()) + return fail(granularityCapability.reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity < 1 || *resultArity < 1) + return fail("requires non-empty source/result physical arity"); + + return success(); +} + +FailureOr> materializeAdjacentMaskGranularityConversion( + Operation *op, VMIMaskType sourceType, VMIMaskType resultType, + ValueRange sourceParts, PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + int sourceRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + if (std::abs(sourceRank - resultRank) != 1) + return fail("mask granularity conversion must be adjacent"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr factor = getVMITypeLayoutFactor(sourceType); + if (failed(sourceArity) || failed(factor) || + static_cast(sourceParts.size()) != *sourceArity) + return fail("source mask part count does not match source VMI type"); + + MLIRContext *ctx = op->getContext(); + auto partAttr = [&](StringRef part) { return StringAttr::get(ctx, part); }; + auto resultMaskType = MaskType::get(ctx, resultType.getGranularity()); + SmallVector results; + + int64_t sourceOffset = 0; + for (int64_t part = 0; part < *factor; ++part) { + FailureOr sourceChunks = getVMITypeChunksInPart(sourceType, part); + FailureOr resultChunks = getVMITypeChunksInPart(resultType, part); + if (failed(sourceChunks) || failed(resultChunks)) + return fail("requires computable source/result chunks per layout part"); + + if (resultRank > sourceRank) { + int64_t produced = 0; + for (int64_t chunk = 0; chunk < *sourceChunks && produced < *resultChunks; + ++chunk) { + Value source = sourceParts[sourceOffset + chunk]; + results.push_back( + rewriter + .create(op->getLoc(), resultMaskType, source, + partAttr("LOWER")) + .getResult()); + ++produced; + if (produced >= *resultChunks) + break; + results.push_back( + rewriter + .create(op->getLoc(), resultMaskType, source, + partAttr("HIGHER")) + .getResult()); + ++produced; + } + if (produced != *resultChunks) + return fail("widening mask granularity conversion produced the wrong " + "number of result chunks"); + } else { + Value allTrue; + int64_t consumed = 0; + for (int64_t chunk = 0; chunk < *resultChunks; ++chunk) { + if (consumed >= *sourceChunks) + return fail("narrowing mask granularity conversion ran out of " + "source chunks"); + Value lowerSource = sourceParts[sourceOffset + consumed++]; + Value packed = + rewriter + .create(op->getLoc(), resultMaskType, lowerSource, + partAttr("LOWER")) + .getResult(); + if (consumed < *sourceChunks) { + Value higherSource = sourceParts[sourceOffset + consumed++]; + Value higher = + rewriter + .create(op->getLoc(), resultMaskType, higherSource, + partAttr("HIGHER")) + .getResult(); + if (!allTrue) { + FailureOr mask = + createAllTrueMask(op->getLoc(), resultMaskType, rewriter); + if (failed(mask)) + return fail("failed to create all-true mask for ppack merge"); + allTrue = *mask; + } + packed = rewriter + .create(op->getLoc(), resultMaskType, packed, + higher, allTrue) + .getResult(); + } + results.push_back(packed); + } + if (consumed != *sourceChunks) + return fail("narrowing mask granularity conversion left unused source " + "chunks"); + } + + sourceOffset += *sourceChunks; + } + + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(resultArity) || + static_cast(results.size()) != *resultArity) + return fail("mask granularity conversion result count mismatch"); + return results; +} + +FailureOr> materializeMaskGranularityConversion( + Operation *op, const VMITargetCapabilityRegistry &capabilities, + VMIMaskType sourceType, VMIMaskType resultType, ValueRange sourceParts, + PatternRewriter &rewriter) { + std::string reason; + if (failed(checkSupportedMaskGranularityMaterialization(capabilities, + sourceType, + resultType, &reason))) { + (void)rewriter.notifyMatchFailure(op, reason); + return failure(); + } + + int currentRank = getMaskGranularityRank(sourceType.getGranularity()); + int resultRank = getMaskGranularityRank(resultType.getGranularity()); + VMIMaskType currentType = sourceType; + SmallVector currentParts(sourceParts.begin(), sourceParts.end()); + + while (currentRank != resultRank) { + currentRank += currentRank < resultRank ? 1 : -1; + StringRef nextGranularity = getMaskGranularityForRank(currentRank); + if (nextGranularity.empty()) { + (void)rewriter.notifyMatchFailure(op, + "invalid target mask granularity rank"); + return failure(); + } + VMIMaskType nextType = + VMIMaskType::get(op->getContext(), currentType.getElementCount(), + nextGranularity, currentType.getLayoutAttr()); + FailureOr> nextParts = + materializeAdjacentMaskGranularityConversion( + op, currentType, nextType, currentParts, rewriter); + if (failed(nextParts)) + return failure(); + currentType = nextType; + currentParts = std::move(*nextParts); + } + + return currentParts; +} + +struct OneToNVMIEnsureLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return rewriter.notifyMatchFailure( + op, "ensure_layout requires assigned source/result layouts"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeDataLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskLayoutOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIEnsureMaskLayoutOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIEnsureMaskLayoutOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getGranularity() != resultType.getGranularity()) + return rewriter.notifyMatchFailure( + op, "mask layout helper cannot also change granularity"); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr> results = materializeMaskLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, resultLayout, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIEnsureMaskGranularityOpPattern + : OneToNOpConversionPattern { + OneToNVMIEnsureMaskGranularityOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, + context), + capabilities(capabilities) {} + + LogicalResult + matchAndRewrite(VMIEnsureMaskGranularityOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getLayout() != resultType.getLayout()) + return rewriter.notifyMatchFailure( + op, "mask granularity helper cannot also change layout"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceType.getGranularity() != resultType.getGranularity()) { + FailureOr> results = + materializeMaskGranularityConversion(op, capabilities, sourceType, + resultType, sourceParts, + rewriter); + if (failed(results)) + return failure(); + if (results->size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "mask granularity result arity mismatch"); + for (auto [result, type] : llvm::zip_equal(*results, resultTypes)) + if (result.getType() != type) + return rewriter.notifyMatchFailure( + op, "mask granularity result type mismatch"); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + + if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, + rewriter))) + return failure(); + rewriter.replaceOp(op, sourceParts, adaptor.getResultMapping()); + return success(); + } + +private: + const VMITargetCapabilityRegistry &capabilities; +}; + +struct OneToNVMIBroadcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange inputParts = adaptor.getValue(); + if (inputParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "broadcast input must convert to one value"); + bool inputIsVReg = isa(op.getValue().getType()); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "broadcast result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for broadcast mask"); + StringAttr position = + inputIsVReg ? rewriter.getStringAttr("LOWEST") : StringAttr{}; + results.push_back( + rewriter + .create(op.getLoc(), resultType, inputParts.front(), + *mask, position) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter) { + if (auto intType = dyn_cast(type)) { + return rewriter + .create(loc, IntegerAttr::get(intType, value)) + .getResult(); + } + if (auto floatType = dyn_cast(type)) { + return rewriter + .create( + loc, FloatAttr::get(floatType, + llvm::APFloat(static_cast(value)))) + .getResult(); + } + return failure(); +} + +FailureOr createIotaChunkBase(Location loc, Value base, + int64_t laneOffset, + StringRef order, + PatternRewriter &rewriter) { + if (laneOffset == 0) + return base; + + FailureOr offset = + createScalarOffsetConstant(loc, base.getType(), laneOffset, rewriter); + if (failed(offset)) + return failure(); + + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + if (isa(base.getType())) { + if (order == "DESC") + return rewriter.create(loc, base, *offset).getResult(); + return rewriter.create(loc, base, *offset).getResult(); + } + + return failure(); +} + +FailureOr createIotaContiguousChunk(Location loc, Type resultType, + Value base, int64_t laneOffset, + StringAttr orderAttr, + PatternRewriter &rewriter) { + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + FailureOr chunkBase = + createIotaChunkBase(loc, base, laneOffset, order, rewriter); + if (failed(chunkBase)) + return failure(); + return rewriter.create(loc, resultType, *chunkBase, orderAttr) + .getResult(); +} + +FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, + Value base, int64_t factor, + int64_t part, int64_t chunk, + int64_t lanesPerPart, + StringAttr orderAttr, + PatternRewriter &rewriter) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return failure(); + + FailureOr mask = createAllTrueMaskForVReg(loc, vregType, rewriter); + FailureOr zero = createScalarOffsetConstant(loc, base.getType(), 0, + rewriter); + FailureOr factorScalar = + createScalarOffsetConstant(loc, base.getType(), factor, rewriter); + if (failed(mask) || failed(zero) || failed(factorScalar)) + return failure(); + + Value local = + rewriter.create(loc, resultType, *zero, StringAttr{}).getResult(); + Value scaled = + rewriter.create(loc, resultType, local, *factorScalar, *mask) + .getResult(); + + StringRef order = orderAttr ? orderAttr.getValue() : StringRef("ASC"); + int64_t partOffset = part + factor * chunk * lanesPerPart; + FailureOr biasedBase = + createIotaChunkBase(loc, base, partOffset, order, rewriter); + if (failed(biasedBase)) + return failure(); + + if (order == "DESC") { + Value baseVector = + rewriter + .create(loc, resultType, *biasedBase, *mask, + /*position=*/nullptr) + .getResult(); + return rewriter.create(loc, resultType, baseVector, scaled, *mask) + .getResult(); + } + + return rewriter.create(loc, resultType, scaled, *biasedBase, *mask) + .getResult(); +} + +struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIIotaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout) + return rewriter.notifyMatchFailure(op, + "iota requires assigned layout"); + + FailureOr lanesPerPart = + getDataLanesPerPart(resultVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "iota requires known physical lanes per part"); + + FailureOr base = + getSingleValue(op, adaptor.getBase(), + "iota base must convert to one value", rewriter); + if (failed(base)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + + if (layout.isContiguous()) { + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, "iota result must be vreg"); + FailureOr result = createIotaContiguousChunk( + op.getLoc(), resultType, *base, + static_cast(index) * *lanesPerPart, op.getOrderAttr(), + rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize contiguous iota chunk"); + results.push_back(*result); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t factor = layout.getFactor(); + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "deinterleaved iota physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + FailureOr result = createIotaDeinterleavedChunk( + op.getLoc(), resultType, *base, factor, part, chunk, + *lanesPerPart, op.getOrderAttr(), rewriter); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to materialize deinterleaved iota chunk"); + results.push_back(*result); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr || !denseAttr.isSplat()) + return rewriter.notifyMatchFailure( + op, "only splat dense data constants are supported"); + auto splatAttr = dyn_cast(denseAttr.getSplatValue()); + if (!splatAttr) + return rewriter.notifyMatchFailure(op, + "splat constant must be typed"); + + Value scalar = + rewriter.create(op.getLoc(), splatAttr).getResult(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "constant result must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for constant mask"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, scalar, *mask, + /*position=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIConstantMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIConstantMaskOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIConstantMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> materializations = + computeConstantMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure( + op, Twine("constant_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "constant_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize constant_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "constant_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICreateMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto activeConstant = + op.getActiveLanes().getDefiningOp(); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !VMIMaskType::isConcreteGranularity( + resultVMIType.getGranularity())) + return rewriter.notifyMatchFailure( + op, "create_mask requires concrete layout and granularity"); + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "create_mask requires known physical mask lanes per part"); + + if (!activeConstant) { + FailureOr active = getSingleValue( + op, adaptor.getActiveLanes(), + "create_mask active_lanes must convert to one value", rewriter); + if (failed(active)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + if (resultTypes.size() % factor != 0) + return rewriter.notifyMatchFailure( + op, "dynamic create_mask physical result count does not match " + "layout factor"); + int64_t chunksPerPart = resultTypes.size() / factor; + Value activeI32 = clampDynamicActiveLanes( + op.getLoc(), *active, resultVMIType.getElementCount(), rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t part = 0; part < factor; ++part) { + Value remaining = createPartitionActiveLanes( + op.getLoc(), activeI32, factor, part, rewriter); + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + Type resultType = resultTypes[part * chunksPerPart + chunk]; + auto maskType = dyn_cast(resultType); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_mask result must be mask"); + FailureOr> maskAndRemaining = + createRuntimePrefixMask(op.getLoc(), maskType, remaining, + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for dynamic create_mask"); + results.push_back(maskAndRemaining->first); + remaining = maskAndRemaining->second; + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return rewriter.notifyMatchFailure( + op, "create_mask active_lanes must be an integer constant"); + + int64_t activeLanes = activeAttr.getInt(); + if (activeLanes < 0) + activeLanes = 0; + if (activeLanes > resultVMIType.getElementCount()) + activeLanes = resultVMIType.getElementCount(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + SmallVector results; + results.reserve(resultTypes.size()); + + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + int64_t activeInChunk = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical padding lane"); + if (*padding) + continue; + anyLane = true; + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return rewriter.notifyMatchFailure( + op, "failed to map create_mask physical lane"); + if (*logicalLane < activeLanes) + ++activeInChunk; + } + if (!anyLane) + break; + + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure(op, + "create_mask result must be mask"); + std::optional pattern = + getPrefixPattern(activeInChunk, *lanesPerPart); + if (pattern) { + FailureOr mask = + createPrefixMask(op.getLoc(), maskType, *pattern, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask"); + results.push_back(*mask); + continue; + } + + FailureOr> maskAndRemaining = + createRuntimePrefixMask( + op.getLoc(), maskType, + createI32Constant(op.getLoc(), activeInChunk, rewriter), + rewriter); + if (failed(maskAndRemaining)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for create_mask plt fallback"); + results.push_back(maskAndRemaining->first); + } + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMILoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "load source must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "load offset must convert to one value", rewriter); + if (failed(source) || failed(offset)) + return failure(); + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { + int64_t groups = resultTypes.size() / 2; + SmallVector lows; + SmallVector highs; + lows.reserve(groups); + highs.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type lowType = resultTypes[group]; + Type highType = resultTypes[groups + group]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "vldsx2 requires matching low/high result types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + SmallVector results; + results.reserve(resultTypes.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + SmallVector contiguousParts; + contiguousParts.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "load result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + contiguousParts.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + FailureOr> results = materializeDataLayoutConversion( + op, contiguousParts, resultTypes, + VMILayoutAttr::getContiguous(rewriter.getContext()), + resultVMIType.getLayoutAttr(), rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIMaskedLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIMaskedLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "masked_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "masked_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != passthruParts.size() || + passthruParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "masked_load physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, maskPassthruAndType] : + llvm::enumerate(llvm::zip_equal(maskParts, passthruParts, + resultTypes))) { + auto [mask, passthru, resultType] = maskPassthruAndType; + if (!isa(mask.getType()) || passthru.getType() != resultType || + !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "masked_load physical part type mismatch"); + + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value loaded = + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, loaded, passthru, mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGatherOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "gather source must convert to one value", rewriter); + if (failed(source)) + return failure(); + + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (indicesParts.size() != maskParts.size() || + indicesParts.size() != passthruParts.size() || + indicesParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "gather physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [indices, mask, passthru, resultType] : + llvm::zip_equal(indicesParts, maskParts, passthruParts, + resultTypes)) { + if (!isa(indices.getType()) || !isa(mask.getType()) || + passthru.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure(op, + "gather physical part type mismatch"); + + Value gathered = + rewriter + .create(op.getLoc(), resultType, *source, indices, + mask) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, gathered, passthru, + mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExpandLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIExpandLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExpandLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "expand_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "expand_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (isStaticAllActiveMask(op.getMask(), resultVMIType.getElementCount())) { + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), *offset, rewriter); + if (failed(lanesPerPart)) + return failure(); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure( + op, "expand_load result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + ValueRange maskParts = adaptor.getMask(); + ValueRange passthruParts = adaptor.getPassthru(); + if (resultTypes.size() != 1 || maskParts.size() != 1 || + passthruParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "runtime expand_load supports only one physical chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || passthruParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "runtime expand_load requires physical result/passthru/mask"); + + auto baseType = dyn_cast((*source).getType()); + if (!baseType) + return rewriter.notifyMatchFailure(op, + "runtime expand_load requires ptr"); + Value gatherBase = + rewriter + .create(op.getLoc(), (*source).getType(), *source, + *offset) + .getResult(); + auto indexType = + VRegType::get(rewriter.getContext(), resultType.getElementCount(), + rewriter.getI32Type()); + FailureOr indexSeedMask = + createAllTrueMaskForVReg(op.getLoc(), indexType, rewriter); + if (failed(indexSeedMask)) + return rewriter.notifyMatchFailure( + op, "failed to create runtime expand_load index seed mask"); + Value zero = rewriter.create(op.getLoc(), 0, 32); + Value carrier = + rewriter + .create(op.getLoc(), indexType, zero, *indexSeedMask, + /*position=*/nullptr) + .getResult(); + Value indices = + rewriter + .create(op.getLoc(), indexType, carrier, + maskParts.front()) + .getResult(); + Value gathered = + rewriter + .create(op.getLoc(), resultType, gatherBase, indices, + maskParts.front()) + .getResult(); + Value result = + rewriter + .create(op.getLoc(), resultType, gathered, + passthruParts.front(), maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "store requires known physical lanes per part"); + bool fullPhysicalChunks = + succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "store destination must convert to one value", rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "store offset must convert to one value", rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && + valueLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); + if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { + int64_t groups = valueParts.size() / 2; + for (int64_t group = 0; group < groups; ++group) { + Value low = valueParts[group]; + Value high = valueParts[groups + group]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "vstsx2 requires matching low/high value types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousTypes; + contiguousTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousTypes.push_back(value.getType()); + + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + for (auto [index, value] : llvm::enumerate(*storeParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + if (!fullPhysicalChunks) { + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute store active lanes"); + if (*activeLanes == 0) + continue; + } + FailureOr mask = fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), + vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), + valueVMIType, + index, vregType, + rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIMaskedStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIMaskedStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIMaskedStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "masked_store requires known physical lanes per part"); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "masked_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "masked_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "masked_store value/mask physical arity mismatch"); + + SmallVector contiguousValueTypes; + contiguousValueTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousValueTypes.push_back(value.getType()); + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousValueTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + auto maskVMIType = cast(op.getMask().getType()); + SmallVector contiguousMaskTypes; + contiguousMaskTypes.reserve(maskParts.size()); + for (Value mask : maskParts) + contiguousMaskTypes.push_back(mask.getType()); + FailureOr> storeMasks = materializeMaskLayoutConversion( + op, maskParts, contiguousMaskTypes, maskVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeMasks)) + return failure(); + + if (storeParts->size() != storeMasks->size()) + return rewriter.notifyMatchFailure( + op, "masked_store converted value/mask arity mismatch"); + + for (auto [index, valueAndMask] : + llvm::enumerate(llvm::zip_equal(*storeParts, *storeMasks))) { + auto [value, mask] = valueAndMask; + auto vregType = dyn_cast(value.getType()); + if (!vregType || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "masked_store converted parts must be vreg/mask"); + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute masked_store active lanes"); + if (*activeLanes == 0) + continue; + FailureOr storeMask = createMaskedStorePredicate( + op.getLoc(), valueVMIType, index, mask, vregType, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize masked_store predicate"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *storeMask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "scatter destination must convert to one value", + rewriter); + if (failed(destination)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != indicesParts.size() || + valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure(op, + "scatter physical arity mismatch"); + + for (auto [value, indices, mask] : + llvm::zip_equal(valueParts, indicesParts, maskParts)) { + if (!isa(value.getType()) || + !isa(indices.getType()) || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "scatter physical part type mismatch"); + rewriter.create(op.getLoc(), value, *destination, indices, + mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMITileReadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITileReadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "tile_read source must convert to one value", rewriter); + if (failed(source)) + return failure(); + + Value zero = rewriter.create(op.getLoc(), 0); + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( + op, resultVMIType, (*source).getType(), zero, rewriter); + if (failed(lanesPerPart)) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { + int64_t groups = resultTypes.size() / 2; + SmallVector lows; + SmallVector highs; + lows.reserve(groups); + highs.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type lowType = resultTypes[group]; + Type highType = resultTypes[groups + group]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "vldsx2 requires matching low/high result types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + SmallVector results; + results.reserve(resultTypes.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + + SmallVector contiguousParts; + contiguousParts.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, "tile_read result must be vreg"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, index * *lanesPerPart, rewriter); + contiguousParts.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + FailureOr> results = materializeDataLayoutConversion( + op, contiguousParts, resultTypes, + VMILayoutAttr::getContiguous(rewriter.getContext()), + resultVMIType.getLayoutAttr(), rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITileWriteOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITileWriteOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "tile_write requires known physical lanes per part"); + bool fullPhysicalChunks = + succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "tile_write destination must convert to one value", rewriter); + if (failed(destination)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + Value zero = rewriter.create(op.getLoc(), 0); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && + valueLayout.getFactor() == 2) { + std::optional dist = + getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); + if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { + int64_t groups = valueParts.size() / 2; + for (int64_t group = 0; group < groups; ++group) { + Value low = valueParts[group]; + Value high = valueParts[groups + group]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "vstsx2 requires matching low/high value types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "tile_write value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for tile_write mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + rewriter.eraseOp(op); + return success(); + } + } + + SmallVector contiguousTypes; + contiguousTypes.reserve(valueParts.size()); + for (Value value : valueParts) + contiguousTypes.push_back(value.getType()); + + FailureOr> storeParts = materializeDataLayoutConversion( + op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), + VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); + if (failed(storeParts)) + return failure(); + + for (auto [index, value] : llvm::enumerate(*storeParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "tile_write value must be vreg"); + if (!fullPhysicalChunks) { + FailureOr activeLanes = + getContiguousActiveDataLanes(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute tile_write active lanes"); + if (*activeLanes == 0) + continue; + } + FailureOr mask = fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), + vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), + valueVMIType, + index, vregType, + rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for tile_write mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), zero, index * *lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +template +struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "physical binary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true binary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs, rhs, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIFmaOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIFmaOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + ValueRange accParts = adaptor.getAcc(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != accParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "fma physical arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, acc, resultType] : + llvm::zip_equal(lhsParts, rhsParts, accParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || lhs.getType() != resultType || + rhs.getType() != resultType || acc.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "fma requires matching physical vreg parts"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "unsupported element type for fma"); + results.push_back( + rewriter.create(op.getLoc(), resultType, acc, lhs, rhs, + *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType || source.getType() != resultType) + return rewriter.notifyMatchFailure(op, + "physical unary part type mismatch"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for all-true unary mask"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask binary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != resultType || + rhs.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask binary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask binary seed"); + results.push_back( + rewriter.create(op.getLoc(), resultType, lhs, rhs, + *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIMaskUnaryOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical mask unary arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || source.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "physical mask unary part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true mask unary seed"); + results.push_back( + rewriter.create(op.getLoc(), resultType, source, *seedMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< + SourceOp>::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + std::optional cmpMode = getVPTOCmpMode(op.getPredicate()); + if (!cmpMode) + return op.emitOpError() + << kVMIDiagUnsupportedPrefix << "compare predicate " + << op.getPredicate() + << " cannot be lowered to pto.vcmp; supported predicates are " + "eq/ne/lt/le/gt/ge, ordered FP forms " + "oeq/one/olt/ole/ogt/oge, and signed integer forms " + "slt/sle/sgt/sge"; + + ValueRange lhsParts = adaptor.getLhs(); + ValueRange rhsParts = adaptor.getRhs(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (lhsParts.size() != rhsParts.size() || + lhsParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical cmp arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [lhs, rhs, resultType] : + llvm::zip_equal(lhsParts, rhsParts, resultTypes)) { + auto maskType = dyn_cast(resultType); + if (!maskType || lhs.getType() != rhs.getType() || + !isa(lhs.getType())) + return rewriter.notifyMatchFailure(op, + "physical cmp part type mismatch"); + FailureOr seedMask = + createAllTrueMask(op.getLoc(), maskType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported mask type for all-true cmp seed"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, lhs, rhs, *seedMask, + rewriter.getStringAttr(*cmpMode)) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMISelectOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMISelectOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + ValueRange trueParts = adaptor.getTrueValue(); + ValueRange falseParts = adaptor.getFalseValue(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != trueParts.size() || + trueParts.size() != falseParts.size() || + trueParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, "physical select arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [mask, trueValue, falseValue, resultType] : + llvm::zip_equal(maskParts, trueParts, falseParts, resultTypes)) { + if (!isa(mask.getType()) || trueValue.getType() != resultType || + falseValue.getType() != resultType || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical select part type mismatch"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, trueValue, falseValue, + mask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIActivePrefixIndexOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIActivePrefixIndexOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIActivePrefixIndexOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (maskParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "active_prefix_index supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires physical vreg/mask parts"); + + auto intType = dyn_cast(resultType.getElementType()); + if (!intType || !intType.isSignless()) + return rewriter.notifyMatchFailure( + op, "active_prefix_index requires signless integer result part"); + + FailureOr seedMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(seedMask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for active_prefix_index seed mask"); + + Value zero = rewriter.create( + op.getLoc(), 0, intType.getWidth()); + Value carrier = + rewriter + .create(op.getLoc(), resultType, zero, *seedMask, + /*position=*/nullptr) + .getResult(); + Value result = + rewriter + .create(op.getLoc(), resultType, carrier, + maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != 1 || maskParts.size() != 1 || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress supports only one physical part"); + + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType || sourceParts.front().getType() != resultType || + !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "compress requires physical source/mask/result parts"); + + Value result = + rewriter + .create(op.getLoc(), resultType, sourceParts.front(), + maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{result}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMICompressStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICompressStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICompressStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "compress_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "compress_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "compress_store supports only one physical part"); + + auto valueType = dyn_cast(valueParts.front().getType()); + if (!valueType || !isa(maskParts.front().getType()) || + !isa((*destination).getType())) + return rewriter.notifyMatchFailure( + op, "compress_store requires physical value/mask and ptr " + "destination"); + + Value storeBase = + rewriter + .create(op.getLoc(), (*destination).getType(), + *destination, *offset) + .getResult(); + Value squeezed = + rewriter + .create(op.getLoc(), valueType, valueParts.front(), + maskParts.front()) + .getResult(); + auto align = + rewriter.create(op.getLoc(), + AlignType::get(rewriter.getContext())); + auto store = rewriter.create( + op.getLoc(), align.getResult().getType(), align.getResult(), squeezed, + storeBase, rewriter.getStringAttr("POST_UPDATE")); + rewriter.create(op.getLoc(), store.getAlignOut(), storeBase); + rewriter.eraseOp(op); + return success(); + } +}; + +struct OneToNVMIReduceAddIOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIReduceAddIOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addi requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addi first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIReduceAddFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIReduceAddFOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIReduceAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching source/mask chunks and one " + "init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires matching physical source/init/result " + "vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every source chunk to match result " + "vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "reduce_addf requires every mask chunk to have the same " + "predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create reduce_addf first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +template +struct OneToNVMIReduceMinMaxFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + ValueRange initParts = adaptor.getInit(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || sourceParts.size() != maskParts.size() || + initParts.size() != 1 || resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching source/mask chunks " + "and one init/result chunk"); + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType || initParts.front().getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires matching physical source/" + "init/result vregs and one mask"); + + for (Value sourcePart : sourceParts) + if (sourcePart.getType() != resultType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every source chunk to " + "match result vreg type"); + for (Value maskPart : maskParts) + if (maskPart.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "floating min/max reduction requires every mask chunk to have " + "the same predicate type"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create floating min/max reduction first-lane mask"); + + Value accumulator = initParts.front(); + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { + Value reduced = + rewriter.create(op.getLoc(), resultType, sourcePart, + maskPart) + .getResult(); + accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, accumulator, + *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{accumulator}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIExtFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "extf requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure(op, "expected physical extf source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "extf source physical parts must have matching type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + (resultVRegTypes.empty() + ? !resultVRegType.getElementType().isF32() + : resultVRegType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical extf result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + ArrayRef parts; + int64_t factor = 0; + if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (sourceBits == 8 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical extf source/result width relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, + "failed to build extf seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is supported"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + auto resultType = dyn_cast(resultTypes.front()); + if (!sourceType0 || !sourceType0.getElementType().isF32() || !resultType) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "truncf source physical parts must have matching f32 type"); + } + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + ArrayRef parts; + if (sourceParts.size() == 2 && resultBits == 16) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + } else if (sourceParts.size() == 4 && resultBits == 8) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result width relation"); + } + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(sourceMask) || failed(resultMask)) + return rewriter.notifyMatchFailure(op, + "failed to build truncf masks"); + + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector partials; + partials.reserve(parts.size()); + for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *sourceMask, + rnd, sat, rewriter.getStringAttr(part)) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = + rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + + rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIBitcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIBitcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "physical bitcast arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + if (!isa(sourcePart.getType()) || !isa(resultType)) + return rewriter.notifyMatchFailure( + op, "physical bitcast part type mismatch"); + results.push_back( + rewriter.create(op.getLoc(), resultType, sourcePart) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelSplitOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIChannelSplitOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelSplitOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getNumResults(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_split only supports 2 or 4 channels"); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!sourceLayout || + (!sourceLayout.isContiguous() && sourceLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_split requires contiguous or matching deinterleaved source " + "layout"); + for (Value result : op.getResults()) { + auto resultType = cast(result.getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_split requires contiguous result layouts"); + } + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(); + FailureOr> results = materializeDataLayoutConversion( + op, adaptor.getSource(), resultTypes, sourceLayout, channelLayout, + rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIChannelMergeOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIChannelMergeOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIChannelMergeOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + int64_t channels = op.getInputs().size(); + if (channels != 2 && channels != 4) + return rewriter.notifyMatchFailure( + op, "channel_merge only supports 2 or 4 channels"); + + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return rewriter.notifyMatchFailure( + op, "channel_merge requires contiguous input layouts"); + } + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + auto channelLayout = + VMILayoutAttr::getDeinterleaved(rewriter.getContext(), channels); + if (!resultLayout || + (!resultLayout.isContiguous() && resultLayout != channelLayout)) + return rewriter.notifyMatchFailure( + op, + "channel_merge requires contiguous or matching deinterleaved result " + "layout"); + + FailureOr> results = materializeDataLayoutConversion( + op, adaptor.getFlatOperands(), + adaptor.getResultMapping().getConvertedTypes(0), channelLayout, + resultLayout, rewriter); + if (failed(results)) + return failure(); + + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIShuffleOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> sourceFlatIndices = + computeShuffleForwardingSourceParts(op, &reason); + if (succeeded(sourceFlatIndices)) { + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t sourceFlatIndex : *sourceFlatIndices) { + if (sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle forwarding source part range is out of bounds"); + results.push_back(sourceParts[sourceFlatIndex]); + } + + if (failed(verifyIdentityPartForwarding(op, results, resultTypes, + rewriter))) + return failure(); + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string splatReason; + FailureOr splatSource = + computeShuffleLane0SplatSourcePart(op, &splatReason); + if (succeeded(splatSource)) { + if (*splatSource >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat source part range is out of bounds"); + + SmallVector results; + results.reserve(resultTypes.size()); + Value sourcePart = sourceParts[*splatSource]; + for (Type resultType : resultTypes) { + auto sourceVRegType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType != resultVRegType) + return rewriter.notifyMatchFailure( + op, "shuffle lane0 splat requires matching physical vreg type"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), resultVRegType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create shuffle lane0 splat mask"); + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + rewriter.getStringAttr("LOWEST")) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + std::string vselrReason; + FailureOr> vselrPlans = + computeShuffleVselrPlans(op, &vselrReason); + if (failed(vselrPlans)) + return rewriter.notifyMatchFailure( + op, Twine("shuffle vselr ") + vselrReason); + + if (vselrPlans->size() != resultTypes.size()) + return rewriter.notifyMatchFailure(op, + "shuffle vselr arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [plan, resultType] : llvm::zip_equal(*vselrPlans, resultTypes)) { + if (plan.sourceFlatIndex >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source part range is out of bounds"); + + auto sourceVRegType = + dyn_cast(sourceParts[plan.sourceFlatIndex].getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceVRegType || !resultVRegType || + sourceVRegType.getElementCount() != + resultVRegType.getElementCount() || + sourceVRegType.getElementType() != resultVRegType.getElementType()) + return rewriter.notifyMatchFailure( + op, "shuffle vselr source/result type mismatch"); + + unsigned indexBits = + pto::getPTOStorageElemBitWidth(sourceVRegType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "shuffle vselr requires 8/16/32-bit index elements"); + + auto indexElementType = + IntegerType::get(rewriter.getContext(), indexBits); + Type indexType = + VRegType::get(rewriter.getContext(), + sourceVRegType.getElementCount(), indexElementType); + FailureOr base = createScalarOffsetConstant( + op.getLoc(), indexElementType, plan.baseLane, rewriter); + if (failed(base)) + return rewriter.notifyMatchFailure( + op, "failed to materialize shuffle vselr index base"); + StringAttr orderAttr = + plan.descending ? rewriter.getStringAttr("DESC") : StringAttr{}; + Value indexVector = + rewriter.create(op.getLoc(), indexType, *base, orderAttr) + .getResult(); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + sourceParts[plan.sourceFlatIndex], indexVector) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +Block *convertBranchDestBlock(Block *block, OneToNPatternRewriter &rewriter, + OneToNTypeConverter &typeConverter, + llvm::DenseMap &converted) { + auto [it, inserted] = converted.try_emplace(block, nullptr); + if (!inserted) + return it->second; + + OneToNTypeMapping argMapping(block->getArgumentTypes()); + if (failed(typeConverter.computeTypeMapping(block->getArgumentTypes(), + argMapping)) || + !argMapping.hasNonIdentityConversion()) { + it->second = block; + return block; + } + + Block *newBlock = rewriter.applySignatureConversion(block, argMapping); + it->second = newBlock; + return newBlock; +} + +struct OneToNCFBranchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *dest = convertBranchDestBlock(op.getDest(), rewriter, *converter, + convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + dest == op.getDest()) + return failure(); + + rewriter.replaceOpWithNewOp(op, dest, + adaptor.getFlatOperands()); + return success(); + } +}; + +struct OneToNCFCondBranchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *trueDest = convertBranchDestBlock(op.getTrueDest(), rewriter, + *converter, convertedBlocks); + Block *falseDest = convertBranchDestBlock(op.getFalseDest(), rewriter, + *converter, convertedBlocks); + + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && + trueDest == op.getTrueDest() && falseDest == op.getFalseDest()) + return failure(); + + ValueRange condition = adaptor.getCondition(); + if (condition.size() != 1) + return rewriter.notifyMatchFailure( + op, "condition converted to multiple values"); + + SmallVector trueOperands; + SmallVector falseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getNumTrueOperands(); i < e; ++i) + llvm::append_range( + trueOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + for (unsigned i = 0, e = op.getNumFalseOperands(); i < e; ++i) + llvm::append_range( + falseOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + + rewriter.replaceOpWithNewOp( + op, condition.front(), trueDest, trueOperands, falseDest, + falseOperands); + return success(); + } +}; + +struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + llvm::DenseMap convertedBlocks; + Block *defaultDest = + convertBranchDestBlock(op.getDefaultDestination(), rewriter, + *converter, convertedBlocks); + + SmallVector caseDests; + caseDests.reserve(op.getCaseDestinations().size()); + for (Block *dest : op.getCaseDestinations()) + caseDests.push_back( + convertBranchDestBlock(dest, rewriter, *converter, convertedBlocks)); + + bool changed = defaultDest != op.getDefaultDestination(); + for (auto [oldDest, newDest] : + llvm::zip(op.getCaseDestinations(), caseDests)) + changed |= oldDest != newDest; + changed |= adaptor.getOperandMapping().hasNonIdentityConversion(); + if (!changed) + return failure(); + + ValueRange flag = adaptor.getFlag(); + if (flag.size() != 1) + return rewriter.notifyMatchFailure(op, "flag converted to multiple values"); + + SmallVector defaultOperands; + SmallVector> caseOperandStorage; + SmallVector caseOperands; + ValueRange flatOperands = adaptor.getFlatOperands(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + + unsigned operandIndex = 1; + for (unsigned i = 0, e = op.getDefaultOperands().size(); i < e; ++i) + llvm::append_range( + defaultOperands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + + caseOperandStorage.reserve(op.getCaseOperandSegments().size()); + caseOperands.reserve(op.getCaseOperandSegments().size()); + for (int32_t segmentSize : op.getCaseOperandSegments()) { + SmallVector operands; + for (int32_t i = 0; i < segmentSize; ++i) + llvm::append_range( + operands, + operandMapping.getConvertedValues(flatOperands, operandIndex++)); + caseOperandStorage.push_back(std::move(operands)); + } + for (SmallVector &operands : caseOperandStorage) + caseOperands.push_back(operands); + + rewriter.replaceOpWithNewOp( + op, flag.front(), defaultDest, defaultOperands, op.getCaseValuesAttr(), + caseDests, caseOperands); + return success(); + } +}; + +struct OneToNSCFExecuteRegionOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + scf::ExecuteRegionOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ExecuteRegionOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = + rewriter.create(op.getLoc(), resultTypes); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +struct OneToNSCFIndexSwitchOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IndexSwitchOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange arg = adaptor.getArg(); + if (arg.size() != 1) + return rewriter.notifyMatchFailure( + op, "index_switch selector converted to multiple values"); + + SmallVector resultTypes; + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) + llvm::append_range(resultTypes, resultMapping.getConvertedTypes(i)); + if (resultTypes == op->getResultTypes()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), resultTypes, arg.front(), op.getCases(), + op.getNumCases()); + newOp->setAttrs(op->getAttrs()); + rewriter.inlineRegionBefore(op.getDefaultRegion(), + newOp.getDefaultRegion(), + newOp.getDefaultRegion().end()); + for (auto [srcRegion, dstRegion] : + llvm::zip(op.getCaseRegions(), newOp.getCaseRegions())) + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.end()); + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +void populateVMIOneToNConversionPatterns( + VMIToVPTOTypeConverter &typeConverter, RewritePatternSet &patterns, + const VMITargetCapabilityRegistry &capabilities) { + populateFuncTypeConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); + patterns + .add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); + patterns.add, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskUnaryOpPattern, + OneToNVMILoadOpPattern, + OneToNVMIMaskedLoadOpPattern, + OneToNVMIGatherOpPattern, + OneToNVMIExpandLoadOpPattern, + OneToNVMIStoreOpPattern, + OneToNVMIMaskedStoreOpPattern, + OneToNVMIScatterOpPattern, + OneToNVMITileReadOpPattern, + OneToNVMITileWriteOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIFmaOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMICmpOpPattern, + OneToNVMICmpOpPattern, + OneToNVMISelectOpPattern, + OneToNVMIActivePrefixIndexOpPattern, + OneToNVMICompressOpPattern, + OneToNVMICompressStoreOpPattern, + OneToNVMIReduceAddIOpPattern, + OneToNVMIReduceAddFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIExtFOpPattern, + OneToNVMITruncFOpPattern, + OneToNVMIBitcastOpPattern, + OneToNVMIChannelSplitOpPattern, + OneToNVMIChannelMergeOpPattern, + OneToNVMIShuffleOpPattern>(typeConverter, + patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), capabilities); +} + +LogicalResult verifyNoResidualVMIIR(ModuleOp module) { + WalkResult result = module.walk([&](Operation *op) { + if (isa(op)) { + op->emitError() + << kVMIDiagResidualOpPrefix + << "unrealized conversion cast remains after vmi-to-vpto"; + return WalkResult::interrupt(); + } + if (auto createMask = dyn_cast(op)) { + if (!createMask.getActiveLanes().getDefiningOp()) { + createMask.emitError() + << kVMIDiagUnsupportedPrefix + << "dynamic pto.vmi.create_mask active_lanes could not be lowered " + "by the current runtime predicate generation plan"; + return WalkResult::interrupt(); + } + } + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + } + if (isVMIOp(op) || hasVMIType(op)) { + op->emitError() + << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +LogicalResult checkSupportedExtFShape(VMIExtFOp op) { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity) || !sourceLayout.isContiguous() || + !resultLayout.isDeinterleaved() || + !resultType.getElementType().isF32()) + return failure(); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + if (sourceBits == 16 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return success(); + if (sourceBits == 8 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return success(); + return failure(); +} + +LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity) || !sourceLayout.isDeinterleaved() || + !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || + *resultArity != 1) + return failure(); + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return success(); + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) + return success(); + return failure(); +} + +FailureOr> +getPhysicalLogicalBitFootprint(VMIVRegType type) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (elementBits == 0) + return failure(); + + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(factor) || failed(lanesPerPart)) + return failure(); + + SmallVector bits; + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getDataChunksInPart(type, part); + if (failed(chunks)) + return failure(); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + int64_t activeLanes = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++activeLanes; + } + bits.push_back(activeLanes * static_cast(elementBits)); + } + } + return bits; +} + +LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source and result layouts"); + if (sourceLayout != resultLayout) + return fail("requires matching source and result layouts"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source and result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + FailureOr> sourceBits = + getPhysicalLogicalBitFootprint(sourceType); + FailureOr> resultBits = + getPhysicalLogicalBitFootprint(resultType); + if (failed(sourceBits) || failed(resultBits)) + return fail("requires computable physical logical bit footprints"); + if (sourceBits->size() != resultBits->size()) + return fail("requires source and result physical footprint counts to " + "match"); + for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { + if (source != result) + return fail("requires matching logical bit footprint in every physical " + "chunk"); + } + + return success(); +} + +LogicalResult checkSupportedChannelSplitShape( + const VMITargetCapabilityRegistry &capabilities, VMIChannelSplitOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getNumResults(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_split", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + auto sourceType = cast(op.getSource().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return fail("requires assigned source layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) + return fail("requires source layout to be contiguous or matching " + "deinterleaved channel layout"); + + for (Value result : op.getResults()) { + VMILayoutAttr resultLayout = + cast(result.getType()).getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return fail("requires every result layout to be contiguous"); + } + + auto channelType = + VMIVRegType::get(op.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), expectedLayout); + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, sourceType, channelType, sourceLayout, expectedLayout, + &materializationReason))) + return fail(Twine("cannot materialize source to channel layout; ") + + materializationReason); + + FailureOr channelArity = getVMIPhysicalArity(channelType); + int64_t resultArity = 0; + for (Value result : op.getResults()) { + FailureOr arity = + getVMIPhysicalArity(cast(result.getType())); + if (failed(arity)) + return fail("requires computable result physical arity"); + resultArity += *arity; + } + if (failed(channelArity) || *channelArity != resultArity) + return fail("requires channel physical arity to match all result parts"); + + return success(); +} + +LogicalResult checkSupportedChannelMergeShape( + const VMITargetCapabilityRegistry &capabilities, VMIChannelMergeOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + int64_t channels = op.getInputs().size(); + VMICapabilityResult channelCapability = + capabilities.supportsChannelCount("pto.vmi.channel_merge", channels); + if (!channelCapability.isSupported()) + return fail(channelCapability.reason); + + int64_t inputArity = 0; + for (Value input : op.getInputs()) { + auto inputType = cast(input.getType()); + VMILayoutAttr inputLayout = inputType.getLayoutAttr(); + if (!inputLayout || !inputLayout.isContiguous()) + return fail("requires every input layout to be contiguous"); + FailureOr arity = getVMIPhysicalArity(inputType); + if (failed(arity)) + return fail("requires computable input physical arity"); + inputArity += *arity; + } + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + auto expectedLayout = + VMILayoutAttr::getDeinterleaved(op.getContext(), channels); + if (!resultLayout.isContiguous() && resultLayout != expectedLayout) + return fail("requires result layout to be contiguous or matching " + "deinterleaved channel layout"); + + auto channelType = + VMIVRegType::get(op.getContext(), resultType.getElementCount(), + resultType.getElementType(), expectedLayout); + FailureOr channelArity = getVMIPhysicalArity(channelType); + if (failed(channelArity) || *channelArity != inputArity) + return fail("requires channel physical arity to match all input parts"); + + std::string materializationReason; + if (failed(checkSupportedLayoutMaterialization( + capabilities, channelType, resultType, expectedLayout, resultLayout, + &materializationReason))) + return fail(Twine("cannot materialize channel layout to result; ") + + materializationReason); + + return success(); +} + +LogicalResult +checkSupportedActivePrefixIndexShape(VMIActivePrefixIndexOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!maskLayout || !resultLayout) + return fail("requires assigned mask and result layouts"); + if (!maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous mask and result layouts"); + + std::string resultFullReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultFullReason))) + return fail(Twine("requires full result physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + resultFullReason); + + std::string maskFullReason; + if (failed(checkFullVMIPhysicalChunks(maskType, &maskFullReason))) + return fail(Twine("requires full mask physical chunks so padding mask " + "lanes cannot affect the observable prefix; ") + + maskFullReason); + + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(maskArity) || failed(resultArity)) + return fail("requires computable mask and result physical arity"); + if (*maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk prefix needs " + "cross-chunk carry"); + + return success(); +} + +LogicalResult checkSupportedCompressShape(VMICompressOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous() || + !resultLayout.isContiguous()) + return fail("requires contiguous source, mask, and result layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding mask " + "lanes cannot be squeezed into the result; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source, mask, and result physical arity"); + if (*sourceArity != 1 || *maskArity != 1 || *resultArity != 1) + return fail("requires a single physical chunk; multi-chunk compress needs " + "cross-chunk compaction"); + + return success(); +} + +LogicalResult checkSupportedCompressStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMICompressStoreOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + if (!valueLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous value and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory( + op.getDestination().getType(), "destination", "pto.vstur", + "pto.vstur stores only to UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(valueType, &fullChunkReason))) + return fail(Twine("requires full physical chunks so padding mask lanes " + "cannot be squeezed into memory; ") + + fullChunkReason); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity)) + return fail("requires computable value and mask physical arity"); + if (*valueArity != 1 || *maskArity != 1) + return fail("requires a single physical chunk; multi-chunk " + "compress_store needs cross-chunk compaction and SQZN " + "state planning"); + + return success(); +} + +template +LogicalResult +checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, VMIReductionKind kind, bool requiresReassoc, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (requiresReassoc && !op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point vcadd"); + + auto sourceType = cast(op.getSource().getType()); + auto initType = cast(op.getInit().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr initLayout = initType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !initLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, init, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !initLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous source, init, mask, and result layouts"); + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(kind, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks so padding lanes " + "do not participate in the reduction; ") + + fullChunkReason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr initArity = getVMIPhysicalArity(initType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(initArity) || failed(maskArity) || + failed(resultArity)) + return fail("requires computable physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires source and mask physical arity to match and be " + "non-empty"); + if (*initArity != 1 || *resultArity != 1) + return fail("requires one init and result physical chunk"); + + return success(); +} + +LogicalResult +checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, + VMIFmaOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lhsType = cast(op.getLhs().getType()); + VMICapabilityResult elementCapability = + capabilities.supportsElementType(lhsType.getElementType(), + VMIElementPurpose::VMula); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(lhsType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical arity"); + + return success(); +} + +LogicalResult +checkSupportedReluShape(const VMITargetCapabilityRegistry &capabilities, + VMIReluOp op, std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + if (failed(checkSupportedMaskableVReg(capabilities, resultType, reason))) + return failure(); + + VMICapabilityResult elementCapability = + capabilities.supportsElementType(resultType.getElementType(), + VMIElementPurpose::VRelu); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + return success(); +} + +void emitEnsureLayoutMaterializationError(VMIEnsureLayoutOp ensure, + VMIVRegType sourceType, + VMIVRegType resultType, + StringRef reason) { + if (ensure.getResult().hasOneUse()) { + OpOperand &use = *ensure.getResult().use_begin(); + Operation *requester = use.getOwner(); + InFlightDiagnostic diag = + requester->emitError() + << kVMIDiagUnsupportedPrefix << requester->getName() << " operand #" + << use.getOperandNumber() << " has type " << sourceType + << " but requires " << resultType + << "; pto.vmi.ensure_layout cannot materialize this conversion"; + diag.attachNote(ensure.getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; + return; + } + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_layout cannot materialize the requested data " + "layout conversion (" + << reason + << "); partial/tail layout materialization requires an explicit " + "packing plan"; +} + +LogicalResult verifySupportedVMIToVPTOOps( + ModuleOp module, const VMITargetCapabilityRegistry &capabilities, + bool enableStableGatherMaskedLoad) { + auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, + VMIVRegType type, Value source, + std::optional constantOffset) + -> WalkResult { + std::string reason; + if (succeeded(checkSupportedLoadShape(capabilities, type, source, + source.getType(), constantOffset, + &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " requires full physical chunks without padding lanes or a " + "statically safe full-read footprint (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitMaskableUnsupported = [&](Operation *op, StringRef opName, + VMIVRegType type) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedMaskableVReg(capabilities, type, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires physical vreg parts with b8/b16/b32 " + "predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + auto emitTargetElementUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, + VMIElementPurpose purpose, StringRef elementContract) -> WalkResult { + std::string reason; + if (succeeded(checkSupportedTargetElementVReg( + capabilities, type, purpose, elementContract, &reason))) + return WalkResult::advance(); + + op->emitError() + << kVMIDiagUnsupportedPrefix << opName + << " direct lowering requires " << elementContract + << " and physical vreg parts with b8/b16/b32 predicate masks (" + << reason << ")"; + return WalkResult::interrupt(); + }; + + WalkResult result = module.walk([&](Operation *op) { + if (auto constant = dyn_cast(op)) { + auto denseAttr = dyn_cast(constant.getValue()); + if (!denseAttr || !denseAttr.isSplat()) { + constant.emitError() + << kVMIDiagUnsupportedPrefix + << "non-splat pto.vmi.constant requires a vreg immediate or " + "scratch materialization plan"; + return WalkResult::interrupt(); + } + return emitMaskableUnsupported( + op, "pto.vmi.constant", + cast(constant.getResult().getType())); + } + + if (auto broadcast = dyn_cast(op)) + return emitMaskableUnsupported( + op, "pto.vmi.broadcast", + cast(broadcast.getResult().getType())); + + if (auto load = dyn_cast(op)) + return emitMemoryUnsupported( + op, "pto.vmi.load", cast(load.getResult().getType()), + load.getSource(), getConstantIndexValue(load.getOffset())); + if (auto load = dyn_cast(op)) { + if (enableStableGatherMaskedLoad) { + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load stable VGATHER-based lowering is reserved " + "for strict masked/tail loads but is not implemented yet"; + return WalkResult::interrupt(); + } + std::string reason; + if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_load direct lowering requires a supported memory " + "source, contiguous result/passthru/mask layouts, and either " + "full physical chunks or a statically safe full-read footprint (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto gather = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGatherShape(capabilities, gather, &reason))) + return WalkResult::advance(); + gather.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only " + "for UB pointer sources, contiguous full physical chunks, " + "32-bit result elements, i32 indices, and b32 masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExpandLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.expand_load direct lowering is currently supported for " + "either a static all-active mask lowered as pto.vlds, or a " + "one-full-chunk 32-bit UB runtime mask lowered through pto.vusqz " + "+ pto.vgather2_bc + pto.vsel (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStoreShape( + capabilities, cast(store.getValue().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.store requires an 8/16/32-bit predicate-maskable " + "element type and either full physical chunks or contiguous " + "tail-store layout, with UB-backed destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedMaskedStoreShape( + capabilities, cast(store.getValue().getType()), + cast(store.getMask().getType()), + store.getDestination(), store.getDestination().getType(), + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.masked_store requires either full physical chunks or " + "contiguous tail-store value/mask layout, with UB-backed " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto scatter = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedScatterShape(capabilities, scatter, + &reason))) + return WalkResult::advance(); + scatter.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.scatter lowers through pto.vscatter only with an " + "indices_unique proof, UB pointer destination, contiguous full " + "physical chunks, 32-bit value elements, i32 indices, and b32 " + "masks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto tileRead = dyn_cast(op)) + return emitMemoryUnsupported( + op, "pto.vmi.tile_read", + cast(tileRead.getResult().getType()), + tileRead.getSource(), 0); + if (auto tileWrite = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStoreShape( + capabilities, + cast(tileWrite.getValue().getType()), + tileWrite.getDestination(), tileWrite.getDestination().getType(), + &reason))) + return WalkResult::advance(); + tileWrite.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable " + "element type and either full physical chunks or contiguous " + "tail-store layout, with UB-backed destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (succeeded(checkSupportedLayoutMaterialization( + capabilities, sourceType, resultType, sourceType.getLayoutAttr(), + resultType.getLayoutAttr(), &reason))) + return WalkResult::advance(); + + emitEnsureLayoutMaterializationError(ensure, sourceType, resultType, + reason); + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (succeeded(checkSupportedLayoutMaterialization( + capabilities, sourceType, resultType, sourceType.getLayoutAttr(), + resultType.getLayoutAttr(), &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.ensure_mask_layout cannot materialize the requested " + "mask layout conversion (" + << reason + << "); partial/tail predicate layout materialization requires an " + "explicit packing plan"; + return WalkResult::interrupt(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + if (sourceType.getGranularity() == resultType.getGranularity()) + return WalkResult::advance(); + + std::string reason; + if (succeeded(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) + return WalkResult::advance(); + + ensure.emitError() + << kVMIDiagUnsupportedPrefix + << "non-identity mask granularity materialization requires concrete " + "b8/b16/b32 masks with matching lane count and layout (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto addf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.addf", + cast(addf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto addi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.addi", + cast( + addi.getResult().getType())); + if (auto subf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.subf", + cast(subf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto subi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.subi", + cast( + subi.getResult().getType())); + if (auto mulf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.mulf", + cast(mulf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto muli = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.muli", + cast( + muli.getResult().getType())); + if (auto divf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.divf", + cast(divf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto minf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.minf", + cast(minf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto maxf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.maxf", + cast(maxf.getResult().getType()), + VMIElementPurpose::F16BF16F32, + "f16/bf16/f32 element type"); + if (auto negf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.negf", + cast(negf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto absf = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absf", + cast(absf.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto absi = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.absi", + cast(absi.getResult().getType()), + VMIElementPurpose::SignlessOrSignedI8I16I32, + "signless/signed i8/i16/i32 element type"); + if (auto sqrt = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.sqrt", + cast(sqrt.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto exp = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.exp", + cast(exp.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto ln = dyn_cast(op)) + return emitTargetElementUnsupported( + op, "pto.vmi.ln", + cast(ln.getResult().getType()), + VMIElementPurpose::F16F32, + "f16/f32 element type"); + if (auto relu = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReluShape(capabilities, relu, &reason))) + return WalkResult::advance(); + relu.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.relu direct lowering requires physical vreg parts with " + "b8/b16/b32 predicate masks and f16/f32 element type (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto andi = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.andi", + cast( + andi.getResult().getType())); + if (auto ori = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.ori", + cast(ori.getResult().getType())); + if (auto xori = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.xori", + cast( + xori.getResult().getType())); + if (auto shli = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.shli", + cast( + shli.getResult().getType())); + if (auto shrui = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.shrui", + cast( + shrui.getResult().getType())); + if (auto notOp = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.not", + cast( + notOp.getResult().getType())); + if (auto select = dyn_cast(op)) + return emitMaskableUnsupported(op, "pto.vmi.select", + cast( + select.getResult().getType())); + + if (auto cmpf = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpf", cast(cmpf.getLhs().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpf.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto cmpi = dyn_cast(op)) { + WalkResult target = emitTargetElementUnsupported( + op, "pto.vmi.cmpi", cast(cmpi.getLhs().getType()), + VMIElementPurpose::AnyI8I16I32, + "signless/signed/unsigned i8/i16/i32 element type"); + if (target.wasInterrupted()) + return target; + if (succeeded(checkSupportedComparePredicate(op, cmpi.getPredicate()))) + return WalkResult::advance(); + return WalkResult::interrupt(); + } + + if (auto activePrefix = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedActivePrefixIndexShape(activePrefix, + &reason))) + return WalkResult::advance(); + activePrefix.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.active_prefix_index lowers through pto.vusqz only for " + "one contiguous physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compress = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressShape(compress, &reason))) + return WalkResult::advance(); + compress.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress lowers through pto.vsqz only for one " + "contiguous full physical chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto compressStore = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedCompressStoreShape(capabilities, + compressStore, &reason))) + return WalkResult::advance(); + compressStore.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.compress_store lowers through pto.vsqz + pto.vstur " + "only for one contiguous full physical chunk with a UB pointer " + "destination (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddI, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addi lowers through pto.vcadd only for " + "contiguous full 32-bit integer source chunks with matching " + "mask chunks and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::AddF, + /*requiresReassoc=*/true, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_addf lowers through pto.vcadd only with " + "reassoc, f32 contiguous full source chunks, matching mask " + "chunks, and one init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MaxF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_maxf lowers through pto.vcmax only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedReduceShape( + capabilities, reduce, VMIReductionKind::MinF, + /*requiresReassoc=*/false, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.reduce_minf lowers through pto.vcmin only for f16/f32 " + "contiguous full source chunks with matching mask chunks and one " + "init/result chunk (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto fma = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedFmaShape(capabilities, fma, &reason))) + return WalkResult::advance(); + fma.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 " + "element types (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extf = dyn_cast(op)) { + if (succeeded(checkSupportedExtFShape(extf))) + return WalkResult::advance(); + + extf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extf supports contiguous 16-bit float-like or fp8-like " + "physical source chunks to f32 deinterleaved=2/4 results; " + "partial/tail is allowed only when source padding maps to result " + "padding"; + return WalkResult::interrupt(); + } + + if (auto truncf = dyn_cast(op)) { + if (succeeded(checkSupportedTruncFShape(truncf))) + return WalkResult::advance(); + + truncf.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " + "to one contiguous f16 result chunk or f32 deinterleaved=4 " + "source parts to one contiguous fp8-like result chunk"; + return WalkResult::interrupt(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedBitcastShape(bitcast, &reason))) + return WalkResult::advance(); + + bitcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.bitcast requires matching source/result layouts with " + "identical physical arity and matching per-chunk logical bit " + "footprints (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto split = dyn_cast(op)) { + int64_t channels = split.getNumResults(); + std::string reason; + if (succeeded(checkSupportedChannelSplitShape(capabilities, split, + &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split supports only 2 or 4 channels"; + else + split.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_split requires source layout to be contiguous " + "or matching deinterleaved channel layout, every result layout " + "to be contiguous, and complete physical channel groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto merge = dyn_cast(op)) { + int64_t channels = merge.getInputs().size(); + std::string reason; + if (succeeded(checkSupportedChannelMergeShape(capabilities, merge, + &reason))) + return WalkResult::advance(); + + if (channels != 2 && channels != 4) + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge supports only 2 or 4 channels"; + else + merge.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.channel_merge requires every input layout to be " + "contiguous and result layout to be contiguous or matching " + "deinterleaved channel layout, with complete physical channel " + "groups (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto shuffle = dyn_cast(op)) { + std::string reason; + if (succeeded(computeShuffleForwardingSourceParts(shuffle, &reason))) + return WalkResult::advance(); + std::string splatReason; + if (succeeded(computeShuffleLane0SplatSourcePart(shuffle, &splatReason))) + return WalkResult::advance(); + std::string vselrReason; + if (succeeded(computeShuffleVselrPlans(shuffle, &vselrReason))) + return WalkResult::advance(); + + shuffle.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.shuffle requires physical chunk forwarding or " + "lane0 splat or vci-materializable vselr indices (forwarding: " + << reason << "; lane0 splat: " << splatReason + << "; vselr: " << vselrReason << ")"; + return WalkResult::interrupt(); + } + + if (auto constantMask = dyn_cast(op)) { + std::string reason; + if (succeeded(computeConstantMaskMaterialization(constantMask, &reason))) + return WalkResult::advance(); + + constantMask.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.constant_mask requires a dense bool constant with " + "concrete layout and b8/b16/b32 granularity (" + << reason << ")"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +struct VMIToVPTOPass + : public mlir::pto::impl::VMIToVPTOBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMIToVPTOPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + if (failed(verifyVMIToVPTOInputIR(module))) { + signalPassFailure(); + return; + } + VMITargetCapabilityRegistry capabilities; + if (failed(verifySupportedVMIToVPTOOps( + module, capabilities, enableStableGatherMaskedLoad))) { + signalPassFailure(); + return; + } + + MLIRContext *context = module.getContext(); + VMIToVPTOTypeConverter typeConverter; + RewritePatternSet patterns(context); + + populateVMIOneToNConversionPatterns(typeConverter, patterns, + capabilities); + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns)))) { + module.emitError() + << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; + signalPassFailure(); + return; + } + if (failed(verifyNoResidualVMIIR(module))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMIToVPTOPass() { + return std::make_unique(); +} diff --git a/test/lit/CMakeLists.txt b/test/lit/CMakeLists.txt index 684ae0bf50..9fb0e63c5b 100644 --- a/test/lit/CMakeLists.txt +++ b/test/lit/CMakeLists.txt @@ -27,6 +27,7 @@ configure_lit_site_cfg( set(PTOIR_TEST_DEPENDS FileCheck count not pto-opt + pto-test-opt ) add_lit_testsuite(check-pto "Running the pto regression tests" diff --git a/test/lit/lit.cfg.py b/test/lit/lit.cfg.py index e429a8a6b8..6565d6b06c 100644 --- a/test/lit/lit.cfg.py +++ b/test/lit/lit.cfg.py @@ -40,6 +40,8 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.ptoir_obj_root, 'test/lit') config.ptoir_tools_dir = os.path.join(config.ptoir_obj_root, 'tools/ptoas') +config.ptoir_test_tools_dir = os.path.join(config.ptoir_obj_root, + 'tools/pto-test-opt') config.substitutions.append(('%PATH%', config.environment['PATH'])) config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) @@ -59,9 +61,11 @@ # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -tool_dirs = [config.ptoir_tools_dir, config.llvm_tools_dir] +tool_dirs = [config.ptoir_tools_dir, config.ptoir_test_tools_dir, + config.llvm_tools_dir] tools = [ 'ptoas', + 'pto-test-opt', ] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/test/lit/vmi/vmi_absf_integer_invalid.pto b/test/lit/vmi/vmi_absf_integer_invalid.pto new file mode 100644 index 0000000000..2a3900e4e5 --- /dev/null +++ b/test/lit/vmi/vmi_absf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absf_integer_invalid(%value: !pto.vmi.vreg<128xi16>) { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.absf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_absi_float_invalid.pto b/test/lit/vmi/vmi_absi_float_invalid.pto new file mode 100644 index 0000000000..0f2d556c1a --- /dev/null +++ b/test/lit/vmi/vmi_absi_float_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_absi_float_invalid(%value: !pto.vmi.vreg<64xf32>) { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.absi' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto new file mode 100644 index 0000000000..c675b2e6e9 --- /dev/null +++ b/test/lit/vmi/vmi_active_prefix_index_result_type_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_active_prefix_index_result_type_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.active_prefix_index' op requires signless integer result element type diff --git a/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..bd6ed94bac --- /dev/null +++ b/test/lit/vmi/vmi_addf_lane_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_lane_mismatch_invalid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<64xf32>) { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires all VMI data values to have the same logical lane count diff --git a/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto new file mode 100644 index 0000000000..937889d014 --- /dev/null +++ b/test/lit/vmi/vmi_bitcast_total_bits_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_bitcast_total_bits_invalid(%value: !pto.vmi.vreg<128xf32>) { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: 'pto.vmi.bitcast' op requires source and result to carry the same total number of bits diff --git a/test/lit/vmi/vmi_bitwise_float_invalid.pto b/test/lit/vmi/vmi_bitwise_float_invalid.pto new file mode 100644 index 0000000000..60f260d444 --- /dev/null +++ b/test/lit/vmi/vmi_bitwise_float_invalid.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_andi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.andi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.andi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_ori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.ori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.ori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_xori_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.xori %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.xori' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_not_float_invalid(%source: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.not %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.not' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto new file mode 100644 index 0000000000..9ecdc9469f --- /dev/null +++ b/test/lit/vmi/vmi_broadcast_type_mismatch_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_type_mismatch_invalid(%value: f16) { + %result = pto.vmi.broadcast %value : f16 -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires scalar or VMI vector input element type to match result element type diff --git a/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto new file mode 100644 index 0000000000..1dbc569c4c --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_input_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_input_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<3xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<3xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires all channel inputs to have the same lane count and element type diff --git a/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto new file mode 100644 index 0000000000..f5c7ad94b9 --- /dev/null +++ b/test/lit/vmi/vmi_channel_merge_result_mismatch_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_merge_result_mismatch_invalid( + %ch0: !pto.vmi.vreg<2xf32>, + %ch1: !pto.vmi.vreg<2xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<5xf32> + return + } +} + +// CHECK: requires result lane count and element type to match merged channels diff --git a/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto new file mode 100644 index 0000000000..bbf923b079 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_lane_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_lane_count_invalid( + %src: !pto.vmi.vreg<5xf32>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<5xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + return + } +} + +// CHECK: requires source lane count to equal result count times per-channel lane count diff --git a/test/lit/vmi/vmi_channel_split_result_count_invalid.pto b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto new file mode 100644 index 0000000000..bbe2b434d6 --- /dev/null +++ b/test/lit/vmi/vmi_channel_split_result_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_channel_split_result_count_invalid( + %src: !pto.vmi.vreg<4xf32>) { + %ch0 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32>) -> !pto.vmi.vreg<4xf32> + return + } +} + +// CHECK: requires at least two channel results diff --git a/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto new file mode 100644 index 0000000000..7e7e6bb66f --- /dev/null +++ b/test/lit/vmi/vmi_compress_result_mismatch_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_compress_result_mismatch_invalid( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.compress' op requires all VMI data values to have the same element type diff --git a/test/lit/vmi/vmi_constant_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto new file mode 100644 index 0000000000..c1ff60fe3b --- /dev/null +++ b/test/lit/vmi/vmi_constant_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_attr_kind_invalid() { + %value = "pto.vmi.constant"() { + value = 1.000000e+00 : f32 + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense elements constant attribute diff --git a/test/lit/vmi/vmi_constant_element_count_invalid.pto b/test/lit/vmi/vmi_constant_element_count_invalid.pto new file mode 100644 index 0000000000..b5e80ce364 --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_count_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<64xf32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_element_type_invalid.pto b/test/lit/vmi/vmi_constant_element_type_invalid.pto new file mode 100644 index 0000000000..29a5f2d22a --- /dev/null +++ b/test/lit/vmi/vmi_constant_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_element_type_invalid() { + %value = "pto.vmi.constant"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires dense constant element type to match result element type diff --git a/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto new file mode 100644 index 0000000000..537d007f03 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_attr_kind_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_attr_kind_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = true + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense elements mask constant attribute diff --git a/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto new file mode 100644 index 0000000000..f39f4ab00a --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_count_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_count_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<64xi1> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element count to match result logical lane count diff --git a/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto new file mode 100644 index 0000000000..7f97a4afd6 --- /dev/null +++ b/test/lit/vmi/vmi_constant_mask_element_type_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_constant_mask_element_type_invalid() { + %mask = "pto.vmi.constant_mask"() { + value = dense<1> : tensor<128xi32> + } : () -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: requires dense mask constant element type to be i1 diff --git a/test/lit/vmi/vmi_divf_integer_invalid.pto b/test/lit/vmi/vmi_divf_integer_invalid.pto new file mode 100644 index 0000000000..0c26d668b3 --- /dev/null +++ b/test/lit/vmi/vmi_divf_integer_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_divf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.divf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_elementwise_kind_invalid.pto b/test/lit/vmi/vmi_elementwise_kind_invalid.pto new file mode 100644 index 0000000000..46e8255de8 --- /dev/null +++ b/test/lit/vmi/vmi_elementwise_kind_invalid.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_subf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.subf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.subf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_subi_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.subi %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.subi' op requires integer-like VMI element type + +// ----- + +module { + func.func @vmi_mulf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, %rhs: !pto.vmi.vreg<128xi32>) { + %out = pto.vmi.mulf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.mulf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_muli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, %rhs: !pto.vmi.vreg<128xf32>) { + %out = pto.vmi.muli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.muli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto new file mode 100644 index 0000000000..09a92692de --- /dev/null +++ b/test/lit/vmi/vmi_ensure_layout_surface_invalid.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_ensure_layout_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_surface_invalid( + %a: !pto.vmi.mask<128xpred>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result to be layout-assigned + +// ----- + +module { + func.func @vmi_ensure_mask_granularity_layout_mismatch_invalid( + %a: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %r = pto.vmi.ensure_mask_granularity %a + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: requires source and result mask layouts to match diff --git a/test/lit/vmi/vmi_extf_direction_invalid.pto b/test/lit/vmi/vmi_extf_direction_invalid.pto new file mode 100644 index 0000000000..e00280a69d --- /dev/null +++ b/test/lit/vmi/vmi_extf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_direction_invalid(%source: !pto.vmi.vreg<128xf32>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires result element type to be wider than source element type diff --git a/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..d1b64fc15d --- /dev/null +++ b/test/lit/vmi/vmi_extf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_extf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf16>) { + %result = pto.vmi.extf %source + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_fma_integer_invalid.pto b/test/lit/vmi/vmi_fma_integer_invalid.pto new file mode 100644 index 0000000000..e44d8879b3 --- /dev/null +++ b/test/lit/vmi/vmi_fma_integer_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_fma_integer_invalid( + %lhs: !pto.vmi.vreg<64xi32>, + %rhs: !pto.vmi.vreg<64xi32>, + %acc: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<64xi32>, + !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %out : !pto.vmi.vreg<64xi32> + } +} + +// CHECK: 'pto.vmi.fma' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_gather_indices_invalid.pto b/test/lit/vmi/vmi_gather_indices_invalid.pto new file mode 100644 index 0000000000..057e3d1244 --- /dev/null +++ b/test/lit/vmi/vmi_gather_indices_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_gather_indices_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xf32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return + } +} + +// CHECK: 'pto.vmi.gather' op requires signless or unsigned 32-bit integer indices diff --git a/test/lit/vmi/vmi_iota_element_type_invalid.pto b/test/lit/vmi/vmi_iota_element_type_invalid.pto new file mode 100644 index 0000000000..448fba485f --- /dev/null +++ b/test/lit/vmi/vmi_iota_element_type_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_element_type_invalid(%base: i64) { + %value = pto.vmi.iota %base + : i64 -> !pto.vmi.vreg<64xi64> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires result element type to be integer 8/16/32 or f16/f32 diff --git a/test/lit/vmi/vmi_iota_order_invalid.pto b/test/lit/vmi/vmi_iota_order_invalid.pto new file mode 100644 index 0000000000..93df56591c --- /dev/null +++ b/test/lit/vmi/vmi_iota_order_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_iota_order_invalid(%base: i32) { + %value = pto.vmi.iota %base {order = "DOWN"} + : i32 -> !pto.vmi.vreg<64xi32> + return + } +} + +// CHECK: 'pto.vmi.iota' op requires order to be ASC or DESC diff --git a/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto new file mode 100644 index 0000000000..5dabf59203 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_active_prefix_index.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_active_prefix_index(%mask: !pto.vmi.mask<64xpred>) + -> !pto.vmi.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xi32> + return %idx : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_active_prefix_index( +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: %[[IDX:.*]] = pto.vmi.active_prefix_index %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK: return %[[IDX]] diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto new file mode 100644 index 0000000000..6e165de8a0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_remat( + %scalar: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %broadcast, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %broadcast, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// ASSIGN-SAME: %[[SCALAR:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[BCAST_DEINT:.*]] = pto.vmi.broadcast %[[SCALAR]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[BCAST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[BCAST_DEINT]] +// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.broadcast %[[SCALAR]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[BCAST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_remat( +// LOWER-COUNT-4: pto.vdup %arg0 +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_call_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto new file mode 100644 index 0000000000..b7245ad00b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_boundary.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-LABEL: func.func @caller( +// CHECK: %[[EA:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[R:.*]] = call @callee(%[[EA]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[R]], %[[R]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_branch.pto b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto new file mode 100644 index 0000000000..f96962a580 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_branch.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %then_mask = pto.vmi.cmpf "olt", %then_value, %then_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%then_value, %then_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_mask = pto.vmi.cmpf "olt", %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + cf.br ^join(%else_value, %else_mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) + + ^join(%value: !pto.vmi.vreg<128xf32>, %mask: !pto.vmi.mask<128xpred>): + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_cf_branch( +// CHECK: cf.cond_br +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: ^{{.*}}(%{{.*}}: !pto.vmi.vreg<128xf16, #pto.vmi.layout>): +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: cf.br +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: ^{{.*}}(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout>): +// CHECK: pto.vmi.select %[[MASK]], %[[VALUE]], %[[VALUE]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_cf_switch.pto b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto new file mode 100644 index 0000000000..6376a5502c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_cf_switch.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_cf_switch( + %flag: i32, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.switch %flag : i32, [ + default: ^join(%a : !pto.vmi.vreg<128xf32>), + 0: ^join(%b : !pto.vmi.vreg<128xf32>) + ] + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_cf_switch( +// ASSIGN-SAME: %[[FLAG:[^:]+]]: i32 +// ASSIGN-SAME: %[[A:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[B:[^:]+]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: cf.switch %[[FLAG]] : i32, [ +// ASSIGN: default: ^bb1(%[[A]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>), +// ASSIGN: 0: ^bb1(%[[B]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN: ^bb1(%[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: return %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_cf_switch( +// LOWER-SAME: %[[FLAG:[^:]+]]: i32 +// LOWER-SAME: %[[A0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[A1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B0:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[B1:[^:]+]]: !pto.vreg<64xf32> +// LOWER-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: cf.switch %[[FLAG]] : i32, [ +// LOWER: default: ^bb1(%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), +// LOWER: 0: ^bb1(%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: ^bb1(%[[VALUE0:.*]]: !pto.vreg<64xf32>, %[[VALUE1:.*]]: !pto.vreg<64xf32>): +// LOWER: return %[[VALUE0]], %[[VALUE1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..351b3f62f8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32>, + %ch1: !pto.vmi.vreg<64xf32>, + %ch2: !pto.vmi.vreg<64xf32>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<192xf32> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..572845c1a4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_layout_assignment_compress.pto b/test/lit/vmi/vmi_layout_assignment_compress.pto new file mode 100644 index 0000000000..dee109ce28 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress( + %src: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress( +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout>) +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.compress %[[SRC]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_compress_store.pto b/test/lit/vmi/vmi_layout_assignment_compress_store.pto new file mode 100644 index 0000000000..93266bdf42 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_compress_store.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_compress_store( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_compress_store( +// CHECK-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %[[DST:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.compress_store %[[VALUE]], %[[DST]][%[[OFFSET]]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto new file mode 100644 index 0000000000..e387aa077d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_constant_remat( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %constant, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %constant, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_remat( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[CONST_DEINT:.*]] = "pto.vmi.constant"() +// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[CONST_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[CONST_DEINT]] +// ASSIGN: %[[CONST_CONTIG:.*]] = "pto.vmi.constant"() +// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[CONST_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_constant_remat( +// LOWER: arith.constant 1.000000e+00 : f32 +// LOWER-COUNT-4: pto.vdup +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_expand_load.pto b/test/lit/vmi/vmi_layout_assignment_expand_load.pto new file mode 100644 index 0000000000..501b26b369 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_expand_load.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_expand_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_expand_load( +// CHECK-SAME: %[[SRC:.*]]: !pto.ptr +// CHECK-SAME: %[[OFFSET:.*]]: index +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %[[PASSTHRU:.*]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.expand_load %[[SRC]][%[[OFFSET]]], %[[MASK]], %[[PASSTHRU]] +// CHECK-SAME: !pto.ptr +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto new file mode 100644 index 0000000000..101f0f9254 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> + + func.func @caller(%x: !pto.vmi.vreg<128xf32>) { + %r = call @external(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto new file mode 100644 index 0000000000..ffb994287a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto @@ -0,0 +1,15 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func private @external_vmi(!pto.vmi.vreg<128xf32>) +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed function declaration requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto new file mode 100644 index 0000000000..384d0d1171 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_external_decl_preserve.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func private @external_i32(i32) -> i32 + + func.func @vmi_layout_assignment_external_decl_preserve( + %input: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + return %input : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: module +// CHECK: func.func private @external_i32(i32) -> i32 +// CHECK-LABEL: func.func @vmi_layout_assignment_external_decl_preserve( +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_fma.pto b/test/lit/vmi/vmi_layout_assignment_fma.pto new file mode 100644 index 0000000000..c40b09c471 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_fma.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_fma( + %lhs: !pto.vmi.vreg<64xf32>, + %rhs: !pto.vmi.vreg<64xf32>, + %acc: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_fma( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_gather.pto b/test/lit/vmi/vmi_layout_assignment_gather.pto new file mode 100644 index 0000000000..a63919bf6f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_gather.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, !pto.vmi.vreg<64xi32>, + !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_gather( +// CHECK-SAME: %arg1: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.gather %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto new file mode 100644 index 0000000000..4186b78dfa --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_indirect_call_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @caller( + %fn: (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32>, + %x: !pto.vmi.vreg<128xf32>) { + %r = func.call_indirect %fn(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: VMI typed call requires a direct internal callee with a body diff --git a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto new file mode 100644 index 0000000000..773fd4187c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_iota_remat( + %base: f32, + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf32> { + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %iota, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %iota, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_iota_remat( +// ASSIGN-SAME: %[[BASE:.*]]: f32 +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[IOTA_DEINT:.*]] = pto.vmi.iota %[[BASE]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[IOTA_DEINT]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout %[[IOTA_DEINT]] +// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.iota %[[BASE]] +// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[IOTA_CONTIG]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_iota_remat( +// LOWER: pto.vci +// LOWER: pto.vcvt +// LOWER: pto.vadd +// LOWER-NOT: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto new file mode 100644 index 0000000000..6b2d588e04 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -0,0 +1,133 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_load_truncf( + %src: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_tile_read_truncf( + %src: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_load_truncf_multi_use( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.store %wide, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } + + func.func @vmi_layout_assignment_tile_read_truncf_multi_use( + %src: memref<128xf32>, + %dst: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { + %wide = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %wide, %dst + : !pto.vmi.vreg<128xf32>, memref<128xf32> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf( +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read +// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( +// LOWER: %[[ZERO:.*]] = arith.constant 0 : index +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read +// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.tile_write %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto new file mode 100644 index 0000000000..f3942119a7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_conflict_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_granularity_conflict_invalid( + %cond: i1, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %mask = scf.if %cond -> !pto.vmi.mask<128xpred> { + %m16 = pto.vmi.cmpf "olt", %a16, %b16 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.mask<128xpred> + scf.yield %m16 : !pto.vmi.mask<128xpred> + } else { + %m32 = pto.vmi.cmpf "olt", %a32, %b32 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %m32 : !pto.vmi.mask<128xpred> + } + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting mask granularities diff --git a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto new file mode 100644 index 0000000000..b114643836 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_create_mask_remat( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_layout_assignment_constant_mask_remat( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// CHECK-SAME: %[[ACTIVE:.*]]: index +// CHECK: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// CHECK-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// CHECK-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M32]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity + +// CHECK-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// CHECK: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"() +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[CM16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[CM32]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto new file mode 100644 index 0000000000..fd487d017a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_use_ensure.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_mask_use_ensure( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_mask_use_ensure( +// CHECK-SAME: %[[M:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M16]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: pto.vmi.select %[[M]] +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load.pto b/test/lit/vmi/vmi_layout_assignment_masked_load.pto new file mode 100644 index 0000000000..286c92f6da --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xpred>, + %passthru: !pto.vmi.vreg<64xf32>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_masked_load( +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.masked_load %arg0[%arg1], %arg2, %arg3 +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return.pto b/test/lit/vmi/vmi_layout_assignment_multi_return.pto new file mode 100644 index 0000000000..380b0d0ef9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @multi_return( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @multi_return( +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto new file mode 100644 index 0000000000..4e9b2885fd --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @multi_return_conflict( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf8E4M3FN>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then, ^else + + ^then: + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return %ea : !pto.vmi.vreg<128xf32> + + ^else: + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf8E4M3FN> -> !pto.vmi.vreg<128xf32> + return %eb : !pto.vmi.vreg<128xf32> + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout diff --git a/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto new file mode 100644 index 0000000000..968aeb1e05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_post_gate_type_attr_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto new file mode 100644 index 0000000000..de71e01d6a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addf.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addf %arg0, %arg1, %arg2 +// CHECK-SAME: reassoc +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto new file mode 100644 index 0000000000..82a516b114 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_addi.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_addi( + %source: !pto.vmi.vreg<64xi32>, + %init: !pto.vmi.vreg<1xi32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> + return %out : !pto.vmi.vreg<1xi32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_addi( +// CHECK-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %[[INIT:.*]]: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.reduce_addi %[[SOURCE]], %[[INIT]], %[[MASK]] +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<1xi32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto new file mode 100644 index 0000000000..51f8180ef0 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_reduce_minmaxf.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_reduce_maxf( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<1xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return %out : !pto.vmi.vreg<1xf32> + } + + func.func @vmi_layout_assignment_reduce_minf( + %source: !pto.vmi.vreg<128xf16>, + %init: !pto.vmi.vreg<1xf16>, + %mask: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<1xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> + return %out : !pto.vmi.vreg<1xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_maxf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// CHECK: %[[MAX:.*]] = pto.vmi.reduce_maxf %arg0, %arg1, %arg2 +// CHECK: return %[[MAX]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_reduce_minf( +// CHECK-SAME: %arg0: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> +// CHECK: %[[MASK:.*]] = pto.vmi.ensure_mask_granularity %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MIN:.*]] = pto.vmi.reduce_minf %arg0, %arg1, %[[MASK]] +// CHECK: return %[[MIN]] diff --git a/test/lit/vmi/vmi_layout_assignment_scatter.pto b/test/lit/vmi/vmi_layout_assignment_scatter.pto new file mode 100644 index 0000000000..9560cfa981 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scatter.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scatter( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scatter( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg2: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: %arg3: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK: pto.vmi.scatter %arg0, %arg1[%arg2], %arg3 +// CHECK-SAME: indices_unique +// CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto new file mode 100644 index 0000000000..3bd81dca8c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_execute_region.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_execute_region( + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.execute_region -> !pto.vmi.vreg<128xf32> { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.execute_region -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_execute_region( +// LOWER: %[[RESULT:.*]]:2 = scf.execute_region -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_scf_for.pto new file mode 100644 index 0000000000..b63563216b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_for.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_for(%a: !pto.vmi.vreg<128xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + %sum = pto.vmi.addf %result, %result + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_for( +// CHECK: %[[INIT:.*]] = pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[RESULT:.*]] = scf.for +// CHECK-SAME: iter_args(%[[ACC:.*]] = %[[INIT]]) +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[ACC]], %[[ACC]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf %[[RESULT]], %[[RESULT]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_if.pto b/test/lit/vmi/vmi_layout_assignment_scf_if.pto new file mode 100644 index 0000000000..f86107920a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_if.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_scf_if( +// CHECK: scf.if +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: scf.yield +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto new file mode 100644 index 0000000000..24ea65503e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_index_switch.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_index_switch( + %selector: index, + %input: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.index_switch %selector -> !pto.vmi.vreg<128xf32> + case 0 { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + default { + scf.yield %wide : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// ASSIGN-SAME: %[[SELECTOR:.*]]: index +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.index_switch %[[SELECTOR]] -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: default +// ASSIGN: scf.yield %[[WIDE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_index_switch( +// LOWER: %[[RESULT:.*]]:2 = scf.index_switch {{.*}} -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: default +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_scf_while.pto b/test/lit/vmi/vmi_layout_assignment_scf_while.pto new file mode 100644 index 0000000000..917bf1762f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_scf_while.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_scf_while( + %input: !pto.vmi.vreg<128xf16>, + %keep_going: i1) -> !pto.vmi.vreg<128xf32> { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.while (%value = %wide) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + scf.condition(%keep_going) %value : !pto.vmi.vreg<128xf32> + } do { + ^bb0(%value: !pto.vmi.vreg<128xf32>): + scf.yield %value : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_scf_while( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RESULT:.*]] = scf.while (%[[VALUE:.*]] = %[[WIDE]]) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.condition(%arg1) %[[VALUE]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: ^bb0(%[[AFTER:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout>): +// ASSIGN: scf.yield %[[AFTER]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: return %[[RESULT]] : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_scf_while( +// LOWER: %[[RESULT:.*]]:2 = scf.while +// LOWER-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// LOWER: scf.condition(%arg1) {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_store_ensure.pto b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto new file mode 100644 index 0000000000..430fff7fda --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_store_ensure.pto @@ -0,0 +1,48 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_store_ensure( + %src: !pto.vmi.vreg<128xf16>, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_store_ensure( +// ASSIGN-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[WIDE]], %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_store_ensure( +// LOWER: %[[EVEN:.*]] = pto.vcvt +// LOWER: %[[ODD:.*]] = pto.vcvt +// LOWER: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]] +// LOWER: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]] +// LOWER: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[SUM0]], %[[SUM1]] +// LOWER: pto.vsts %[[D0]] +// LOWER: pto.vsts %[[D1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto new file mode 100644 index 0000000000..141e85772b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_truncf_ensure( + %wide: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return %narrow : !pto.vmi.vreg<128xf16> + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// ASSIGN-SAME: %[[WIDE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_truncf_ensure( +// LOWER-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]]{{.*}}part = "EVEN" +// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]]{{.*}}part = "ODD" +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_widen.pto b/test/lit/vmi/vmi_layout_assignment_widen.pto new file mode 100644 index 0000000000..eceedcb711 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_widen( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) { + %ea = pto.vmi.extf %a : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %eb = pto.vmi.extf %b : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %ea, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %cmp, %sum, %ea + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_widen( +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK: pto.vmi.extf +// CHECK-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.addf +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.cmpf +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.select +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_factor_invalid.pto b/test/lit/vmi/vmi_layout_factor_invalid.pto new file mode 100644 index 0000000000..b908700333 --- /dev/null +++ b/test/lit/vmi/vmi_layout_factor_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_factor_invalid( + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout expected factor to be 2 or 4 diff --git a/test/lit/vmi/vmi_layout_gate_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto new file mode 100644 index 0000000000..1b1bfdfb52 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_invalid(%a: !pto.vmi.vreg<128xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout diff --git a/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto new file mode 100644 index 0000000000..79425740d8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_surface_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_surface_mask_invalid( + %m: !pto.vmi.mask<128xpred>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: layout-assigned VMI IR requires !pto.vmi.vreg with layout +// CHECK-SAME: !pto.vmi.mask with b8/b16/b32 granularity plus layout diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto new file mode 100644 index 0000000000..7494367606 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_nested_physical_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_physical_state = [{nested = !pto.vreg<64xf32>}] +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..78549ed3e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.mask<128xpred> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_layout_gate_valid.pto b/test/lit/vmi/vmi_layout_gate_valid.pto new file mode 100644 index 0000000000..ebc5778f34 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_valid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir + +module { + func.func @vmi_layout_gate_valid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %sel = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto new file mode 100644 index 0000000000..43aca3fd30 --- /dev/null +++ b/test/lit/vmi/vmi_mask_concrete_without_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_concrete_without_layout_invalid( + %arg0: !pto.vmi.mask<128xb32>) { + return + } +} + +// CHECK: concrete mask granularity requires layout diff --git a/test/lit/vmi/vmi_mask_granularity_invalid.pto b/test/lit/vmi/vmi_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4d85cc9aa0 --- /dev/null +++ b/test/lit/vmi/vmi_mask_granularity_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_granularity_invalid( + %arg0: !pto.vmi.mask<128xb64, #pto.vmi.layout>) { + return + } +} + +// CHECK: expected granularity to be one of pred, b8, b16, b32 diff --git a/test/lit/vmi/vmi_mask_logic_invalid.pto b/test/lit/vmi/vmi_mask_logic_invalid.pto new file mode 100644 index 0000000000..49798b742b --- /dev/null +++ b/test/lit/vmi/vmi_mask_logic_invalid.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_and_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_and' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_or_granularity_mismatch( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_or' op requires all VMI mask values to have the same granularity + +// ----- + +module { + func.func @vmi_mask_xor_lane_mismatch( + %lhs: !pto.vmi.mask<128xpred>, + %rhs: !pto.vmi.mask<64xpred>) { + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<64xpred> + -> !pto.vmi.mask<128xpred> + return + } +} + +// CHECK: 'pto.vmi.mask_xor' op requires all VMI mask values to have the same logical lane count + +// ----- + +module { + func.func @vmi_mask_not_granularity_mismatch( + %src: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %not = pto.vmi.mask_not %src + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.mask_not' op requires all VMI mask values to have the same granularity diff --git a/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto new file mode 100644 index 0000000000..e7d949242e --- /dev/null +++ b/test/lit/vmi/vmi_mask_pred_with_layout_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_mask_pred_with_layout_invalid( + %arg0: !pto.vmi.mask<128xpred, #pto.vmi.layout>) { + return + } +} + +// CHECK: pred mask must not carry layout diff --git a/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto new file mode 100644 index 0000000000..4b3a672049 --- /dev/null +++ b/test/lit/vmi/vmi_masked_store_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_masked_store_mask_granularity_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.masked_store' op requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_memory_element_type_invalid.pto b/test/lit/vmi/vmi_memory_element_type_invalid.pto new file mode 100644 index 0000000000..4d6a199e11 --- /dev/null +++ b/test/lit/vmi/vmi_memory_element_type_invalid.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_load_element_type_invalid(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: 'pto.vmi.load' op requires memory source element type to match VMI data element type + +// ----- + +module { + func.func @vmi_store_element_type_invalid( + %value: !pto.vmi.vreg<128xf16>, %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// CHECK: 'pto.vmi.store' op requires memory destination element type to match VMI data element type + +// ----- + +module { + func.func @vmi_tile_read_element_type_invalid(%src: memref<128xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: 'pto.vmi.tile_read' op requires memory source element type to match VMI data element type + +// ----- + +module { + func.func @vmi_tile_write_element_type_invalid( + %value: !pto.vmi.vreg<128xf16>, %dst: memref<128xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf16>, memref<128xf32> + return + } +} + +// CHECK: 'pto.vmi.tile_write' op requires memory destination element type to match VMI data element type diff --git a/test/lit/vmi/vmi_min_max_integer_invalid.pto b/test/lit/vmi/vmi_min_max_integer_invalid.pto new file mode 100644 index 0000000000..71d0861e82 --- /dev/null +++ b/test/lit/vmi/vmi_min_max_integer_invalid.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_minf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.minf' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_maxf_integer_invalid( + %lhs: !pto.vmi.vreg<128xi32>, + %rhs: !pto.vmi.vreg<128xi32>) { + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.maxf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_negf_integer_invalid.pto b/test/lit/vmi/vmi_negf_integer_invalid.pto new file mode 100644 index 0000000000..6b28584b64 --- /dev/null +++ b/test/lit/vmi/vmi_negf_integer_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_negf_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.negf' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto new file mode 100644 index 0000000000..bff24c6e07 --- /dev/null +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_op_verifier_basic( + %ptr: !pto.ptr, + %tile: memref<128xf32>, + %layouted: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %f32 = arith.constant 1.000000e+00 : f32 + %f16 = arith.constant 1.000000e+00 : f16 + %active = arith.constant 64 : index + + %const = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32> + %broadcast = pto.vmi.broadcast %f32 : f32 -> !pto.vmi.vreg<128xf32> + %broadcast16 = pto.vmi.broadcast %f16 : f16 -> !pto.vmi.vreg<128xf16> + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %mask_const = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + + %add = pto.vmi.addf %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %cmp = pto.vmi.cmpf "olt", %broadcast, %const + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.mask<128xpred> + %sel = pto.vmi.select %mask, %broadcast, %const + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ext = pto.vmi.extf %broadcast16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %trunc = pto.vmi.truncf %ext : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + + %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr + %tile_read = pto.vmi.tile_read %tile : memref<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %tile_read, %tile : !pto.vmi.vreg<128xf32>, memref<128xf32> + + %small = "pto.vmi.shuffle"(%broadcast) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<4xf32> + %split0, %split1 = "pto.vmi.channel_split"(%small) + : (!pto.vmi.vreg<4xf32>) -> (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) + %merged = "pto.vmi.channel_merge"(%split0, %split1) + : (!pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32>) -> !pto.vmi.vreg<4xf32> + + %ensure = pto.vmi.ensure_layout %layouted + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %layouted_ext = pto.vmi.extf %ensure + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf64, #pto.vmi.layout> + %layouted_trunc = pto.vmi.truncf %layouted_ext + : !pto.vmi.vreg<128xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_layout = pto.vmi.ensure_mask_layout %mask_b32 + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_granularity = pto.vmi.ensure_mask_granularity %mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %part0, %part1 = "pto.vmi.unpack"(%layouted) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %packed = "pto.vmi.pack"(%part0, %part1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %i0 = arith.constant 0 : i32 + %iv0 = pto.vmi.broadcast %i0 : i32 -> !pto.vmi.vreg<128xi32> + %iadd = pto.vmi.addi %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + %icmp = pto.vmi.cmpi "slt", %iv0, %iv0 + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.mask<128xpred> + + return + } +} + +// CHECK-LABEL: func.func @vmi_op_verifier_basic +// CHECK: pto.vmi.broadcast +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.cmpf +// CHECK: pto.vmi.select +// CHECK: pto.vmi.extf +// CHECK: pto.vmi.truncf +// CHECK: pto.vmi.load +// CHECK: pto.vmi.store +// CHECK: pto.vmi.tile_read +// CHECK: pto.vmi.tile_write +// CHECK: pto.vmi.ensure_layout +// CHECK: pto.vmi.ensure_mask_layout +// CHECK: pto.vmi.ensure_mask_granularity +// CHECK: "pto.vmi.unpack" +// CHECK: "pto.vmi.pack" +// CHECK: pto.vmi.addi +// CHECK: pto.vmi.cmpi diff --git a/test/lit/vmi/vmi_pack_arity_invalid.pto b/test/lit/vmi/vmi_pack_arity_invalid.pto new file mode 100644 index 0000000000..4ba4eaa180 --- /dev/null +++ b/test/lit/vmi/vmi_pack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_pack_arity_invalid(%p0: !pto.vreg<64xf32>) { + %a = "pto.vmi.pack"(%p0) + : (!pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto new file mode 100644 index 0000000000..81805f2a28 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_helper_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_helper_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.ensure_layout %a + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto new file mode 100644 index 0000000000..be6a6414f9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_layout_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface diff --git a/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto new file mode 100644 index 0000000000..3d3727bdaa --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_mask_layout_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_mask_layout_invalid( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI producer boundary requires surface !pto.vmi.vreg or !pto.vmi.mask type diff --git a/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto new file mode 100644 index 0000000000..c5aa0676f0 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_non_vmi_op_invalid.pto @@ -0,0 +1,21 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_non_vmi_op_invalid( + %a: !pto.vmi.vreg<128xf32>) { + %0 = builtin.unrealized_conversion_cast %a + : !pto.vmi.vreg<128xf32> to !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI typed value is used by a non-VMI semantic op diff --git a/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto new file mode 100644 index 0000000000..c2a3996eb9 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_physical_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_producer_boundary_physical_invalid(%a: !pto.vreg<64xf32>) { + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO + +// ----- + +module { + func.func @vmi_producer_boundary_physical_op_invalid() { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + return + } +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: physical VPTO register type appears before VMI-to-VPTO diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto new file mode 100644 index 0000000000..8deed1cecb --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_layout_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto new file mode 100644 index 0000000000..4163dcfb16 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_nested_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_state = {nested = [!pto.vmi.vreg<128xf32>]} +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto new file mode 100644 index 0000000000..8cd353ca13 --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_type_attr_surface_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-ir 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32> +} { +} + +// CHECK: VMI-PASS-INVARIANT +// CHECK: VMI or physical VPTO type appears in a non-signature attribute diff --git a/test/lit/vmi/vmi_producer_boundary_valid.pto b/test/lit/vmi/vmi_producer_boundary_valid.pto new file mode 100644 index 0000000000..dee731bd1f --- /dev/null +++ b/test/lit/vmi/vmi_producer_boundary_valid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -pto-validate-vmi-ir | FileCheck %s + +module { + func.func @vmi_producer_boundary_valid( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %m: !pto.vmi.mask<128xpred>) -> !pto.vmi.vreg<128xf32> { + %r = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %s = pto.vmi.select %m, %r, %a + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %s : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_producer_boundary_valid +// CHECK: pto.vmi.addf +// CHECK: pto.vmi.select diff --git a/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto new file mode 100644 index 0000000000..7379984b50 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_backend_required_invalid.pto @@ -0,0 +1,17 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --enable-vmi %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + func.func @vmi_ptoas_backend_required_invalid() { + return + } +} + +// CHECK: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_cli_control_flow.pto b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto new file mode 100644 index 0000000000..cd29782d10 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_control_flow.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_control_flow( + %cond: i1, + %lhs: f32, + %rhs: f32, + %dst: !pto.ptr, + %offset: index) { + %lhs_v = pto.vmi.broadcast %lhs + : f32 -> !pto.vmi.vreg<128xf32> + %rhs_v = pto.vmi.broadcast %rhs + : f32 -> !pto.vmi.vreg<128xf32> + %chosen = scf.if %cond -> !pto.vmi.vreg<128xf32> { + scf.yield %lhs_v : !pto.vmi.vreg<128xf32> + } else { + scf.yield %rhs_v : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %chosen, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_control_flow +// CHECK: %[[LHS:.*]] = pto.vdup +// CHECK: %[[RHS:.*]] = pto.vdup +// CHECK: %[[CHOSEN:.*]] = arith.select {{.*}}, %[[LHS]], %[[RHS]] : !pto.vreg<64xf32> +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK: pto.vsts %[[CHOSEN]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto new file mode 100644 index 0000000000..8957bb1f40 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s +// RUN: ptoas --pto-arch=a5 --enable-vmi --emit-vpto %s -o - | FileCheck %s --check-prefix=ATTR +// RUN: not ptoas --pto-backend=emitc --enable-vmi %s -o - 2>&1 | FileCheck %s --check-prefix=EMITC + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_pipeline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_pipeline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// ATTR-LABEL: func.func @vmi_ptoas_cli_pipeline +// ATTR: pto.vecscope +// ATTR: pto.vdup +// ATTR: pto.vsts +// ATTR-NOT: pto.vmi. +// ATTR-NOT: !pto.vmi. +// ATTR-NOT: unrealized_conversion_cast + +// EMITC: Error: --enable-vmi requires --pto-backend=vpto or pto.backend = "vpto". diff --git a/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto new file mode 100644 index 0000000000..79b146acd8 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_abi_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_abi_invalid( + %value: !pto.vmi.vreg<128xf32>) { + return + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto new file mode 100644 index 0000000000..a27067e62c --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_public_result_abi_invalid( + %scalar: f32) -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: public VMI typed function requires an explicit external ABI materialization plan diff --git a/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto new file mode 100644 index 0000000000..47dc112c04 --- /dev/null +++ b/test/lit/vmi/vmi_reduce_addf_missing_reassoc_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_reduce_addf_missing_reassoc_invalid( + %source: !pto.vmi.vreg<64xf32>, + %init: !pto.vmi.vreg<1xf32>, + %mask: !pto.vmi.mask<64xpred>) { + %out = pto.vmi.reduce_addf %source, %init, %mask + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> + return + } +} + +// CHECK: 'pto.vmi.reduce_addf' op requires reassoc attr because VPTO vcadd performs pair-wise floating-point reduction diff --git a/test/lit/vmi/vmi_scatter_indices_invalid.pto b/test/lit/vmi/vmi_scatter_indices_invalid.pto new file mode 100644 index 0000000000..bd59b81b04 --- /dev/null +++ b/test/lit/vmi/vmi_scatter_indices_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_scatter_indices_invalid( + %value: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32>, !pto.ptr, + !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + return + } +} + +// CHECK: 'pto.vmi.scatter' op requires signless or unsigned 32-bit integer indices diff --git a/test/lit/vmi/vmi_select_mask_granularity_invalid.pto b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto new file mode 100644 index 0000000000..2e6b9d10f9 --- /dev/null +++ b/test/lit/vmi/vmi_select_mask_granularity_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_select_mask_granularity_invalid( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %r = pto.vmi.select %m, %a, %b + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: requires mask granularity to match data element width diff --git a/test/lit/vmi/vmi_shli_float_invalid.pto b/test/lit/vmi/vmi_shli_float_invalid.pto new file mode 100644 index 0000000000..e73ee9c232 --- /dev/null +++ b/test/lit/vmi/vmi_shli_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shli_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shli %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shli' op requires integer-like VMI element type diff --git a/test/lit/vmi/vmi_shrui_float_invalid.pto b/test/lit/vmi/vmi_shrui_float_invalid.pto new file mode 100644 index 0000000000..5de50dfff1 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_float_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_float_invalid( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_shrui_signed_invalid.pto b/test/lit/vmi/vmi_shrui_signed_invalid.pto new file mode 100644 index 0000000000..c3c57a52e9 --- /dev/null +++ b/test/lit/vmi/vmi_shrui_signed_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_shrui_signed_invalid( + %lhs: !pto.vmi.vreg<128xsi16>, + %rhs: !pto.vmi.vreg<128xsi16>) { + %shifted = pto.vmi.shrui %lhs, %rhs + : !pto.vmi.vreg<128xsi16>, !pto.vmi.vreg<128xsi16> + -> !pto.vmi.vreg<128xsi16> + return + } +} + +// CHECK: 'pto.vmi.shrui' op requires signless or unsigned integer VMI element type diff --git a/test/lit/vmi/vmi_to_vpto_abs.pto b/test/lit/vmi/vmi_to_vpto_abs.pto new file mode 100644 index 0000000000..247a239f66 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_abs.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %abs : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[AF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[AF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[AF0]], %[[AF1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[I1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AI0:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AI1:.*]] = pto.vabs %[[I1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AI0]], %[[AI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto new file mode 100644 index 0000000000..7d64e0ec0f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<64xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_active_prefix_index( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[M:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[M]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg0 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[IDX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto new file mode 100644 index 0000000000..cb655b0e4f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_multichunk_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%idx) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: multi-chunk prefix needs cross-chunk carry diff --git a/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto new file mode 100644 index 0000000000..07fd5307e0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_active_prefix_index_tail_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_tail_invalid( + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: padding mask lanes cannot affect the observable prefix diff --git a/test/lit/vmi/vmi_to_vpto_add.pto b/test/lit/vmi/vmi_to_vpto_add.pto new file mode 100644 index 0000000000..49b5fdeca3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_add.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_addf( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_addi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_addf( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[M1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_addi( +// CHECK: %[[IM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM0]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vadd {{.*}}, {{.*}}, %[[IM1]] +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-NOT: pto.vmi.add +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bf16_arith.pto b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto new file mode 100644 index 0000000000..c7357b5abd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bf16_arith.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bf16_arith( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>) { + %sum = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %sum_part = "pto.vmi.unpack"(%sum) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %min_part = "pto.vmi.unpack"(%min) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + %max_part = "pto.vmi.unpack"(%max) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %sum_part, %min_part, %max_part + : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bf16_arith( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[ADD:.*]] = pto.vadd %arg0, %arg1, %[[MASK]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MIN:.*]] = pto.vmin %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: %[[MAX:.*]] = pto.vmax %arg0, %arg1, %{{.*}} : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[ADD]], %[[MIN]], %[[MAX]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast.pto b/test/lit/vmi/vmi_to_vpto_bitcast.pto new file mode 100644 index 0000000000..f73ffbe68a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_f32_to_i16( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<256xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<256xi16> + return %cast : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_f32_to_i16( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto new file mode 100644 index 0000000000..e2a1b3c789 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_partial.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_partial( + %value: !pto.vmi.vreg<65xf32>) -> !pto.vmi.vreg<130xi16> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<65xf32> -> !pto.vmi.vreg<130xi16> + return %cast : !pto.vmi.vreg<130xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_partial( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[S0]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[S1]] : !pto.vreg<64xf32> -> !pto.vreg<128xi16> +// CHECK: return %[[B0]], %[[B1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitwise.pto b/test/lit/vmi/vmi_to_vpto_bitwise.pto new file mode 100644 index 0000000000..80a665ccd9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitwise.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitwise( + %a: !pto.vmi.vreg<256xi16>, + %b: !pto.vmi.vreg<256xi16>) + -> (!pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>) { + %and = pto.vmi.andi %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %or = pto.vmi.ori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %xor = pto.vmi.xori %a, %b + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + %not = pto.vmi.not %a + : !pto.vmi.vreg<256xi16> -> !pto.vmi.vreg<256xi16> + return %and, %or, %xor, %not + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16>, + !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitwise( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[AND0:.*]] = pto.vand %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[AND1:.*]] = pto.vand %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR0:.*]] = pto.vor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[OR1:.*]] = pto.vor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR0:.*]] = pto.vxor %[[A0]], %[[B0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[XOR1:.*]] = pto.vxor %[[A1]], %[[B1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT0:.*]] = pto.vnot %[[A0]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[NOT1:.*]] = pto.vnot %[[A1]], {{.*}} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_broadcast.pto b/test/lit/vmi/vmi_to_vpto_broadcast.pto new file mode 100644 index 0000000000..9cdbf92e1e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_broadcast.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_broadcast_contiguous(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_deint4(%scalar: f32) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_broadcast_rank0( + %scalar: !pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.broadcast %scalar + : !pto.vmi.vreg<1xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_contiguous( +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %arg0, %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %arg0, %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_deint4( +// CHECK-COUNT-4: pto.vdup %arg0 +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_broadcast_rank0( +// CHECK-COUNT-4: pto.vdup %arg0{{.*}}{position = "LOWEST"} +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_call_boundary.pto b/test/lit/vmi/vmi_to_vpto_call_boundary.pto new file mode 100644 index 0000000000..0a34ebe197 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_call_boundary.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %r = call @callee(%ea) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.addf %r, %r + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func private @callee( +// CHECK-SAME: %[[C0:[^:]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[C1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[CM0:.*]] = pto.vadd %[[C0]], %[[C0]] +// CHECK-DAG: %[[CM1:.*]] = pto.vadd %[[C1]], %[[C1]] +// CHECK: return %[[CM0]], %[[CM1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @caller( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-DAG: %[[EA0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[EA1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[R:.*]]:2 = call @callee(%[[EA0]], %[[EA1]]) +// CHECK-SAME: (!pto.vreg<64xf32>, !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[S0:.*]] = pto.vadd %[[R]]#0, %[[R]]#0 +// CHECK-DAG: %[[S1:.*]] = pto.vadd %[[R]]#1, %[[R]]#1 +// CHECK: return %[[S0]], %[[S1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cf_branch.pto b/test/lit/vmi/vmi_to_vpto_cf_branch.pto new file mode 100644 index 0000000000..0a4cf70e1d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cf_branch.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cf_branch( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^then(%a : !pto.vmi.vreg<128xf16>), + ^else(%b : !pto.vmi.vreg<128xf16>) + + ^then(%then_arg: !pto.vmi.vreg<128xf16>): + %then_value = pto.vmi.extf %then_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + cf.br ^join(%then_value : !pto.vmi.vreg<128xf32>) + + ^else(%else_arg: !pto.vmi.vreg<128xf16>): + %else_value = pto.vmi.extf %else_arg + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %else_sum = pto.vmi.addf %else_value, %else_value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + cf.br ^join(%else_sum : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + %sum = pto.vmi.addf %value, %value + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_cf_cond_branch_operands( + %cond: i1, + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + cf.cond_br %cond, ^join(%a : !pto.vmi.vreg<128xf32>), + ^join(%b : !pto.vmi.vreg<128xf32>) + + ^join(%value: !pto.vmi.vreg<128xf32>): + return %value : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cf_branch( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: cf.cond_br %[[COND]], ^[[THEN:.*]], ^[[ELSE:.*]] +// CHECK: ^[[THEN]]: +// CHECK-DAG: %[[THEN_P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[THEN_P1:.*]] = pto.vcvt %[[A]] +// CHECK: cf.br ^[[JOIN:.*]](%[[THEN_P0]], %[[THEN_P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[ELSE]]: +// CHECK-DAG: %[[ELSE_P0:.*]] = pto.vcvt %[[B]] +// CHECK-DAG: %[[ELSE_P1:.*]] = pto.vcvt %[[B]] +// CHECK: cf.br ^[[JOIN]]({{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: pto.vadd +// CHECK-LABEL: func.func @vmi_to_vpto_cf_cond_branch_operands( +// CHECK-SAME: %[[COND2:[^,]+]]: i1 +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[B1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: cf.cond_br %[[COND2]], ^[[CB_JOIN:.*]](%[[A0]], %[[A1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>), ^[[CB_JOIN]](%[[B0]], %[[B1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: ^[[CB_JOIN]](%{{.*}}: !pto.vreg<64xf32>, %{{.*}}: !pto.vreg<64xf32>): +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto new file mode 100644 index 0000000000..4ffb8e384d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge4_contiguous.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge4_contiguous( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch3: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2, %ch3) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_channel_merge4_contiguous( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[E0:.*]], %[[E1:.*]] = pto.vintlv %[[P0]], %[[P2]] +// CHECK: %[[O0:.*]], %[[O1:.*]] = pto.vintlv %[[P1]], %[[P3]] +// CHECK: %[[L0:.*]], %[[L1:.*]] = pto.vintlv %[[E0]], %[[O0]] +// CHECK: %[[H0:.*]], %[[H1:.*]] = pto.vintlv %[[E1]], %[[O1]] +// CHECK: return %[[L0]], %[[L1]], %[[H0]], %[[H1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto new file mode 100644 index 0000000000..8bdc2beb6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_count_unsupported_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_count_unsupported_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch2: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1, %ch2) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto new file mode 100644 index 0000000000..867cbdce65 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_layout_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_layout_invalid( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: 'pto.vmi.channel_merge' op requires layout-assigned channel_merge inputs to be contiguous diff --git a/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto new file mode 100644 index 0000000000..443a5fedae --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_merge_partial_group_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_merge_partial_group_invalid( + %ch0: !pto.vmi.vreg<2xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<2xf32, #pto.vmi.layout>) { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_merge requires every input layout to be contiguous +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto new file mode 100644 index 0000000000..1bc963d400 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_count_unsupported_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_count_unsupported_invalid( + %src: !pto.vmi.vreg<192xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<192xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split supports only 2 or 4 channels diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto new file mode 100644 index 0000000000..55c9ea862e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_layout_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_layout_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: 'pto.vmi.channel_split' op requires layout-assigned channel_split source to be contiguous or deinterleaved by result count diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto new file mode 100644 index 0000000000..10d90d2869 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge.pto @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32> + return %merged : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_channel_split4( + %src: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1, %ch2, %ch3 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1, %ch2, %ch3 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_split_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + return %ch0, %ch1 + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_channel_merge_deint2_identity( + %ch0: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %ch1: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// CHECK: return %[[CH0]], %[[CH1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_channel_split4( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S2:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S3:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vdintlv %[[S2]], %[[S3]] +// CHECK: %[[C0:.*]], %[[C2:.*]] = pto.vdintlv %[[A0]], %[[B0]] +// CHECK: %[[C1:.*]], %[[C3:.*]] = pto.vdintlv %[[A1]], %[[B1]] +// CHECK: return %[[C0]], %[[C1]], %[[C2]], %[[C3]] +// CHECK-LABEL: func.func @vmi_channel_split_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vdintlv +// CHECK: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_channel_merge_deint2_identity( +// CHECK-SAME: %[[M0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NOT: pto.vintlv +// CHECK: return %[[M0]], %[[M1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto new file mode 100644 index 0000000000..25afa0d016 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_merge_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_channel_split_merge2_tail( + %src: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + %merged = "pto.vmi.channel_merge"(%ch0, %ch1) + : (!pto.vmi.vreg<50xf32, #pto.vmi.layout>, + !pto.vmi.vreg<50xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + return %merged : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_channel_split_merge2_tail( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK: %[[CH0:.*]], %[[CH1:.*]] = pto.vdintlv %[[S0]], %[[S1]] +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[CH0]], %[[CH1]] +// CHECK: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto new file mode 100644 index 0000000000..f45b7cdfda --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_channel_split_partial_group_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_channel_split_partial_group_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) { + %ch0, %ch1 = "pto.vmi.channel_split"(%src) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<2xf32, #pto.vmi.layout>, + !pto.vmi.vreg<2xf32, #pto.vmi.layout>) + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.channel_split requires source layout to be contiguous or matching deinterleaved channel layout +// CHECK-SAME: complete physical channel groups +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto new file mode 100644 index 0000000000..100f4b7378 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_element_type_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %mask = pto.vmi.cmpf "lt", %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.cmpf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..8689bc8312 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmp_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "uno", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate uno cannot be lowered to pto.vcmp +// CHECK-SAME: supported predicates are eq/ne/lt/le/gt/ge, ordered FP forms diff --git a/test/lit/vmi/vmi_to_vpto_cmp_select.pto b/test/lit/vmi/vmi_to_vpto_cmp_select.pto new file mode 100644 index 0000000000..816913c8b2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmp_select.pto @@ -0,0 +1,140 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpf_select( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.vmi.cmpf "lt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %selected = pto.vmi.select %mask, %a, %b + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + %p0, %p1 = "pto.vmi.unpack"(%selected) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %m0, %m1, %p0, %p1 + : !pto.mask, !pto.mask, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_cmpi( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ge", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_ordered_predicate( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpi_signed_predicate( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "slt", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_cmpf_bf16( + %a: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpf "oge", %a, %b + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } + + func.func @vmi_to_vpto_cmpi_ui16( + %a: !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) + -> !pto.mask { + %mask = pto.vmi.cmpi "eq", %a, %b + : !pto.vmi.vreg<128xui16, #pto.vmi.layout>, + !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_select( +// CHECK: %[[FM0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM0:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM0]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: %[[FM1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[CM1:.*]] = pto.vcmp {{.*}}, {{.*}}, %[[FM1]], "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM0]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[CM1]] +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_ordered_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_signed_predicate( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK-SAME: !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpf_bf16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "ge" +// CHECK-SAME: !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_cmpi_ui16( +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "eq" +// CHECK-SAME: !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto new file mode 100644 index 0000000000..23b1e7f88f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_cmpi_unsigned_predicate_unsupported_invalid( + %a: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %b: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.cmpi "ult", %a, %b + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %m0, %m1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %m0, %m1 : !pto.mask, !pto.mask + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} compare predicate ult cannot be lowered to pto.vcmp +// CHECK-SAME: signed integer forms slt/sle/sgt/sge diff --git a/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto new file mode 100644 index 0000000000..b4b2af9879 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compaction_deint_invalid.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_active_prefix_index_deint_invalid( + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %idx = pto.vmi.active_prefix_index %mask + : !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.active_prefix_index lowers through pto.vusqz only for one contiguous physical chunk +// CHECK-SAME: requires contiguous mask and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.compress %source, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous source, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_compress_store_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: requires contiguous value and mask layouts diff --git a/test/lit/vmi/vmi_to_vpto_compress.pto b/test/lit/vmi/vmi_to_vpto_compress.pto new file mode 100644 index 0000000000..aba4da0228 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress( +// CHECK: %[[OUT:.*]] = pto.vsqz %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto new file mode 100644 index 0000000000..3122bbb0ee --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_multichunk_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_multichunk_invalid( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_store.pto b/test/lit/vmi/vmi_to_vpto_compress_store.pto new file mode 100644 index 0000000000..edf8565c5f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_compress_store( +// CHECK: %[[BASE:.*]] = pto.addptr %arg1, %arg2 +// CHECK: %[[SQZ:.*]] = pto.vsqz %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ALIGN0:.*]] = pto.init_align : !pto.align +// CHECK: %[[ALIGN1:.*]] = pto.vstur %[[ALIGN0]], %[[SQZ]], %[[BASE]], "POST_UPDATE" : !pto.align, !pto.vreg<64xf32>, !pto.ptr -> !pto.align +// CHECK: pto.vstar %[[ALIGN1]], %[[BASE]] : !pto.align, !pto.ptr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto new file mode 100644 index 0000000000..e4fc4738cc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_store_multichunk_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_store_multichunk_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + pto.vmi.compress_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress_store lowers through pto.vsqz + pto.vstur only for one contiguous full physical chunk +// CHECK-SAME: multi-chunk compress_store needs cross-chunk compaction diff --git a/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto new file mode 100644 index 0000000000..4d97cf831d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_compress_tail_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_compress_tail_invalid( + %src: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.compress %src, %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.compress lowers through pto.vsqz only for one contiguous full physical chunk +// CHECK-SAME: padding mask lanes cannot be squeezed into the result diff --git a/test/lit/vmi/vmi_to_vpto_constant.pto b/test/lit/vmi/vmi_to_vpto_constant.pto new file mode 100644 index 0000000000..c5c93bf2db --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_splat() + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_splat +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P0:.*]] = pto.vdup %[[CST]], %[[M0]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P1:.*]] = pto.vdup %[[CST]], %[[M1]] : f32, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask.pto b/test/lit/vmi/vmi_to_vpto_constant_mask.pto new file mode 100644 index 0000000000..9c38a62148 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask.pto @@ -0,0 +1,128 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b8_all_true() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<512xi1> + } : () -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_b16_all_false() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<256xi1> + } : () -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_plt_fallback() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, true, true, true, true, false, false, false]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } + + func.func @vmi_to_vpto_constant_mask_deinterleaved() + -> (!pto.mask, !pto.mask) { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false, false, true, false, true]> : tensor<8xi1> + } : () -> !pto.vmi.mask<8xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<8xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_true +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_all_false +// CHECK: %[[F0:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: %[[F1:.*]] = pto.pset_b32 "PAT_ALLF" : !pto.mask +// CHECK: return %[[F0]], %[[F1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b8_all_true +// CHECK: %[[B8_0:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_b16_all_false +// CHECK: %[[B16_0:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pset_b16 "PAT_ALLF" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_plt_fallback +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[P0:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_deinterleaved +// CHECK: %[[PART0:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[P4:.*]] = pto.pset_b32 "PAT_VL4" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[PART1:.*]] = pto.pand %[[P4]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[PART0]], %[[PART1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto new file mode 100644 index 0000000000..cc3f439e62 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_nonprefix.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_nonprefix() + -> !pto.mask { + %mask = "pto.vmi.constant_mask"() { + value = dense<[true, false, true, false]> : tensor<4xi1> + } : () -> !pto.vmi.mask<4xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<4xb32, #pto.vmi.layout>) -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_nonprefix +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[RUN0:.*]] = pto.pset_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[P3:.*]] = pto.pset_b32 "PAT_VL3" : !pto.mask +// CHECK: %[[P2:.*]] = pto.pset_b32 "PAT_VL2" : !pto.mask +// CHECK: %[[NOT_P2:.*]] = pto.pnot %[[P2]], %[[ALL]] : !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[RUN1:.*]] = pto.pand %[[P3]], %[[NOT_P2]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[OUT:.*]] = pto.por %[[RUN0]], %[[RUN1]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto new file mode 100644 index 0000000000..3b2fc0d080 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_mask_rematerialize( + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_constant_mask_rematerialize( +// CHECK: %[[M32_0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M32_1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[S16:.*]] = pto.vsel %arg0, %arg1, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg2, %arg4, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg3, %arg5, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto new file mode 100644 index 0000000000..1d9fe8377e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_constant_nonsplat_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_constant_nonsplat_invalid() + -> (!pto.vreg<64xf32>) { + %value = "pto.vmi.constant"() { + value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + } : () -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %p0 : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{.*}}non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan diff --git a/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto new file mode 100644 index 0000000000..1c6fdea4b0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_construction_width_invalid.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_broadcast_f64_unsupported(%scalar: f64) { + %value = pto.vmi.broadcast %scalar + : f64 -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.broadcast direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_constant_f64_unsupported() { + %value = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<32xf64> + } : () -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.constant direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_create_mask.pto b/test/lit/vmi/vmi_to_vpto_create_mask.pto new file mode 100644 index 0000000000..63417a8a99 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask.pto @@ -0,0 +1,87 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 96 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b8_contiguous() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 320 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_b16_deint2() + -> (!pto.mask, !pto.mask) { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_contiguous +// CHECK: %[[M0:.*]] = pto.pge_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[M1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[M0]], %[[M1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_deint2 +// CHECK: %[[P0:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: %[[P1:.*]] = pto.pge_b32 "PAT_VL32" : !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b8_contiguous +// CHECK: %[[B8_0:.*]] = pto.pge_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[B8_1:.*]] = pto.pge_b8 "PAT_VL64" : !pto.mask +// CHECK: return %[[B8_0]], %[[B8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_b16_deint2 +// CHECK: %[[B16_0:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: %[[B16_1:.*]] = pto.pge_b16 "PAT_VL32" : !pto.mask +// CHECK: return %[[B16_0]], %[[B16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto new file mode 100644 index 0000000000..c702d80529 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_dynamic.pto @@ -0,0 +1,132 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_dynamic_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_deint4(%active: index) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2(%active: index) + -> (!pto.mask, !pto.mask) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_contiguous +// CHECK: %[[ACTIVE:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG:.*]] = arith.maxsi %[[ACTIVE]], {{.*}} : i32 +// CHECK: %[[CLAMPED:.*]] = arith.minui %[[NONNEG]], {{.*}} : i32 +// CHECK: %[[P0:.*]], %[[REM:.*]] = pto.plt_b32 %[[CLAMPED]] : i32 -> !pto.mask, i32 +// CHECK: %[[P1:.*]], %{{.*}} = pto.plt_b32 %[[REM]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint2 +// CHECK: %[[ACTIVE2:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG2:.*]] = arith.maxsi %[[ACTIVE2]], {{.*}} : i32 +// CHECK: %[[CLAMPED2:.*]] = arith.minui %[[NONNEG2]], {{.*}} : i32 +// CHECK: %[[BIAS2:.*]] = arith.addi %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[PART0:.*]] = arith.divui %[[BIAS2]], {{.*}} : i32 +// CHECK: %[[P2_0:.*]], %{{.*}} = pto.plt_b32 %[[PART0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART1:.*]] = arith.divui %[[CLAMPED2]], {{.*}} : i32 +// CHECK: %[[P2_1:.*]], %{{.*}} = pto.plt_b32 %[[PART1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P2_0]], %[[P2_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_deint4 +// CHECK: %[[ACTIVE4:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG4:.*]] = arith.maxsi %[[ACTIVE4]], {{.*}} : i32 +// CHECK: %[[CLAMPED4:.*]] = arith.minui %[[NONNEG4]], {{.*}} : i32 +// CHECK: %[[BIAS4_0:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_0:.*]] = arith.divui %[[BIAS4_0]], {{.*}} : i32 +// CHECK: %[[P4_0:.*]], %{{.*}} = pto.plt_b32 %[[PART4_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_1:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_1:.*]] = arith.divui %[[BIAS4_1]], {{.*}} : i32 +// CHECK: %[[P4_1:.*]], %{{.*}} = pto.plt_b32 %[[PART4_1]] : i32 -> !pto.mask, i32 +// CHECK: %[[BIAS4_2:.*]] = arith.addi %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[PART4_2:.*]] = arith.divui %[[BIAS4_2]], {{.*}} : i32 +// CHECK: %[[P4_2:.*]], %{{.*}} = pto.plt_b32 %[[PART4_2]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART4_3:.*]] = arith.divui %[[CLAMPED4]], {{.*}} : i32 +// CHECK: %[[P4_3:.*]], %{{.*}} = pto.plt_b32 %[[PART4_3]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P4_0]], %[[P4_1]], %[[P4_2]], %[[P4_3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b8_contiguous +// CHECK: %[[ACTIVE8:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG8:.*]] = arith.maxsi %[[ACTIVE8]], {{.*}} : i32 +// CHECK: %[[CLAMPED8:.*]] = arith.minui %[[NONNEG8]], {{.*}} : i32 +// CHECK: %[[P8_0:.*]], %[[REM8:.*]] = pto.plt_b8 %[[CLAMPED8]] : i32 -> !pto.mask, i32 +// CHECK: %[[P8_1:.*]], %{{.*}} = pto.plt_b8 %[[REM8]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P8_0]], %[[P8_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_dynamic_b16_deint2 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMPED16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[BIAS16:.*]] = arith.addi %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[PART16_0:.*]] = arith.divui %[[BIAS16]], {{.*}} : i32 +// CHECK: %[[P16_0:.*]], %{{.*}} = pto.plt_b16 %[[PART16_0]] : i32 -> !pto.mask, i32 +// CHECK: %[[PART16_1:.*]] = arith.divui %[[CLAMPED16]], {{.*}} : i32 +// CHECK: %[[P16_1:.*]], %{{.*}} = pto.plt_b16 %[[PART16_1]] : i32 -> !pto.mask, i32 +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto new file mode 100644 index 0000000000..8cd9cd051c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_plt_fallback.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_plt_fallback() + -> !pto.mask { + %active = arith.constant 5 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %p0 = "pto.vmi.unpack"(%mask) + : (!pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.mask + return %p0 : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_plt_fallback( +// CHECK: %[[C5:.*]] = arith.constant 5 : i32 +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C5]] : i32 -> !pto.mask, i32 +// CHECK: return %[[MASK]] : !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto new file mode 100644 index 0000000000..74ef8194d5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_create_mask_rematerialize( + %active: index, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %mask = pto.vmi.create_mask %active : index -> !pto.vmi.mask<128xpred> + %sel16 = pto.vmi.select %mask, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %mask, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_create_mask_rematerialize( +// CHECK: %[[ACTIVE32:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG32:.*]] = arith.maxsi %[[ACTIVE32]], {{.*}} : i32 +// CHECK: %[[CLAMP32:.*]] = arith.minui %[[NONNEG32]], {{.*}} : i32 +// CHECK: %[[M32_0:.*]], %[[REM32:.*]] = pto.plt_b32 %[[CLAMP32]] : i32 -> !pto.mask, i32 +// CHECK: %[[M32_1:.*]], %{{.*}} = pto.plt_b32 %[[REM32]] : i32 -> !pto.mask, i32 +// CHECK: %[[ACTIVE16:.*]] = arith.index_cast %arg0 : index to i32 +// CHECK: %[[NONNEG16:.*]] = arith.maxsi %[[ACTIVE16]], {{.*}} : i32 +// CHECK: %[[CLAMP16:.*]] = arith.minui %[[NONNEG16]], {{.*}} : i32 +// CHECK: %[[M16:.*]], %{{.*}} = pto.plt_b16 %[[CLAMP16]] : i32 -> !pto.mask, i32 +// CHECK: %[[S16:.*]] = pto.vsel %arg1, %arg2, %[[M16]] +// CHECK: %[[S32_0:.*]] = pto.vsel %arg3, %arg5, %[[M32_0]] +// CHECK: %[[S32_1:.*]] = pto.vsel %arg4, %arg6, %[[M32_1]] +// CHECK: return %[[S16]], %[[S32_0]], %[[S32_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_divf.pto b/test/lit/vmi/vmi_to_vpto_divf.pto new file mode 100644 index 0000000000..be21ba5fdc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_divf.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_divf( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %quotient = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %quotient : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_divf( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[DIV0:.*]] = pto.vdiv %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[DIV1:.*]] = pto.vdiv %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DIV0]], %[[DIV1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto new file mode 100644 index 0000000000..f88f15a8eb --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_e2e_widen_add_store.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_f16_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_f8_widen_add_store( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %narrow = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + %wide = pto.vmi.extf %narrow + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %wide, %wide + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %sum, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_f16_widen_add_store( +// CHECK: %[[NARROW:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: %[[CVT_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "EVEN"} +// CHECK: %[[ODD:.*]] = pto.vcvt %[[NARROW]], %[[CVT_MASK]] {part = "ODD"} +// CHECK: %[[ADD_MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM0:.*]] = pto.vadd %[[EVEN]], %[[EVEN]], %[[ADD_MASK0]] +// CHECK: %[[ADD_MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[SUM1:.*]] = pto.vadd %[[ODD]], %[[ODD]], %[[ADD_MASK1]] +// CHECK: %[[STORE_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vstsx2 %[[SUM0]], %[[SUM1]], %arg1[%arg2], "INTLV_B32", %[[STORE_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_f8_widen_add_store( +// CHECK: %[[NARROW8:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P0"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P1"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P2"} +// CHECK: pto.vcvt %[[NARROW8]], {{.*}} {part = "P3"} +// CHECK-COUNT-4: pto.vadd +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK-COUNT-4: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto new file mode 100644 index 0000000000..958e6f1f5a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_elementwise_width_invalid.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_addf_f64_unsupported( + %a: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %b: !pto.vmi.vreg<32xf64, #pto.vmi.layout>) { + %sum = pto.vmi.addf %a, %b + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + !pto.vmi.vreg<32xf64, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf64, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type and physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_addi_index_unsupported( + %a: !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + %b: !pto.vmi.vreg<64xindex, #pto.vmi.layout>) { + %sum = pto.vmi.addi %a, %b + : !pto.vmi.vreg<64xindex, #pto.vmi.layout>, + !pto.vmi.vreg<64xindex, #pto.vmi.layout> + -> !pto.vmi.vreg<64xindex, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addi direct lowering requires physical vreg parts with b8/b16/b32 predicate masks +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_ensure_identity.pto b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto new file mode 100644 index 0000000000..783bc3428d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_identity.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_identity( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %em0 = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %em1 = "pto.vmi.ensure_mask_granularity"(%em0) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em1) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_ensure_identity_tail( + %v: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask) { + %ev = "pto.vmi.ensure_layout"(%v) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %em = "pto.vmi.ensure_mask_layout"(%m) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%ev) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + %pm0, %pm1 = "pto.vmi.unpack"(%em) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1, %pm0, %pm1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_identity_tail( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK-NOT: pto.vmi.ensure +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto new file mode 100644 index 0000000000..4c15af9f19 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto new file mode 100644 index 0000000000..bbfd4dcfd8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_partial_invalid( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_layout cannot materialize the requested data layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto new file mode 100644 index 0000000000..03661ac669 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vdintlv.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto new file mode 100644 index 0000000000..e4506c86c2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_vintlv.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( + %input: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( + %input: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %dense = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %arg0, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto new file mode 100644 index 0000000000..989fd0cb74 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity( + %m: !pto.vmi.mask<128xpred>, + %a16: !pto.vmi.vreg<128xf16>, + %b16: !pto.vmi.vreg<128xf16>, + %a32: !pto.vmi.vreg<128xf32>, + %b32: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32>) { + %sel16 = pto.vmi.select %m, %a16, %b16 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<128xf16> + %sel32 = pto.vmi.select %m, %a32, %b32 + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sel16, %sel32 + : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity( +// CHECK: %[[LO:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO]], %[[HI]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsel %arg2, %arg3, %[[M16]] +// CHECK: pto.vsel %arg4, %arg6, %arg0 +// CHECK: pto.vsel %arg5, %arg7, %arg1 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto new file mode 100644 index 0000000000..2512367b64 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_direct.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_direct( + %m: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_direct( +// CHECK: %[[P0:.*]] = pto.punpack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[P1:.*]] = pto.punpack %arg0, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto new file mode 100644 index 0000000000..29bb147489 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_granularity_multistep.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_granularity_multistep( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.mask { + %result = pto.vmi.ensure_mask_granularity %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb8, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%result) + : (!pto.vmi.mask<128xb8, #pto.vmi.layout>) + -> !pto.mask + return %part : !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_mask_granularity_multistep( +// CHECK: %[[LO16:.*]] = pto.ppack %arg0, "LOWER" : !pto.mask -> !pto.mask +// CHECK: %[[HI16:.*]] = pto.ppack %arg1, "HIGHER" : !pto.mask -> !pto.mask +// CHECK: %[[ALL16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[M16:.*]] = pto.por %[[LO16]], %[[HI16]], %[[ALL16]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: %[[M8:.*]] = pto.ppack %[[M16]], "LOWER" : !pto.mask -> !pto.mask +// CHECK: return %[[M8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto new file mode 100644 index 0000000000..17a644834b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout.pto @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_deint2_to_contiguous( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint2( + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( + %m: !pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<100xb32, #pto.vmi.layout> + -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<100xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_deint4_to_contiguous( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_contiguous_to_deint4( + %m: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) + return %p0, %p1, %p2, %p3 + : !pto.mask, !pto.mask, !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint2_tail_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_deint4_to_contiguous( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.pintlv_b32 %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.pintlv_b32 %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b32 %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b32 %[[A1]], %[[B1]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]], %[[D3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_contiguous_to_deint4( +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.pdintlv_b32 %arg0, %arg1 +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.pdintlv_b32 %arg2, %arg3 +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.pdintlv_b32 %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.pdintlv_b32 %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto new file mode 100644 index 0000000000..87edcee933 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_partial_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_ensure_mask_layout_partial_invalid( + %input: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %dense = pto.vmi.ensure_mask_layout %input + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ensure_mask_layout cannot materialize the requested mask layout conversion +// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-SAME: partial/tail predicate layout materialization requires an explicit packing plan diff --git a/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto new file mode 100644 index 0000000000..0c8b9a4120 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_ensure_mask_layout_widths.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( + %m: !pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<512xb8, #pto.vmi.layout> + -> !pto.vmi.mask<512xb8, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<512xb8, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %dense = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%dense) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } + + func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( + %m: !pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) { + %deint = pto.vmi.ensure_mask_layout %m + : !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.mask<256xb16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%deint) + : (!pto.vmi.mask<256xb16, #pto.vmi.layout>) + -> (!pto.mask, !pto.mask) + return %p0, %p1 : !pto.mask, !pto.mask + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_deint2_to_contiguous( +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.pintlv_b8 %arg0, %arg1 +// CHECK: return %[[D0]], %[[D1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b8_contiguous_to_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.pdintlv_b8 %arg0, %arg1 +// CHECK: return %[[P0]], %[[P1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_deint2_to_contiguous( +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.pintlv_b16 %arg0, %arg1 +// CHECK: return %[[D2]], %[[D3]] + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_b16_contiguous_to_deint2( +// CHECK: %[[P2:.*]], %[[P3:.*]] = pto.pdintlv_b16 %arg0, %arg1 +// CHECK: return %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto new file mode 100644 index 0000000000..836e33dec0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %active = arith.constant 64 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<64xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[LOAD]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_all_active_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto new file mode 100644 index 0000000000..d8733b4641 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_all_active_negative_offset_invalid( + %src: memref<132xf32>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 100 : index + %offset = arith.constant -1 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<100xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: all-active path requires full physical chunks or statically safe full-read footprint +// CHECK-SAME: safe-read proof requires non-negative offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto new file mode 100644 index 0000000000..cdab169262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_partial_mask_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_partial_mask_invalid( + %src: !pto.ptr, + %offset: index, + %passthru: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %active = arith.constant 4 : index + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %part0, %part1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %part0, %part1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: one physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto new file mode 100644 index 0000000000..7c9d8a3a5b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_expand_load_runtime_mask.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_expand_load_runtime_mask( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_expand_load_runtime_mask( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK-DAG: %[[BASE:.*]] = pto.addptr %arg0, %arg1 +// CHECK: %[[CARRIER:.*]] = pto.vdup %[[ZERO]], %[[ALL]] : i32, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[IDX:.*]] = pto.vusqz %[[CARRIER]], %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[LOAD:.*]] = pto.vgather2_bc %[[BASE]], %[[IDX]], %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf.pto b/test/lit/vmi/vmi_to_vpto_extf.pto new file mode 100644 index 0000000000..af4fbca903 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f16_to_f32( + %input: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f16_tail_to_f32( + %input: !pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_bf16_to_f32( + %input: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f16_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<128xf16> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_bf16_to_f32( +// CHECK-SAME: %[[BF16_INPUT:.*]]: !pto.vreg<128xbf16> +// CHECK: %[[BF16_MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "EVEN"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[BF16_INPUT]], %[[BF16_MASK]] {part = "ODD"} : !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_f8.pto b/test/lit/vmi/vmi_to_vpto_extf_f8.pto new file mode 100644 index 0000000000..c9ab157d0d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_f8.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_f8_to_f32( + %input: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_extf_f8_tail_to_f32( + %input: !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<100xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_to_f32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_f8_tail_to_f32( +// CHECK-SAME: %[[TAIL_INPUT:.*]]: !pto.vreg<256xf8E4M3FN> +// CHECK: %[[TAIL_MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt %[[TAIL_INPUT]], %[[TAIL_MASK]] {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto new file mode 100644 index 0000000000..0803ccde1c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_extf_multichunk.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extf_multichunk( + %input: !pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %wide = pto.vmi.extf %input + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extf_multichunk( +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[EVEN0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[EVEN1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD0:.*]] = pto.vcvt %arg0, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ODD1:.*]] = pto.vcvt %arg1, %[[MASK]] {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[EVEN0]], %[[EVEN1]], %[[ODD0]], %[[ODD1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma.pto b/test/lit/vmi/vmi_to_vpto_fma.pto new file mode 100644 index 0000000000..d222c1ecb9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_fma( + %lhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_fma_f16( + %lhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_fma_bf16( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xbf16, #pto.vmi.layout>) + -> !pto.vreg<128xbf16> + return %part : !pto.vreg<128xbf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_fma( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_f16( +// CHECK: %[[MASK16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUT16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASK16]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fma_bf16( +// CHECK: %[[MASKBF16:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: %[[OUTBF16:.*]] = pto.vmula %arg2, %arg0, %arg1, %[[MASKBF16]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.vreg<128xbf16>, !pto.mask -> !pto.vreg<128xbf16> +// CHECK: return %[[OUTBF16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto new file mode 100644 index 0000000000..877568258b --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_fma_element_type_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_fma_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.fma %lhs, %rhs, %acc + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.fma lowers through pto.vmula only for f16/bf16/f32 element types +// CHECK-SAME: requires f16, bf16, or f32 element type for pto.vmula diff --git a/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto new file mode 100644 index 0000000000..0fedf0d694 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_function_type_layout_free_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func private @external(!pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> +} + +// CHECK: VMI-PASS-INVARIANT: vmi-to-vpto requires layout-assigned VMI types diff --git a/test/lit/vmi/vmi_to_vpto_gather.pto b/test/lit/vmi/vmi_to_vpto_gather.pto new file mode 100644 index 0000000000..d68e72c1d2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_gather( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_gather( +// CHECK: %[[GATHER:.*]] = pto.vgather2_bc %arg0, %arg1, %arg2 : !pto.ptr, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[GATHER]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto new file mode 100644 index 0000000000..83bf5db675 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_f16_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: 32-bit result element type diff --git a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto new file mode 100644 index 0000000000..c271e9f446 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_deint_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: contiguous result, indices, passthru, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_gather_tail_invalid( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<32xf32, #pto.vmi.layout>) { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout>, + !pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only +// CHECK-SAME: result requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_scatter_deint_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: contiguous value, indices, and mask layouts + +// ----- + +module { + func.func @vmi_to_vpto_scatter_tail_invalid( + %value: !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only +// CHECK-SAME: value requires full physical chunks +// CHECK-SAME: found padding lane in physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_iota.pto b/test/lit/vmi/vmi_to_vpto_iota.pto new file mode 100644 index 0000000000..a46f767b59 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota.pto @@ -0,0 +1,120 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_i32_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_desc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base {order = "DESC"} + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i32_deint2_asc(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_i16_asc(%base: i16) + -> !pto.vreg<128xi16> { + %value = pto.vmi.iota %base + : i16 -> !pto.vmi.vreg<128xi16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xi16, #pto.vmi.layout>) + -> !pto.vreg<128xi16> + return %part : !pto.vreg<128xi16> + } + + func.func @vmi_to_vpto_iota_f16_deint2_asc(%base: f16) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.iota %base + : f16 -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_asc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_desc( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.subi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i32_deint2_asc( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[FACTOR:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[PART1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LOCAL0:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED0:.*]] = pto.vmuls %[[LOCAL0]], %[[FACTOR]], +// CHECK: %[[P0:.*]] = pto.vadds %[[SCALED0]], %arg0, +// CHECK: %[[LOCAL1:.*]] = pto.vci %[[ZERO]] : i32 -> !pto.vreg<64xi32> +// CHECK: %[[SCALED1:.*]] = pto.vmuls %[[LOCAL1]], %[[FACTOR]], +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[PART1]] : i32 +// CHECK: %[[P1:.*]] = pto.vadds %[[SCALED1]], %[[BASE1]], +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_i16_asc( +// CHECK: %[[P16:.*]] = pto.vci %arg0 : i16 -> !pto.vreg<128xi16> +// CHECK: return %[[P16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_f16_deint2_asc( +// CHECK-DAG: %[[ZERO16:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[FACTOR16:.*]] = arith.constant 2.000000e+00 : f16 +// CHECK-DAG: %[[PART16_1:.*]] = arith.constant 1.000000e+00 : f16 +// CHECK: %[[LOCAL16_0:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_0:.*]] = pto.vmuls %[[LOCAL16_0]], %[[FACTOR16]], +// CHECK: %[[P16_0:.*]] = pto.vadds %[[SCALED16_0]], %arg0, +// CHECK: %[[LOCAL16_1:.*]] = pto.vci %[[ZERO16]] : f16 -> !pto.vreg<128xf16> +// CHECK: %[[SCALED16_1:.*]] = pto.vmuls %[[LOCAL16_1]], %[[FACTOR16]], +// CHECK: %[[BASE16_1:.*]] = arith.addf %arg0, %[[PART16_1]] : f16 +// CHECK: %[[P16_1:.*]] = pto.vadds %[[SCALED16_1]], %[[BASE16_1]], +// CHECK: return %[[P16_0]], %[[P16_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_iota_tail.pto b/test/lit/vmi/vmi_to_vpto_iota_tail.pto new file mode 100644 index 0000000000..7ba8a31f11 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_iota_tail.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_iota_contiguous_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<100xi32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1 : !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_iota_deint2_tail(%base: i32) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %value = pto.vmi.iota %base + : i32 -> !pto.vmi.vreg<130xi32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<130xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_contiguous_tail( +// CHECK: %[[C64:.*]] = arith.constant 64 : i32 +// CHECK: %[[P0:.*]] = pto.vci %arg0 : i32 -> !pto.vreg<64xi32> +// CHECK: %[[B1:.*]] = arith.addi %arg0, %[[C64]] : i32 +// CHECK: %[[P1:.*]] = pto.vci %[[B1]] : i32 -> !pto.vreg<64xi32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_iota_deint2_tail( +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : i32 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C129:.*]] = arith.constant 129 : i32 +// CHECK: %[[BASE128:.*]] = arith.addi %arg0, %[[C128]] : i32 +// CHECK: %[[BASE1:.*]] = arith.addi %arg0, %[[C1]] : i32 +// CHECK: %[[BASE129:.*]] = arith.addi %arg0, %[[C129]] : i32 +// CHECK: return {{.*}}, {{.*}}, {{.*}}, {{.*}} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint.pto b/test/lit/vmi/vmi_to_vpto_load_deint.pto new file mode 100644 index 0000000000..715dacdfa6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_deint4(%src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4( +// CHECK: %[[D0:.*]] = pto.vlds %arg0[%arg1] +// CHECK: %[[D1:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[D2:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[D3:.*]] = pto.vlds %arg0[{{.*}}] +// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %[[D2]], %[[D3]] +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto new file mode 100644 index 0000000000..433f222af3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_multichunk( +// CHECK: %[[P0_0:.*]], %[[P1_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto new file mode 100644 index 0000000000..f87e3753ca --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_nonfull_invalid( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: requires constant index offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback +// CHECK-SAME: scratch memory fallback resource allocation is not implemented +// CHECK-SAME: guarded memory fallback control-flow lowering is not implemented diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto new file mode 100644 index 0000000000..157f57f84d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto @@ -0,0 +1,73 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %value = pto.vmi.load %src[%c4] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_tile_read_safe_tail_memref(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref( +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_safe_tail_memref( +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto new file mode 100644 index 0000000000..07975ea70d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref_invalid(%src: memref<100xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<100xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: full physical read footprint [0, 128) exceeds static memref element count 100 diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto new file mode 100644 index 0000000000..863c2b4fa5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid(%src: memref<132xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %value = pto.vmi.load %src[%cm1] + : memref<132xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: safe-read proof failed: requires non-negative offset diff --git a/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto new file mode 100644 index 0000000000..891cb20567 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_load_store_contiguous.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_store_contiguous( + %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_store_contiguous( +// CHECK: %[[C64_LOAD:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OFF1_LOAD:.*]] = arith.addi %arg2, %[[C64_LOAD]] : index +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[OFF1_LOAD]]] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[M0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L0]], %arg1[%arg2], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[M1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_mask_logic.pto b/test/lit/vmi/vmi_to_vpto_mask_logic.pto new file mode 100644 index 0000000000..cf220cc6a4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_mask_logic.pto @@ -0,0 +1,126 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_mask_logic( + %a: !pto.vmi.vreg<128xf32>, + %b: !pto.vmi.vreg<128xf32>, + %c: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>) { + %lt = pto.vmi.cmpf "olt", %a, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %gt = pto.vmi.cmpf "ogt", %a, %c + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %and = pto.vmi.mask_and %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %or = pto.vmi.mask_or %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %xor = pto.vmi.mask_xor %lt, %gt + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + -> !pto.vmi.mask<128xpred> + %not = pto.vmi.mask_not %lt + : !pto.vmi.mask<128xpred> -> !pto.vmi.mask<128xpred> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred>, + !pto.vmi.mask<128xpred>, !pto.vmi.mask<128xpred> + } + + func.func @vmi_to_vpto_mask_logic_b8( + %lhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.mask<256xb8, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + } + + func.func @vmi_to_vpto_mask_logic_b16( + %lhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %and = pto.vmi.mask_and %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %or = pto.vmi.mask_or %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %xor = pto.vmi.mask_xor %lhs, %rhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %not = pto.vmi.mask_not %lhs + : !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %and, %or, %xor, %not + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND0:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[AND1:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR0:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR1:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR0:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR1:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT0:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT1:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND0]], %[[AND1]], %[[OR0]], %[[OR1]], %[[XOR0]], %[[XOR1]], %[[NOT0]], %[[NOT1]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b8( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B8:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B8:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B8:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B8:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B8]], %[[OR_B8]], %[[XOR_B8]], %[[NOT_B8]] +// CHECK-LABEL: func.func @vmi_to_vpto_mask_logic_b16( +// CHECK-SAME: -> (!pto.mask, !pto.mask, !pto.mask, !pto.mask) +// CHECK-DAG: %[[AND_B16:.*]] = pto.pand {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[OR_B16:.*]] = pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[XOR_B16:.*]] = pto.pxor {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK-DAG: %[[NOT_B16:.*]] = pto.pnot {{.*}} : !pto.mask, !pto.mask -> !pto.mask +// CHECK: return %[[AND_B16]], %[[OR_B16]], %[[XOR_B16]], %[[NOT_B16]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load.pto b/test/lit/vmi/vmi_to_vpto_masked_load.pto new file mode 100644 index 0000000000..bc46f591ac --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load( +// CHECK: %[[LOAD:.*]] = pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vsel %[[LOAD]], %arg3, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto new file mode 100644 index 0000000000..9b79049c1a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_nonfull_invalid( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout>, + !pto.vmi.vreg<4xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<4xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires constant index offset +// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback +// CHECK-SAME: target true masked/non-faulting load is unavailable diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto new file mode 100644 index 0000000000..d4b9f23c23 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref( + %src: memref<128xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c0 = arith.constant 0 : index + %out = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c4 = arith.constant 4 : index + %out = pto.vmi.masked_load %src[%c4], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_load_safe_tail_memref_nonzero_offset( +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C68:.*]] = arith.constant 68 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[C4]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O0:.*]] = pto.vsel %[[L0]], %arg3, %arg1 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C68]]] : memref<132xf32> -> !pto.vreg<64xf32> +// CHECK: %[[O1:.*]] = pto.vsel %[[L1]], %arg4, %arg2 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[O0]], %[[O1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto new file mode 100644 index 0000000000..ab22618d3e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid( + %src: memref<132xf32>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %cm1 = arith.constant -1 : index + %out = pto.vmi.masked_load %src[%cm1], %mask, %passthru + : memref<132xf32>, + !pto.vmi.mask<100xb32, #pto.vmi.layout>, + !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: safe-read proof requires non-negative offset diff --git a/test/lit/vmi/vmi_to_vpto_masked_store.pto b/test/lit/vmi/vmi_to_vpto_masked_store.pto new file mode 100644 index 0000000000..01e8d53d89 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_contiguous( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_contiguous( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: %[[OFF1:.*]] = arith.addi %[[OFF]], %[[C64]] : index +// CHECK: pto.vsts %[[V1]], %[[DST]][%[[OFF1]]], %[[M1]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto new file mode 100644 index 0000000000..e874e8d90d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_deint_tail.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<4xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<4xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// CHECK: %[[USER:.*]], %{{.*}} = pto.pintlv_b32 %[[M0]], %[[M1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK:.*]] = pto.pand %[[USER]], %[[TAIL]], %[[ALL]] +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto new file mode 100644 index 0000000000..375f44c894 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_nonfull_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_nonfull_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<129xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<129xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto new file mode 100644 index 0000000000..361277c4fd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_masked_store_tail.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_masked_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<100xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<100xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_masked_store_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[M0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[M1:[^,]+]]: !pto.mask +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[OFF]]], %[[M0]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[COMBINED:.*]] = pto.pand %[[M1]], %[[TAIL]], %[[ALL]] : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[COMBINED]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto new file mode 100644 index 0000000000..1102d992ef --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_math_element_type_invalid.pto @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_addf_f8_invalid( + %lhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.addf %lhs, %rhs + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.addf direct lowering requires f16/bf16/f32 element type +// CHECK-SAME: requires f16/bf16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_divf_bf16_invalid( + %lhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.divf %lhs, %rhs + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout>, + !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.divf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_sqrt_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.sqrt %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.sqrt direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_exp_f8_invalid( + %source: !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %out = pto.vmi.exp %source + : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.exp direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_negf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.negf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.negf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_ln_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.ln %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.ln direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absf_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %out = pto.vmi.absf %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absf direct lowering requires f16/f32 element type +// CHECK-SAME: requires f16/f32 element type for direct VPTO lowering + +// ----- + +module { + func.func @vmi_to_vpto_absi_unsigned_invalid( + %source: !pto.vmi.vreg<128xui16, #pto.vmi.layout>) { + %out = pto.vmi.absi %source + : !pto.vmi.vreg<128xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.absi direct lowering requires signless/signed i8/i16/i32 element type +// CHECK-SAME: requires signless/signed i8/i16/i32 element type for direct VPTO lowering diff --git a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto new file mode 100644 index 0000000000..7a222a35ad --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto @@ -0,0 +1,130 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_gm_unsupported(%src: !pto.ptr, %offset: index) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_masked_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_expand_load_gm_unsupported( + %src: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %offset: index) { + %value = pto.vmi.expand_load %src[%offset], %mask, %passthru + : !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed + +// ----- + +module { + func.func @vmi_masked_store_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.masked_store %value, %dst[%offset], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed + +// ----- + +module { + func.func @vmi_tile_read_gm_unsupported( + %src: memref<64xf32, #pto.address_space>) { + %value = pto.vmi.tile_read %src + : memref<64xf32, #pto.address_space> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source is GM-backed +// CHECK-SAME: requires UB-backed memory + +// ----- + +module { + func.func @vmi_tile_write_gm_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<64xf32, #pto.address_space>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<64xf32, #pto.address_space> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: with UB-backed destination +// CHECK-SAME: destination is GM-backed diff --git a/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto new file mode 100644 index 0000000000..98d92a3262 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memory_x2_widths.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_load_deint2_f16( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_store_deint2_i8( + %value: !pto.vmi.vreg<512xi8, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<512xi8, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_f16( +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B16" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_i8( +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B8", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto new file mode 100644 index 0000000000..489891c72a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto @@ -0,0 +1,177 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] + : memref<128xf32, strided<[2], offset: 0>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_load_memref_subview_unsupported(%src: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %src[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + %value = pto.vmi.load %view[%c0] + : memref<64xf32, strided<[1], offset: ?>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.masked_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load direct lowering requires a supported memory source, contiguous result/passthru/mask layouts +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_expand_load_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %value = pto.vmi.expand_load %src[%c0], %mask, %passthru + : memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.expand_load direct lowering is currently supported +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.store %value, %dst[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_store_memref_subview_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32>) { + %c0 = arith.constant 0 : index + %view = memref.subview %dst[%c0] [64] [1] + : memref<128xf32> to memref<64xf32, strided<[1], offset: ?>> + pto.vmi.store %value, %view[%c0] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<64xf32, strided<[1], offset: ?>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps +// CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning + +// ----- + +module { + func.func @vmi_masked_store_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + %c0 = arith.constant 0 : index + pto.vmi.masked_store %value, %dst[%c0], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_tile_read_strided_memref_unsupported( + %src: memref<128xf32, strided<[2], offset: 0>>) { + %value = pto.vmi.tile_read %src + : memref<128xf32, strided<[2], offset: 0>> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK-SAME: source memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps + +// ----- + +module { + func.func @vmi_tile_write_strided_memref_unsupported( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: memref<128xf32, strided<[2], offset: 0>>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + memref<128xf32, strided<[2], offset: 0>> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: destination memref layout is non-identity +// CHECK-SAME: contiguous identity lane-to-address maps diff --git a/test/lit/vmi/vmi_to_vpto_min_max.pto b/test/lit/vmi/vmi_to_vpto_min_max.pto new file mode 100644 index 0000000000..eeefc6ee94 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_min_max.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_min_max( + %lhs: !pto.vmi.vreg<128xf32>, + %rhs: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %min = pto.vmi.minf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %max = pto.vmi.maxf %lhs, %rhs + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %min, %max : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_min_max( +// CHECK-SAME: %[[LHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[LHS1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[RHS1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[MIN0:.*]] = pto.vmin %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MIN1:.*]] = pto.vmin %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX0:.*]] = pto.vmax %[[LHS0]], %[[RHS0]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MAX1:.*]] = pto.vmax %[[LHS1]], %[[RHS1]], {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[MIN0]], %[[MIN1]], %[[MAX0]], %[[MAX1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_negf.pto b/test/lit/vmi/vmi_to_vpto_negf.pto new file mode 100644 index 0000000000..1aafa02c9a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_negf.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_negf(%a: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %neg = pto.vmi.negf %a + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_negf( +// CHECK-SAME: %[[A0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[A1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[A0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[A1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_pack_unpack.pto b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto new file mode 100644 index 0000000000..e4caa3cc0a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_pack_unpack.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unpack( + %v: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %p0, %p1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_pack_unpack( + %p0: !pto.vreg<64xf32>, + %p1: !pto.vreg<64xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %v = "pto.vmi.pack"(%p0, %p1) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %q0, %q1 = "pto.vmi.unpack"(%v) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %q0, %q1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unpack( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-LABEL: func.func @vmi_to_vpto_pack_unpack( +// CHECK: return +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi.pack +// CHECK-NOT: pto.vmi.unpack +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto new file mode 100644 index 0000000000..7d302805d6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -0,0 +1,310 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_f16( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c128 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<128xf16>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c128 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + } + return + } + + func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %dequant, %dst[%dst_offset] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %dequant = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %dequant, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } + + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %rows: index, + %full_blocks: index, + %tail: index, + %src_stride: index, + %dst_stride: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %has_tail = arith.cmpi ne, %tail, %c0 : index + scf.for %row = %c0 to %rows step %c1 { + %src_row = arith.muli %row, %src_stride : index + %dst_row = arith.muli %row, %dst_stride : index + scf.for %block = %c0 to %full_blocks step %c1 { + %block_offset = arith.muli %block, %c256 : index + %src_offset = arith.addi %src_row, %block_offset : index + %dst_offset = arith.addi %dst_row, %block_offset : index + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%dst_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + scf.if %has_tail { + %tail_offset = arith.muli %full_blocks, %c256 : index + %src_offset = arith.addi %src_row, %tail_offset : index + %dst_offset = arith.addi %dst_row, %tail_offset : index + %tail_mask = pto.vmi.create_mask %tail + : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %src[%src_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %dst[%dst_offset], %tail_mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, + !pto.vmi.mask<256xpred> + } + } + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_f16_to_f32( +// CHECK-SAME: %[[DSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[DDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_f16( +// CHECK-SAME: %[[QSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[INV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: scf.if +// CHECK: pto.plt_b16 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( +// CHECK-SAME: %[[FSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FSCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P0"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P1"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P2"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P3"} : !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vintlv +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.plt_b32 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK-SAME: %[[FQSRC:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[FINV_SCALE:[^,]+]]: f32 +// CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr +// CHECK: scf.for +// CHECK: scf.for +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: scf.if +// CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto new file mode 100644 index 0000000000..c44de2ec84 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( + %src: !pto.ptr, + %inv_scale: f32, + %dst: !pto.ptr, + %offset: index) { + %wide = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %inv_scale + : f32 -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %dst[%offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto new file mode 100644 index 0000000000..6f2fadfdba --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto new file mode 100644 index 0000000000..4e24ee12a8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_f16_invalid( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only with reassoc +// CHECK-SAME: currently supports only f32 elements diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto new file mode 100644 index 0000000000..0389c17e25 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto new file mode 100644 index 0000000000..fd6c461b2c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi( + %source: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..466374c65c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_i16_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_i16_invalid( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + !pto.vmi.vreg<1xi16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: currently supports only 32-bit integer elements diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto new file mode 100644 index 0000000000..8275a80790 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_addi_multichunk.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_multichunk( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addi_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcadd %arg0, %arg3 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC0:.*]] = pto.vadd %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[RED1:.*]] = pto.vcadd %arg1, %arg4 : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ACC1:.*]] = pto.vadd %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto new file mode 100644 index 0000000000..51782e8462 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_multichunk.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_reduce_minf_multichunk( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_maxf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmax %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmax %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmax %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmax %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf_multichunk( +// CHECK: %[[FIRST:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[RED0:.*]] = pto.vcmin %arg0, %arg3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC0:.*]] = pto.vmin %[[RED0]], %arg2, %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[RED1:.*]] = pto.vcmin %arg1, %arg4 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[ACC1:.*]] = pto.vmin %[[RED1]], %[[ACC0]], %[[FIRST]] : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ACC1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto new file mode 100644 index 0000000000..a926b48d70 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_maxf_tail_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_maxf_tail_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<65xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<65xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<65xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI{{.*}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires full source physical chunks diff --git a/test/lit/vmi/vmi_to_vpto_reduce_minf.pto b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto new file mode 100644 index 0000000000..96a70a03f3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_minf.pto @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_minf( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_minf( +// CHECK: %[[FIRST:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcmin %arg0, %arg2 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vmin %[[REDUCED]], %arg1, %[[FIRST]] : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto new file mode 100644 index 0000000000..1b2cf33ffa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_reduce_shape_invalid.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_reduce_addi_tail_invalid( + %source: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addi %source, %init, %mask + : !pto.vmi.vreg<32xi32, #pto.vmi.layout>, + !pto.vmi.vreg<1xi32, #pto.vmi.layout>, + !pto.vmi.mask<32xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addi lowers through pto.vcadd only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_addf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts + +// ----- + +module { + func.func @vmi_to_vpto_reduce_minf_tail_invalid( + %source: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb16, #pto.vmi.layout>) { + %out = pto.vmi.reduce_minf %source, %init, %mask + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.vmi.vreg<1xf16, #pto.vmi.layout>, + !pto.vmi.mask<64xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_minf lowers through pto.vcmin only +// CHECK-SAME: requires full source physical chunks +// CHECK-SAME: found padding lane in physical chunk + +// ----- + +module { + func.func @vmi_to_vpto_reduce_maxf_deint_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %init: !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.reduce_maxf %source, %init, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.vreg<1xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_maxf lowers through pto.vcmax only +// CHECK-SAME: requires contiguous source, init, mask, and result layouts diff --git a/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto new file mode 100644 index 0000000000..ab4f204979 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_relu_element_type_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_relu_bf16_invalid( + %source: !pto.vmi.vreg<128xbf16, #pto.vmi.layout>) { + %relu = pto.vmi.relu %source + : !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xbf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.relu direct lowering requires physical vreg parts with b8/b16/b32 predicate masks and f16/f32 element type +// CHECK-SAME: pto.vrelu direct lowering supports only f16/f32 VMI floating-point element types diff --git a/test/lit/vmi/vmi_to_vpto_scatter.pto b/test/lit/vmi/vmi_to_vpto_scatter.pto new file mode 100644 index 0000000000..12799c01fc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scatter.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scatter( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scatter( +// CHECK: pto.vscatter %arg0, %arg1, %arg2, %arg3 : !pto.vreg<64xf32>, !pto.ptr, !pto.vreg<64xi32>, !pto.mask +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto new file mode 100644 index 0000000000..027162ac68 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_scatter_missing_unique_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + pto.vmi.scatter %value, %dst[%indices], %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.vreg<64xi32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only with an indices_unique proof +// CHECK-SAME: requires indices_unique proof diff --git a/test/lit/vmi/vmi_to_vpto_scf_for.pto b/test/lit/vmi/vmi_to_vpto_scf_for.pto new file mode 100644 index 0000000000..253432b6dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_for.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_for(%a: !pto.vmi.vreg<128xf16>) + -> !pto.vmi.vreg<128xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %init = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + return %result : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_for( +// CHECK-SAME: %[[A:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[P0:.*]] = pto.vcvt %[[A]] +// CHECK-DAG: %[[P1:.*]] = pto.vcvt %[[A]] +// CHECK: %[[RESULT:.*]]:2 = scf.for +// CHECK-SAME: iter_args(%[[ACC0:.*]] = %[[P0]], %[[ACC1:.*]] = %[[P1]]) +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[N0:.*]] = pto.vadd %[[ACC0]], %[[ACC0]] +// CHECK: %[[N1:.*]] = pto.vadd %[[ACC1]], %[[ACC1]] +// CHECK: scf.yield %[[N0]], %[[N1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_scf_if.pto b/test/lit/vmi/vmi_to_vpto_scf_if.pto new file mode 100644 index 0000000000..dcc7497ee4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_scf_if.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_scf_if( + %cond: i1, + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) -> !pto.vmi.vreg<128xf32> { + %value, %mask = scf.if %cond + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred>) { + %ea = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpa = pto.vmi.cmpf "olt", %ea, %ea + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %ea, %cmpa : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } else { + %eb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %cmpb = pto.vmi.cmpf "olt", %eb, %eb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + scf.yield %eb, %cmpb : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + } + %selected = pto.vmi.select %mask, %value, %value + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %selected : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_scf_if( +// CHECK-SAME: %[[COND:[^,]+]]: i1 +// CHECK-SAME: %[[A:[^,]+]]: !pto.vreg<128xf16> +// CHECK-SAME: %[[B:[^)]+]]: !pto.vreg<128xf16> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK: %[[IF:.*]]:4 = scf.if %[[COND]] -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask) +// CHECK: pto.vcvt %[[A]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: else +// CHECK: pto.vcvt %[[B]] +// CHECK: pto.vcmp {{.*}}, {{.*}}, {{.*}}, "lt" +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask, !pto.mask +// CHECK: pto.vsel %[[IF]]#0, %[[IF]]#0, %[[IF]]#2 +// CHECK: pto.vsel %[[IF]]#1, %[[IF]]#1, %[[IF]]#3 +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shli.pto b/test/lit/vmi/vmi_to_vpto_shli.pto new file mode 100644 index 0000000000..eb5fa7d64d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shli.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shli( + %value: !pto.vmi.vreg<256xi16>, + %amount: !pto.vmi.vreg<256xi16>) -> !pto.vmi.vreg<256xi16> { + %shifted = pto.vmi.shli %value, %amount + : !pto.vmi.vreg<256xi16>, !pto.vmi.vreg<256xi16> + -> !pto.vmi.vreg<256xi16> + return %shifted : !pto.vmi.vreg<256xi16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shli( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xi16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xi16> +// CHECK-SAME: -> (!pto.vreg<128xi16>, !pto.vreg<128xi16>) +// CHECK-DAG: %[[SHL0:.*]] = pto.vshl %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK-DAG: %[[SHL1:.*]] = pto.vshl %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xi16>, !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<128xi16> +// CHECK: return %[[SHL0]], %[[SHL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shrui.pto b/test/lit/vmi/vmi_to_vpto_shrui.pto new file mode 100644 index 0000000000..46ccbf8d86 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shrui.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_shrui( + %value: !pto.vmi.vreg<256xui16>, + %amount: !pto.vmi.vreg<256xui16>) -> !pto.vmi.vreg<256xui16> { + %shifted = pto.vmi.shrui %value, %amount + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<256xui16> + -> !pto.vmi.vreg<256xui16> + return %shifted : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_shrui( +// CHECK-SAME: %[[VALUE0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[VALUE1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[AMOUNT1:[^)]+]]: !pto.vreg<128xui16> +// CHECK-SAME: -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) +// CHECK-DAG: %[[SHR0:.*]] = pto.vshr %[[VALUE0]], %[[AMOUNT0]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK-DAG: %[[SHR1:.*]] = pto.vshr %[[VALUE1]], %[[AMOUNT1]], {{.*}} : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: return %[[SHR0]], %[[SHR1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto new file mode 100644 index 0000000000..dc237c02dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_forwarding.pto @@ -0,0 +1,159 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_identity( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_second_chunk( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } + + func.func @vmi_shuffle_tail_prefix( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<4xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_chunk_swap( + %src: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_shuffle_reverse_one_chunk( + %src: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_shuffle_deint2_identity( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_identity( +// CHECK-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D0]], %[[D1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_second_chunk( +// CHECK-SAME: %{{[^,]+}}: !pto.vreg<64xf32> +// CHECK-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[D1]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_tail_prefix( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %{{[^)]+}}: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S0]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_chunk_swap( +// CHECK-SAME: %[[S0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[S1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[S1]], %[[S0]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_reverse_one_chunk( +// CHECK-SAME: %[[SRC:[^)]+]]: !pto.vreg<64xf32> +// CHECK-DAG: %[[C63:.*]] = arith.constant 63 : i32 +// CHECK: %[[IDX:.*]] = pto.vci %[[C63]] {order = "DESC"} : i32 -> !pto.vreg<64xi32> +// CHECK: %[[OUT:.*]] = pto.vselr %[[SRC]], %[[IDX]] : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +// CHECK-NEXT: return %[[OUT]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_shuffle_deint2_identity( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-NEXT: return %[[P0]], %[[P1]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto new file mode 100644 index 0000000000..264b7b6a6a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_shuffle_lane0_splat.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_shuffle_lane0_splat( + %src: !pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<1xf32>) -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_shuffle_lane0_splat( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP0:.*]] = pto.vdup %arg0, %[[MASK0]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[DUP1:.*]] = pto.vdup %arg0, %[[MASK1]] {position = "LOWEST"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[DUP0]], %[[DUP1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto new file mode 100644 index 0000000000..6e89595596 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto='enable-stable-gather-masked-load=true' 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_stable_gather_masked_load_todo( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %out = pto.vmi.masked_load %src[%offset], %mask, %passthru + : !pto.ptr, + !pto.vmi.mask<64xb32, #pto.vmi.layout>, + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + pto.vmi.store %out, %src[%offset] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_load stable VGATHER-based lowering is reserved for strict masked/tail loads but is not implemented yet diff --git a/test/lit/vmi/vmi_to_vpto_store_deint.pto b/test/lit/vmi/vmi_to_vpto_store_deint.pto new file mode 100644 index 0000000000..cafebbf14d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint2( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } + + func.func @vmi_to_vpto_store_deint2_multichunk( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2( +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B32", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint4( +// CHECK: %[[A0:.*]], %[[A1:.*]] = pto.vintlv %arg0, %arg2 +// CHECK: %[[B0:.*]], %[[B1:.*]] = pto.vintlv %arg1, %arg3 +// CHECK: %[[D0:.*]], %[[D1:.*]] = pto.vintlv %[[A0]], %[[B0]] +// CHECK: %[[D2:.*]], %[[D3:.*]] = pto.vintlv %[[A1]], %[[B1]] +// CHECK: pto.vsts %[[D0]], %arg4[%arg5] +// CHECK: pto.vsts %[[D1]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D2]], %arg4[{{.*}}] +// CHECK: pto.vsts %[[D3]], %arg4[{{.*}}] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint2_multichunk( +// CHECK: %[[MASK0:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg2, %arg4[%arg5], "INTLV_B32", %[[MASK0]] +// CHECK: %[[MASK1:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg1, %arg3, %arg4[{{.*}}], "INTLV_B32", %[[MASK1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto new file mode 100644 index 0000000000..e1068f813c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto new file mode 100644 index 0000000000..653d9b6f33 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_deint_tail.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_deint_tail( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^,]+]]: !pto.ptr +// CHECK-SAME: %[[OFF:[^)]+]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[OFF]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_tail.pto b/test/lit/vmi/vmi_to_vpto_store_tail.pto new file mode 100644 index 0000000000..34058b925c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_tail.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_store_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_store_tail( +// CHECK: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %arg0, %arg2[%arg3], %[[FULL_MASK]] +// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %arg1, %arg2[{{.*}}], %[[TAIL_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto new file mode 100644 index 0000000000..b412afdcca --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_store_f64_unsupported( + %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %dst: memref<32xf64>, + %offset: index) { + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires an 8/16/32-bit element type + +// ----- + +module { + func.func @vmi_tile_write_f64_unsupported( + %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, + %dst: memref<32xf64>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_sub_mul.pto b/test/lit/vmi/vmi_to_vpto_sub_mul.pto new file mode 100644 index 0000000000..d76a6bfd3c --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_sub_mul.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_subf_mulf( + %a: !pto.vmi.vreg<128xf16>, + %b: !pto.vmi.vreg<128xf16>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>) { + %wa = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %wb = pto.vmi.extf %b + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %diff = pto.vmi.subf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wa, %wb + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %diff, %prod : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_subi_muli( + %a: !pto.vmi.vreg<128xi32>, + %b: !pto.vmi.vreg<128xi32>) + -> (!pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32>) { + %diff = pto.vmi.subi %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %prod = pto.vmi.muli %a, %b + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + return %diff, %prod : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_subf_mulf( +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[SUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[MUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[SUB0]], %[[SUB1]], %[[MUL0]], %[[MUL1]] + +// CHECK-LABEL: func.func @vmi_to_vpto_subi_muli( +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[ISUB0:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[ISUB1:.*]] = pto.vsub {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL0:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-DAG: %[[IMUL1:.*]] = pto.vmul {{.*}} : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ISUB0]], %[[ISUB1]], %[[IMUL0]], %[[IMUL1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_read_write.pto b/test/lit/vmi/vmi_to_vpto_tile_read_write.pto new file mode 100644 index 0000000000..5b7e6dbe00 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_read_write.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_read_write_contiguous(%src: memref<128xf32>, %dst: memref<128xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> + return + } + + func.func @vmi_to_vpto_tile_read_deint2(%src: memref<128xf32>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.tile_read %src + : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_tile_write_deint2( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: memref<128xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_write_contiguous( +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[ZERO]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> +// CHECK: pto.vsts %[[L0]], %arg1[%[[ZERO]]] +// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_deint2( +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" +// CHECK: return %[[P0]], %[[P1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint2( +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%[[ZERO]]], "INTLV_B32", %[[MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto new file mode 100644 index 0000000000..701d921186 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_deint_tail( + %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + %dst: memref<4xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, + memref<4xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint_tail( +// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^)]+]]: memref<4xf32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] +// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[ZERO]]], %[[MASK]] +// CHECK-NOT: pto.vsts %[[HIGH]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto new file mode 100644 index 0000000000..d4f37d48fc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_tail( + %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, + %dst: memref<100xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, memref<100xf32> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[DST:[^)]+]]: memref<100xf32> +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C36:.*]] = arith.constant 36 : i32 +// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: pto.vsts %[[V0]], %[[DST]][%[[ZERO]]], %[[FULL_MASK]] +// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 +// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[TAIL_MASK]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto b/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto new file mode 100644 index 0000000000..4d4dac9d6d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_tile_write_tail_deint_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: memref<129xf32>) { + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, memref<129xf32> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type +// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count diff --git a/test/lit/vmi/vmi_to_vpto_truncf.pto b/test/lit/vmi/vmi_to_vpto_truncf.pto new file mode 100644 index 0000000000..e8d8340c83 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_f32_to_f16( + %even: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %odd: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %wide = pto.vmi.addf %even, %odd + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } + + func.func @vmi_to_vpto_truncf_f32_tail_to_f16( + %wide: !pto.vmi.vreg<100xf32, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<100xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<100xf16, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<100xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_tail_to_f16( +// CHECK: %[[EVEN:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor %[[EVEN]], %[[ODD]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto new file mode 100644 index 0000000000..5297123e5a --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( + %input: !pto.vmi.vreg<128xf32>) { + %packed = pto.vmi.truncf %input + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion +// CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: requires source and result to have the same physical arity diff --git a/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto new file mode 100644 index 0000000000..9d8cb972aa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_truncf_unsupported_shape_invalid( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) { + %narrow = pto.vmi.truncf %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf supports only f32 deinterleaved=2 source parts +// CHECK-SAME: one contiguous f16 result chunk +// CHECK-SAME: f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk diff --git a/test/lit/vmi/vmi_to_vpto_type_arity.pto b/test/lit/vmi/vmi_to_vpto_type_arity.pto new file mode 100644 index 0000000000..e99e8e9ea0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_arity.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_arity_contiguous_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + return + } + + func.func @vmi_to_vpto_type_arity_deint2_partial( + %value: !pto.vmi.vreg<130xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<130xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_contiguous_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint4( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return + +// CHECK-LABEL: func.func @vmi_to_vpto_type_arity_deint2_partial( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto new file mode 100644 index 0000000000..afc8502caf --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_nested_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = [{nested = !pto.vmi.vreg<128xf32, #pto.vmi.layout>}] +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto new file mode 100644 index 0000000000..c115c1c3d8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_attr_residual_invalid.pto @@ -0,0 +1,16 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module attributes { + pto.hidden_vmi_type = !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} { +} + +// CHECK: VMI-RESIDUAL-OP: failed to convert all VMI ops/types to VPTO diff --git a/test/lit/vmi/vmi_to_vpto_type_only.pto b/test/lit/vmi/vmi_to_vpto_type_only.pto new file mode 100644 index 0000000000..777afaf124 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_type_only.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_type_only( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %m: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_type_only( +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.vreg<64xf32> +// CHECK-SAME: !pto.mask +// CHECK-SAME: !pto.mask +// CHECK: return +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast +// CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_unary_math.pto b/test/lit/vmi/vmi_to_vpto_unary_math.pto new file mode 100644 index 0000000000..5a4419bad2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unary_math.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_unary_math( + %value: !pto.vmi.vreg<128xf32>) + -> (!pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>) { + %neg = pto.vmi.negf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %neg, %sqrt, %exp, %ln, %relu + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absf( + %value: !pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> { + %abs = pto.vmi.absf %value + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + return %abs : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_to_vpto_absi( + %value: !pto.vmi.vreg<64xi32>) -> !pto.vmi.vreg<64xi32> { + %abs = pto.vmi.absi %value + : !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xi32> + return %abs : !pto.vmi.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_unary_math( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>, +// CHECK-SAME: !pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[NEG0:.*]] = pto.vneg %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[NEG1:.*]] = pto.vneg %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT0:.*]] = pto.vsqrt %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[SQRT1:.*]] = pto.vsqrt %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP0:.*]] = pto.vexp %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[EXP1:.*]] = pto.vexp %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN0:.*]] = pto.vln %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[LN1:.*]] = pto.vln %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU0:.*]] = pto.vrelu %[[V0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[RELU1:.*]] = pto.vrelu %[[V1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[NEG0]], %[[NEG1]], %[[SQRT0]], %[[SQRT1]], %[[EXP0]], %[[EXP1]], %[[LN0]], %[[LN1]], %[[RELU0]], %[[RELU1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absf( +// CHECK-SAME: %[[F0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[F1:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +// CHECK-DAG: %[[ABSF0:.*]] = pto.vabs %[[F0]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-DAG: %[[ABSF1:.*]] = pto.vabs %[[F1]], {{.*}} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[ABSF0]], %[[ABSF1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_absi( +// CHECK-SAME: %[[I0:[^)]+]]: !pto.vreg<64xi32> +// CHECK-SAME: -> !pto.vreg<64xi32> +// CHECK: %[[ABSI:.*]] = pto.vabs %[[I0]], {{.*}} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: return %[[ABSI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto new file mode 100644 index 0000000000..9bf8f25949 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unrealized_cast_residual_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unrealized_cast_residual_invalid( + %arg0: i32) -> f32 { + %0 = builtin.unrealized_conversion_cast %arg0 + : i32 to f32 + return %0 : f32 + } +} + +// CHECK: VMI-RESIDUAL-OP: unrealized conversion cast remains after vmi-to-vpto diff --git a/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto new file mode 100644 index 0000000000..df51608b08 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_unsupported_op_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_unsupported_op_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %shuffled = "pto.vmi.shuffle"(%a) { + indices = array + } : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<4xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.shuffle requires physical chunk forwarding or lane0 splat or vci-materializable vselr indices +// CHECK-SAME: forwarding: +// CHECK-SAME: lane0 splat: +// CHECK-SAME: vselr: diff --git a/test/lit/vmi/vmi_truncf_direction_invalid.pto b/test/lit/vmi/vmi_truncf_direction_invalid.pto new file mode 100644 index 0000000000..934f1e4ba3 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_direction_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_direction_invalid(%source: !pto.vmi.vreg<128xf16>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + return + } +} + +// CHECK: requires result element type to be narrower than source element type diff --git a/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto new file mode 100644 index 0000000000..56e07a9892 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_lane_mismatch_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_lane_mismatch_invalid(%source: !pto.vmi.vreg<64xf32>) { + %result = pto.vmi.truncf %source + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: requires source and result logical lane counts to match diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto new file mode 100644 index 0000000000..04613c441d --- /dev/null +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module attributes { + pto.vmi_contiguous = #pto.vmi.layout, + pto.vmi_deinterleaved2 = #pto.vmi.layout, + pto.vmi_deinterleaved4 = #pto.vmi.layout +} { + func.func @vmi_type_attr_parse( + %surface: !pto.vmi.vreg<128xf32>, + %contiguous: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %surface_mask: !pto.vmi.mask<128xpred>, + %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, + %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + return + } +} + +// CHECK: pto.vmi_contiguous = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved2 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout +// CHECK-LABEL: func.func @vmi_type_attr_parse( +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_type_element_count_invalid.pto b/test/lit/vmi/vmi_type_element_count_invalid.pto new file mode 100644 index 0000000000..a7548528c9 --- /dev/null +++ b/test/lit/vmi/vmi_type_element_count_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_type_element_count_invalid( + %arg0: !pto.vmi.vreg<0xf32>) { + return + } +} + +// CHECK: expected a positive element count diff --git a/test/lit/vmi/vmi_unary_math_integer_invalid.pto b/test/lit/vmi/vmi_unary_math_integer_invalid.pto new file mode 100644 index 0000000000..8f3af3092e --- /dev/null +++ b/test/lit/vmi/vmi_unary_math_integer_invalid.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_sqrt_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %sqrt = pto.vmi.sqrt %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.sqrt' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_exp_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %exp = pto.vmi.exp %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.exp' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_ln_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %ln = pto.vmi.ln %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.ln' op requires floating-point-like VMI element type + +// ----- + +module { + func.func @vmi_relu_integer_invalid(%value: !pto.vmi.vreg<128xi32>) { + %relu = pto.vmi.relu %value + : !pto.vmi.vreg<128xi32> -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK: 'pto.vmi.relu' op requires floating-point-like VMI element type diff --git a/test/lit/vmi/vmi_unpack_arity_invalid.pto b/test/lit/vmi/vmi_unpack_arity_invalid.pto new file mode 100644 index 0000000000..5cd224a6e6 --- /dev/null +++ b/test/lit/vmi/vmi_unpack_arity_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_unpack_arity_invalid( + %a: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %p0 = "pto.vmi.unpack"(%a) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> + return + } +} + +// CHECK: requires 2 physical parts, got 1 diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py new file mode 100644 index 0000000000..8c3eb7acea --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 23 +SCALE = np.float32(2.0) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = src[:LOGICAL_ELEMS].astype(np.float32) * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..81b8640c0a --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f16_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %packed = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..3c329a34bb --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f16_to_f32_tail_kernel(__gm__ half *src, __gm__ float *dst); + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_dequant_f16_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp new file mode 100644 index 0000000000..7797fe7fb0 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f16_to_f32_tail_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f16_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f16-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py new file mode 100644 index 0000000000..8de470b64d --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py new file mode 100644 index 0000000000..b53b4b2ba9 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SCALE = np.float32(2.0) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden[:LOGICAL_ELEMS] = decoded[:LOGICAL_ELEMS] * SCALE + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto new file mode 100644 index 0000000000..bddf6b0f06 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dequant_f8_to_f32_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %packed = pto.vmi.load %ub_src_f8[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.masked_store %out, %ub_dst[%offset], %mask + : !pto.vmi.vreg<256xf32>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp new file mode 100644 index 0000000000..02688457e3 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dequant_f8_to_f32_tail_kernel(__gm__ uint8_t *src, __gm__ float *dst); + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream) { + vmi_dequant_f8_to_f32_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp new file mode 100644 index 0000000000..ee62749258 --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dequant_f8_to_f32_tail_kernel(uint8_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint8_t *srcHost = nullptr; + uint8_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dequant_f8_to_f32_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dequant-f8-to-f32-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py new file mode 100644 index 0000000000..7938574cd5 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +SEED = 29 +SCALE = np.float32(0.5) +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden[:LOGICAL_ELEMS] = (src[:LOGICAL_ELEMS] * SCALE).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto new file mode 100644 index 0000000000..2920617624 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f16_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c128_i64 + nburst(%c32_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c128 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<128xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %packed = pto.vmi.truncf %scaled : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %packed, %ub_dst[%offset], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + %next = arith.subi %remaining, %c128 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c32_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp new file mode 100644 index 0000000000..bf3aa91f10 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f16_tail_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_quant_f32_to_f16_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp new file mode 100644 index 0000000000..b03ccbce5c --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f16_tail_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f16_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f16-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py new file mode 100644 index 0000000000..9c36f02c73 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + golden = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto new file mode 100644 index 0000000000..4c7193f970 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/kernel.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_full_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %wide = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %packed, %ub_dst_f8[%c0] : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp new file mode 100644 index 0000000000..18bc01e2d1 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_full_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_full_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp new file mode 100644 index 0000000000..6e3aae53f2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_full_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 256; + constexpr size_t kDstElems = 256; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_full_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-full/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py new file mode 100644 index 0000000000..b662cd604f --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 1024 +LOGICAL_ELEMS = 1000 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + packed = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden = np.full(ELEMS, SENTINEL, dtype=np.uint8) + golden[:LOGICAL_ELEMS] = packed[:LOGICAL_ELEMS] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto new file mode 100644 index 0000000000..bb3db56ff2 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_quant_f32_to_f8_tail_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c1024 step %c256 iter_args(%remaining = %c1000) -> (index) { + %mask = pto.vmi.create_mask %remaining : index -> !pto.vmi.mask<256xpred> + %wide = pto.vmi.load %ub_src[%offset] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %packed = pto.vmi.truncf %wide : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.masked_store %packed, %ub_dst_f8[%offset], %mask + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr, !pto.vmi.mask<256xpred> + %next = arith.subi %remaining, %c256 : index + scf.yield %next : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp new file mode 100644 index 0000000000..cf40a3fc57 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_quant_f32_to_f8_tail_kernel(__gm__ float *src, __gm__ uint8_t *dst); + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream) { + vmi_quant_f32_to_f8_tail_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp new file mode 100644 index 0000000000..5f5bda8502 --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_quant_f32_to_f8_tail_kernel(float *src, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 1024; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_quant_f32_to_f8_tail_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/quant-f32-to-f8-tail/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..ee2be3c731 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F16_VALUE = np.float16(0.125) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src_f16 = np.full(ELEMS, F16_VALUE, dtype=np.float16) + src_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + decoded_f8 = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32) + reduction = np.sum(src_f16.astype(np.float32), dtype=np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..ae307ef525 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + %c256 = arith.constant 256 : index + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %src_f16 = pto.vmi.load %ub_f16[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %init = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<1xf32> + %sum = pto.vmi.reduce_addf %src_f16_f32, %init, %mask {reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<1xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<1xf32> + %sum_vec = pto.vmi.broadcast %sum + : !pto.vmi.vreg<1xf32> -> !pto.vmi.vreg<256xf32> + %src_f8 = pto.vmi.load %ub_f8[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %out, %ub_dst[%c0] : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..b882f9e0e2 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream) { + vmi_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e48cd97661 --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 5bde9442e7..3732f72313 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(ptobc) add_subdirectory(ptoas) +add_subdirectory(pto-test-opt) diff --git a/tools/pto-test-opt/CMakeLists.txt b/tools/pto-test-opt/CMakeLists.txt new file mode 100644 index 0000000000..8f72f0383d --- /dev/null +++ b/tools/pto-test-opt/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set(LLVM_LINK_COMPONENTS + Support +) + +add_llvm_executable(pto-test-opt + pto-test-opt.cpp +) + +target_link_libraries(pto-test-opt PRIVATE + PTOIR + PTOTransforms + MLIRMlirOptMain + MLIRIR + MLIRParser + MLIRPass + MLIRSupport + MLIRFuncDialect + MLIRArithDialect + MLIRMemRefDialect + MLIRSCFDialect + MLIRControlFlowDialect +) + +add_dependencies(pto-test-opt + PTOOpsIncGen + PTOPassesIncGen +) diff --git a/tools/pto-test-opt/pto-test-opt.cpp b/tools/pto-test-opt/pto-test-opt.cpp new file mode 100644 index 0000000000..6ec1dc70ef --- /dev/null +++ b/tools/pto-test-opt/pto-test-opt.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +//===- pto-test-opt.cpp - PTO lit pass runner -----------------------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registry.insert(); + + mlir::registerAllPasses(); + mlir::pto::registerPTOPasses(); + + return failed(mlir::MlirOptMain(argc, argv, "PTO lit pass runner\n", + registry)); +} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 8e73de48e7..4d0bc4b877 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -8,6 +8,7 @@ #include "ptoas.h" #include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/VPTOLLVMEmitter.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" @@ -441,6 +442,12 @@ static llvm::cl::opt disableInferLayout( llvm::cl::desc("Disable PTO layout inference pass (static-only)"), llvm::cl::init(false)); +static llvm::cl::opt enableVMI( + "enable-vmi", + llvm::cl::desc("Run the experimental VMI-to-VPTO semantic pipeline " + "(requires --pto-backend=vpto or pto.backend = \"vpto\")"), + llvm::cl::init(false)); + static llvm::cl::opt emitAddPtrTrace( "emit-addptr-trace", llvm::cl::desc("Emit addptr trace comments in generated C++ output"), @@ -1654,6 +1661,51 @@ static LogicalResult runVPTOBackendPipeline(OwningOpRef &module, return success(); } +static bool containsVMIType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIType) || + llvm::any_of(functionType.getResults(), containsVMIType); + } + if (auto shapedType = dyn_cast(type)) + return containsVMIType(shapedType.getElementType()); + return false; +} + +static LogicalResult verifyNoPublicVMISignature(ModuleOp module) { + WalkResult result = module.walk([&](func::FuncOp func) { + if (!func.isPublic() || !containsVMIType(func.getFunctionType())) + return WalkResult::advance(); + func.emitError() + << pto::kVMIDiagLayoutContractPrefix + << "public VMI typed function requires an explicit external ABI " + "materialization plan"; + return WalkResult::interrupt(); + }); + return failure(result.wasInterrupted()); +} + +static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { + if (failed(verifyNoPublicVMISignature(module.get()))) + return failure(); + + PassManager pm(module->getContext()); + pm.enableVerifier(); + pm.addPass(pto::createPTOValidateVMIIRPass()); + pm.addPass(pto::createVMILayoutAssignmentPass()); + pm.addPass(pto::createPTOValidateVMILayoutIRPass()); + pm.addPass(pto::createVMIToVPTOPass()); + if (failed(applyConfiguredPassManagerCLOptions(pm, + "VMI-to-VPTO pipeline"))) + return failure(); + if (failed(pm.run(module.get()))) { + llvm::errs() << "Error: VMI-to-VPTO pipeline failed.\n"; + return failure(); + } + return success(); +} + int mlir::pto::compilePTOASModule( OwningOpRef &module, PTOASContext &context, PTOBackend effectiveBackend, PTOASCompileResult &result, @@ -1670,6 +1722,11 @@ int mlir::pto::compilePTOASModule( "--pto-backend=vpto or pto.backend = \"vpto\".\n"; return 1; } + if (enableVMI && effectiveBackend != PTOBackend::VPTO) { + llvm::errs() << "Error: --enable-vmi requires --pto-backend=vpto or " + "pto.backend = \"vpto\".\n"; + return 1; + } PTOBuildLevel effectiveLevel = defaultBuildLevel(); if (!parseBuildLevel(ptoBuildLevel, effectiveLevel)) { @@ -1824,6 +1881,11 @@ int mlir::pto::compilePTOASModule( const bool hasTileOpsToExpand = hasUnexpandedTileOps(*module); + if (enableVMI) { + if (failed(runVMISemanticPipeline(module))) + return 1; + } + if (effectiveBackend == PTOBackend::VPTO && !hasTileOpsToExpand) { if (ptoPrintSeamIR || !ptoSeamIRFile.empty()) { llvm::errs() << "Error: shared pre-backend seam IR is unavailable when " From 2c04f4640abd00843c1dd2d0fcbb9fea07efb998 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Thu, 18 Jun 2026 13:03:02 +0800 Subject: [PATCH 04/54] feat: support num_groups layout --- docs/designs/vmi-implementation-manual.md | 121 +++ include/PTO/IR/VMIAttrs.td | 4 + include/PTO/IR/VMIOps.td | 38 + lib/PTO/IR/VMI.cpp | 139 ++- lib/PTO/Transforms/VMILayoutAssignment.cpp | 39 + lib/PTO/Transforms/VMIToVPTO.cpp | 910 +++++++++++++++++- .../vmi/vmi_to_vpto_group_broadcast_deint.pto | 33 + .../vmi/vmi_to_vpto_group_broadcast_vselr.pto | 42 + test/lit/vmi/vmi_to_vpto_group_ops.pto | 41 + .../vmi/vmi_to_vpto_group_reduce_vcgadd.pto | 33 + ...to_vpto_group_reduce_vcgadd_multichunk.pto | 45 + .../group-reduce-f16-f8-mul-store/compare.py | 27 + .../group-reduce-f16-f8-mul-store/golden.py | 59 ++ .../group-reduce-f16-f8-mul-store/kernel.pto | 71 ++ .../group-reduce-f16-f8-mul-store/launch.cpp | 43 + .../group-reduce-f16-f8-mul-store/main.cpp | 91 ++ .../group-reduce-f16-f8-mul-store/ptoas.flags | 1 + 17 files changed, 1725 insertions(+), 12 deletions(-) create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_ops.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 772194f64d..cd674db32a 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -748,6 +748,11 @@ deinterleaved=4: part1 chunks for lanes 1,5,9,... part2 chunks for lanes 2,6,10,... part3 chunks for lanes 3,7,11,... + +num_groups=G: + sparse group-slot reduce result layout + physical storage is contiguous chunk order + only canonical group_slot(g) lanes contain semantic values ``` 每个 semantic pattern 必须从 adaptor 拿 physical parts,不允许从 defining op 反推: @@ -1596,6 +1601,11 @@ The type converter must define one canonical physical ordering and every pattern part2 lanes [2,6,10,...] part3 lanes [3,7,11,...] +!pto.vmi.vreg + -> chunks in contiguous physical storage order + only derived group_slot(g) lanes contain semantic values + this layout is valid only for group reduce/broadcast exchange values + !pto.vmi.mask -> same part/chunk ordering as its data layout, one !pto.mask per physical part/chunk ``` @@ -2932,6 +2942,117 @@ pto.vmi.reduce_addf: f16 until accumulator precision and rounding contract are designed partial/tail source chunks because padding lanes must not participate +pto.vmi.group_load / pto.vmi.group_store: + semantic: + num_groups is the only static grouping attribute. + N = logical lane count; G = num_groups; S = N / G. + group_load reads each logical group as one contiguous row: + result[g * S + i] = source[offset + g * row_stride + i] + for 0 <= g < G and 0 <= i < S + group_store writes the inverse row mapping: + destination[offset + g * row_stride + i] = value[g * S + i] + row_stride is an index operand, measured in elements, and may be dynamic. + Tail/valid-lane information is not an attr; it must be represented by a + mask in the producing/consuming computation. The current direct + group_load/group_store path is for full physical chunks. + layout assignment: + group_load result natural layout is contiguous + group_store value use is requested as contiguous + current direct lowering: + source/value element width must be maskable by b8/b16/b32 + layout must be contiguous with full physical chunks + num_groups must evenly divide N, and the derived group size S must be a + multiple of the physical lanes + per part, so every physical chunk belongs to exactly one group + lower each physical chunk with pto.vlds/pto.vsts at: + offset + group * row_stride + chunk_in_group * lanes_per_part + unsupported cases: + derived group size splitting a physical chunk, because this needs partial-vreg + lane insertion/extraction or a gather/scatter plan + partial/tail physical chunks + GM-backed direct vector load/store paths not already accepted by the normal + VMI memory access plan + +pto.vmi.group_reduce_addf: + semantic: + requires {reassoc} + N = logical lane count; G = num_groups; S = N / G + L = physical lanes per 256B chunk for the element type. + The result carries #pto.vmi.layout, a sparse group-slot + layout. It is not a dense vector layout: only group_slot(g) lanes have + semantic values. + group_slot(g) is canonical and derived from N, G, and L: + if S < L: + low_elems = L / S + chunk_stride = 1 + if S >= L: + low_elems = 1 + chunk_stride = S / L + group_slot(g) = (g / low_elems) * chunk_stride * L + (g % low_elems) + for each group g: + result[group_slot(g)] = + reduce_add(source[g * S .. (g + 1) * S), mask in same range) + Non-slot lanes are not consumed by pto.vmi.group_broadcast. The current + direct lowering materializes them as zero where the hardware path does not + already define them. + The result remains a VMI vector with the same element type and logical lane + count as the source, but its layout is #pto.vmi.layout. + layout assignment: + source use is requested as contiguous + result natural layout is #pto.vmi.layout + mask use is requested as contiguous with granularity derived from source + element width + current direct lowering: + source/result element type must be f32 + source, result, and mask must have matching physical arity and full chunks + if S=8 for f32, lower each physical chunk with pto.vcgadd. This is the + hardware 32B VLane group reduction path for f32: each source chunk produces + eight 8-lane group sums in the low lanes of that physical chunk. The + lowering preserves this natural no-pack result. + Otherwise: + derived group size S must be a multiple of physical lanes per part + lower each source chunk with pto.vcadd, combine chunks in the same group + with pto.vadd under PAT_VL1, then place group g at group_slot(g) in the + #pto.vmi.layout result. All other result chunks/lane values + are zero. + unsupported cases: + missing reassoc attr + f16 or integer group reductions until accumulator and result contracts are + designed + derived group size S that neither divides nor is a multiple of L + +pto.vmi.group_broadcast: + semantic: + N = logical lane count; G = num_groups; S = N / G + source must carry #pto.vmi.layout. For each group g, the + source value is read from group_slot(g), using the same canonical group_slot + definition as pto.vmi.group_reduce_addf. The result broadcasts it back to + each logical group: + result[g * S + i] = source[group_slot(g)] + layout assignment: + source use is requested as #pto.vmi.layout + result is consumer-driven. If no consumer requests another layout, it + defaults to contiguous. + current direct lowering: + source must carry #pto.vmi.layout with full physical chunks + result may be contiguous with full physical chunks + result may also be deinterleaved when S is large enough that every physical + result chunk stays inside one logical group, for example N=512, G=2, S=256, + L=64, deinterleaved=4 + derived group size S must divide or be a multiple of L for canonical + group-slot addressing + if result is contiguous and S < L, each physical chunk contains multiple group + slots. Lower by + creating an index vector [0...0, 1...1, ...] and applying pto.vselr to the + corresponding source chunk. + if S >= L and each result physical chunk belongs to one group, lower by + duplicating the first lane of that group's source chunk with pto.vdup LOWEST. + unsupported cases: + partial/tail physical chunks + derived group size S that neither divides nor is a multiple of L + deinterleaved small-group broadcast where one physical result chunk needs + values from multiple source chunks + pto.vmi.reduce_maxf / pto.vmi.reduce_minf: semantic: acc = init[0] diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index fc2a7f2f5b..da8428dd23 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -25,9 +25,13 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, int64_t factor); + static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, + int64_t numGroups); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + bool isGroupSlots() const { return getKind() == "num_groups"; } + int64_t getNumGroups() const { return getFactor(); } }]; } diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 6f567bb8a5..7bd7524118 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -391,6 +391,26 @@ def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { let assemblyFormat = "$source `,` $init `,` $mask attr-dict `:` type($source) `,` type($init) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { + let summary = "VMI masked floating-point add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups, + OptionalAttr:$reassoc); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { + let summary = "VMI broadcast group-slot values back to each logical group"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + def VMIExtFOp : VMI_Op<"extf"> { let summary = "VMI floating-point elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); @@ -423,6 +443,15 @@ def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector load with a row stride between groups"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$row_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $row_stride attr-dict `:` type($source) `->` type($result)"; +} + def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector load with passthrough lanes"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, @@ -462,6 +491,15 @@ def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI logical grouped vector store with a row stride between groups"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, Index:$row_stride, I64Attr:$num_groups); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $row_stride attr-dict `:` type($value) `,` type($destination)"; +} + def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector store"; let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 1f9a43f51a..e26982e347 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -143,7 +143,7 @@ static FailureOr getLayoutFactor(Type type) { FailureOr layout = getAssignedVMILayout(type); if (failed(layout)) return failure(); - return (*layout).isContiguous() ? 1 : (*layout).getFactor(); + return (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; } static FailureOr getPhysicalLanesPerPart(Type type) { @@ -294,6 +294,17 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, return success(); } +static LogicalResult verifyNumGroups(Operation *op, VMIVRegType type, + int64_t numGroups) { + if (numGroups <= 0) + return op->emitOpError("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return op->emitOpError() + << "requires num_groups to evenly divide VMI logical lane count " + << type.getElementCount(); + return success(); +} + static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, TypeRange physicalTypes) { FailureOr expectedArity = getVMIPhysicalArity(vmiType); @@ -354,6 +365,11 @@ VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, return VMILayoutAttr::get(context, "deinterleaved", factor); } +VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, + int64_t numGroups) { + return VMILayoutAttr::get(context, "num_groups", numGroups); +} + Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { SMLoc loc = parser.getCurrentLocation(); StringRef kind; @@ -367,10 +383,13 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { } else if (kind == "deinterleaved") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + } else if (kind == "num_groups") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " - "'deinterleaved'"); + "'deinterleaved' or 'num_groups'"); return {}; } @@ -383,7 +402,7 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { void VMILayoutAttr::print(AsmPrinter &printer) const { printer << "<" << getKind(); - if (isDeinterleaved()) + if (isDeinterleaved() || isGroupSlots()) printer << " = " << getFactor(); printer << ">"; } @@ -406,8 +425,16 @@ VMILayoutAttr::verify(function_ref emitError, return success(); } + if (kind == "num_groups") { + if (factor <= 0) + return emitError() + << "#pto.vmi.layout requires num_groups to be positive"; + return success(); + } + return emitError() << "expected VMI layout kind to be 'contiguous' or " - "'deinterleaved'"; + "'deinterleaved' or 'num_groups'"; } Type VMIVRegType::parse(AsmParser &parser) { @@ -454,6 +481,14 @@ LogicalResult VMIVRegType::verify(function_ref emitError, return emitError() << "'" << formatVMIVRegType(elementCount, elementType, layout) << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots() && + elementCount % layoutAttr.getNumGroups() != 0) + return emitError() << "'" << formatVMIVRegType(elementCount, elementType, + layout) + << "' expected num_groups layout to evenly divide " + "the VMI logical lane count"; + } return success(); } @@ -509,6 +544,12 @@ LogicalResult VMIMaskType::verify(function_ref emitError, return emitError() << "'" << formatVMIMaskType(elementCount, granularity, layout) << "' expected layout to be #pto.vmi.layout"; + if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { + if (layoutAttr.isGroupSlots()) + return emitError() << "'" << formatVMIMaskType(elementCount, granularity, + layout) + << "' mask type must not carry num_groups layout"; + } if (granularity == "pred" && layout) return emitError() << "'" << formatVMIMaskType(elementCount, granularity, @@ -958,6 +999,65 @@ LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } +LogicalResult VMIGroupReduceAddFOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!getOperation()->hasAttr("reassoc")) + return emitOpError( + "requires reassoc attr because grouped lowering uses pair-wise " + "floating-point reductions"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like VMI source element type"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isContiguous()) + return emitOpError( + "requires layout-assigned source to use contiguous layout"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() + << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupBroadcastOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() + << "requires layout-assigned source to use " + "#pto.vmi.layout"; + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (resultLayout.isGroupSlots()) + return emitOpError( + "requires layout-assigned result to use a dense VMI layout"); + } + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); @@ -1025,6 +1125,21 @@ void VMILoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIGroupLoadOp::verify() { + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + return verifyNumGroups(getOperation(), resultType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedLoadOp::verify() { auto maskType = cast(getMask().getType()); auto passthruType = cast(getPassthru().getType()); @@ -1109,6 +1224,22 @@ void VMIStoreOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } +LogicalResult VMIGroupStoreOp::verify() { + auto valueType = cast(getValue().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyNumGroups(getOperation(), valueType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + LogicalResult VMIMaskedStoreOp::verify() { auto valueType = cast(getValue().getType()); auto maskType = cast(getMask().getType()); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index e4d201d45c..27d6b806fe 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -224,6 +224,10 @@ struct LayoutSolver { return VMILayoutAttr::getContiguous(ctx); } + VMILayoutAttr getGroupSlotsLayout(int64_t numGroups) { + return VMILayoutAttr::getGroupSlots(ctx, numGroups); + } + VMILayoutAttr getDataLayout(Value value) { unsigned id = addDataValue(value); if (id == ~0u) @@ -537,6 +541,21 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(reduce.getResult(), + getGroupSlotsLayout( + reduce.getNumGroupsAttr().getInt()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto broadcast = dyn_cast(op)) { + requestDataUse(broadcast.getSourceMutable(), + getGroupSlotsLayout( + broadcast.getNumGroupsAttr().getInt())); + return WalkResult::advance(); + } if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); @@ -607,10 +626,20 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { requestDataUse(store.getValueMutable(), getContiguousLayout()); return WalkResult::advance(); } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getValueMutable(), getContiguousLayout()); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); @@ -1136,6 +1165,16 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse(reduce.getMaskMutable(), + sourceType.getLayoutAttr(), + getMaskGranularityForElement( + sourceType.getElementType()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed(requestMaskUse(load.getMaskMutable(), diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index db19c2846b..cf91af1142 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -488,7 +488,7 @@ FailureOr getDataLayoutFactor(VMIVRegType type) { VMILayoutAttr layout = type.getLayoutAttr(); if (!layout) return failure(); - return layout.isContiguous() ? 1 : layout.getFactor(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; } FailureOr getDataChunksInPart(VMIVRegType type, int64_t part) { @@ -568,7 +568,7 @@ FailureOr getVMITypeLayoutFactor(Type type) { auto layoutAttr = dyn_cast_or_null(layout); if (!layoutAttr) return failure(); - return layoutAttr.isContiguous() ? 1 : layoutAttr.getFactor(); + return layoutAttr.isDeinterleaved() ? layoutAttr.getFactor() : 1; } FailureOr getVMITypeElementCount(Type type) { @@ -1087,6 +1087,80 @@ LogicalResult checkSupportedStoreShape( materializationReason); } +FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, + int64_t groupSize, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("requires assigned contiguous layout"); + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(type, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("requires derived group size to evenly divide logical lane " + "count"); + if (groupSize % *lanesPerPart != 0) + return fail("currently requires group size to be a multiple of physical " + "lanes per part"); + return success(); +} + +LogicalResult checkSupportedGroupLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, + std::string *reason) { + auto resultType = cast(op.getResult().getType()); + FailureOr groupSize = + getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), + op.getSource().getType(), + std::nullopt, reason))) + return failure(); + return checkSupportedGroupChunkShape(resultType, *groupSize, reason); +} + +LogicalResult checkSupportedGroupStoreShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, + std::string *reason) { + auto valueType = cast(op.getValue().getType()); + FailureOr groupSize = + getGroupSizeFromNumGroups(valueType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (failed(checkSupportedStoreShape(capabilities, valueType, + op.getDestination(), + op.getDestination().getType(), reason))) + return failure(); + return checkSupportedGroupChunkShape(valueType, *groupSize, reason); +} + LogicalResult checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, VMIMaskedLoadOp op, std::string *reason) { @@ -1766,7 +1840,7 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { return fail("requires known physical mask lanes per part"); auto boolValues = denseAttr.getValues(); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; SmallVector materializations; for (int64_t part = 0; part < factor; ++part) { for (int64_t chunk = 0;; ++chunk) { @@ -1897,6 +1971,10 @@ materializeConstantMaskChunk(Location loc, MaskType maskType, return materializePrefixMask(loc, maskType, 0, *lanesPerPart, rewriter); } +FailureOr createScalarOffsetConstant(Location loc, Type type, + int64_t value, + PatternRewriter &rewriter); + Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, PatternRewriter &rewriter) { if (laneOffset == 0) @@ -1905,6 +1983,234 @@ Value createChunkOffset(Location loc, Value baseOffset, int64_t laneOffset, return rewriter.create(loc, baseOffset, delta).getResult(); } +Value createGroupChunkOffset(Location loc, Value baseOffset, Value rowStride, + int64_t group, int64_t inGroupLaneOffset, + PatternRewriter &rewriter) { + Value offset = baseOffset; + if (group != 0) { + Value groupIndex = rewriter.create(loc, group); + Value rowOffset = + rewriter.create(loc, rowStride, groupIndex).getResult(); + offset = rewriter.create(loc, offset, rowOffset).getResult(); + } + return createChunkOffset(loc, offset, inGroupLaneOffset, rewriter); +} + +LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, + int64_t groupSize, + int64_t *lanesPerPart, + int64_t *groupCount, + int64_t *chunksPerGroup, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("group op requires contiguous VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group op requires full physical chunks"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group op requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group op requires derived group size to evenly divide lane " + "count"); + if (groupSize % *lanes != 0) + return fail("group op currently requires group size to be a multiple of " + "physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = type.getElementCount() / groupSize; + *chunksPerGroup = groupSize / *lanes; + return success(); +} + +LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, + int64_t groupSize, + int64_t numGroups, + int64_t *lanesPerPart, + int64_t *groupCount, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != numGroups) + return fail("group slot op requires matching num_groups VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group slot op requires full physical chunks"); + FailureOr lanes = getDataLanesPerPart(type.getElementType()); + if (failed(lanes)) + return fail("group slot op requires known physical lanes per part"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail( + "group slot op requires derived group size to evenly divide lane count"); + if (*lanes % groupSize != 0 && groupSize % *lanes != 0) + return fail("group slot op requires group size to divide or be a " + "multiple of physical lanes per part"); + + *lanesPerPart = *lanes; + *groupCount = type.getElementCount() / groupSize; + return success(); +} + +LogicalResult checkFullGroupBroadcastResultShape(Operation *op, + VMIVRegType type, + int64_t groupSize, + int64_t lanesPerPart, + int64_t *layoutFactor, + int64_t *groupCount, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) { + return rewriter.notifyMatchFailure(op, message); + }; + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return fail("group_broadcast result requires assigned VMI layout"); + if (layout.isGroupSlots()) + return fail("group_broadcast result requires a dense VMI layout"); + if (failed(checkFullDataPhysicalChunks(type, nullptr))) + return fail("group_broadcast result requires full physical chunks"); + FailureOr resultLanes = + getDataLanesPerPart(type.getElementType()); + if (failed(resultLanes) || *resultLanes != lanesPerPart) + return fail("group_broadcast result requires matching physical lanes"); + if (groupSize <= 0 || type.getElementCount() % groupSize != 0) + return fail("group_broadcast result requires derived group size to evenly " + "divide lane count"); + FailureOr factor = getDataLayoutFactor(type); + if (failed(factor)) + return fail("group_broadcast result requires known layout factor"); + + if (*factor == 1) { + if (lanesPerPart % groupSize != 0 && groupSize % lanesPerPart != 0) + return fail("group_broadcast contiguous result requires group size to " + "divide or be a multiple of physical lanes per part"); + } else { + int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; + if (groupSize < lanesPerPart || + groupSize % logicalSpanPerResultChunk != 0) + return fail("group_broadcast deinterleaved result requires every " + "physical result chunk to stay within one logical group"); + } + + *layoutFactor = *factor; + *groupCount = type.getElementCount() / groupSize; + return success(); +} + +FailureOr createZeroVector(Location loc, VRegType type, + PatternRewriter &rewriter) { + FailureOr zero = + createScalarOffsetConstant(loc, type.getElementType(), 0, rewriter); + FailureOr mask = createAllTrueMaskForVReg(loc, type, rewriter); + if (failed(zero) || failed(mask)) + return failure(); + return rewriter.create(loc, type, *zero, *mask, + /*position=*/nullptr) + .getResult(); +} + +FailureOr createLaneRangeMask(Location loc, MaskType maskType, + int64_t begin, int64_t end, + PatternRewriter &rewriter) { + FailureOr lanesPerPart = + getMaskLanesPerPart(maskType.getGranularity()); + if (failed(lanesPerPart) || begin < 0 || begin > end || + end > *lanesPerPart) + return failure(); + SmallVector active(*lanesPerPart, 0); + for (int64_t lane = begin; lane < end; ++lane) + active[lane] = 1; + return materializeConstantMaskChunk(loc, maskType, active, rewriter); +} + +FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, + int64_t groupSize, + PatternRewriter &rewriter) { + int64_t lanesPerPart = indexType.getElementCount(); + FailureOr zero = + createZeroVector(loc, indexType, rewriter); + FailureOr maskType = getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); + if (failed(zero) || failed(maskType) || failed(allMask)) + return failure(); + if (groupSize >= lanesPerPart) + return *zero; + if (lanesPerPart % groupSize != 0) + return failure(); + + Value result = *zero; + int64_t groupsPerChunk = lanesPerPart / groupSize; + for (int64_t localGroup = 1; localGroup < groupsPerChunk; ++localGroup) { + FailureOr groupScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), localGroup, rewriter); + FailureOr laneMask = + createLaneRangeMask(loc, *maskType, localGroup * groupSize, + (localGroup + 1) * groupSize, rewriter); + if (failed(groupScalar) || failed(laneMask)) + return failure(); + Value splat = + rewriter + .create(loc, indexType, *groupScalar, *allMask, + /*position=*/nullptr) + .getResult(); + result = rewriter.create(loc, indexType, splat, result, *laneMask) + .getResult(); + } + return result; +} + +LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, + VMIMaskType maskType, + VMIVRegType resultType, + int64_t groupSize, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("vcgadd group_reduce_addf path requires f32 source/result"); + if (groupSize != 8) + return fail("vcgadd group_reduce_addf path requires group size = 8 for " + "f32 32-byte VLane groups"); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + int64_t numGroups = sourceType.getElementCount() / groupSize; + if (!sourceLayout || !resultLayout || !maskLayout || + !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || + !maskLayout.isContiguous()) + return fail("vcgadd group_reduce_addf path requires contiguous source/mask " + "layouts and matching num_groups result layout"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("vcgadd group_reduce_addf path requires full source " + "chunks; ") + + sourceFullReason); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("vcgadd group_reduce_addf path requires computable physical " + "arity"); + if (*sourceArity < 1 || *sourceArity != *maskArity || + *sourceArity != *resultArity) + return fail("vcgadd group_reduce_addf path requires matching non-empty " + "source/mask/result physical arity"); + return success(); +} + std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -2713,9 +3019,10 @@ FailureOr createScalarOffsetConstant(Location loc, Type type, } if (auto floatType = dyn_cast(type)) { return rewriter - .create( - loc, FloatAttr::get(floatType, - llvm::APFloat(static_cast(value)))) + .create(loc, + rewriter.getFloatAttr(floatType, + static_cast( + value))) .getResult(); } return failure(); @@ -2987,7 +3294,7 @@ struct OneToNVMICreateMaskOpPattern return failure(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; if (resultTypes.size() % factor != 0) return rewriter.notifyMatchFailure( op, "dynamic create_mask physical result count does not match " @@ -3034,7 +3341,7 @@ struct OneToNVMICreateMaskOpPattern activeLanes = resultVMIType.getElementCount(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - int64_t factor = layout.isContiguous() ? 1 : layout.getFactor(); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; SmallVector results; results.reserve(resultTypes.size()); @@ -3184,6 +3491,73 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIGroupLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_load offset must convert to one value", + rewriter); + FailureOr rowStride = + getSingleValue(op, adaptor.getRowStride(), + "group_load row_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(rowStride)) + return failure(); + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks( + op, resultVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (static_cast(resultTypes.size()) != + groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_load arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_load result must be vreg"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + results.push_back( + rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIMaskedLoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern< @@ -3502,6 +3876,73 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIGroupStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto valueVMIType = cast(op.getValue().getType()); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + FailureOr groupSize = getGroupSizeFromNumGroups( + valueVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_store requires num_groups to evenly divide lane count"); + if (failed(checkContiguousFullGroupChunks( + op, valueVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "group_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_store offset must convert to one value", + rewriter); + FailureOr rowStride = + getSingleValue(op, adaptor.getRowStride(), + "group_store row_stride must convert to one value", + rewriter); + if (failed(destination) || failed(offset) || failed(rowStride)) + return failure(); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != + groupCount * chunksPerGroup) + return rewriter.notifyMatchFailure(op, "group_store arity mismatch"); + + for (auto [index, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t group = index / chunksPerGroup; + int64_t chunkInGroup = index % chunksPerGroup; + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct OneToNVMIMaskedStoreOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern< @@ -4346,6 +4787,284 @@ struct OneToNVMIReduceAddFOpPattern } }; +struct OneToNVMIGroupReduceAddFOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupReduceAddFOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupReduceAddFOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto maskVMIType = cast(op.getMask().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, + "group_reduce_addf requires num_groups to evenly divide lane count"); + if (succeeded(checkVcgaddGroupReduceShape( + sourceVMIType, maskVMIType, resultVMIType, + *groupSize, nullptr))) { + if (sourceParts.size() != maskParts.size() || + sourceParts.size() != resultTypes.size() || sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires matching physical " + "arity"); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires physical vreg/mask"); + for (auto [sourcePart, maskPart, physicalResultType] : + llvm::zip_equal(sourceParts, maskParts, resultTypes)) { + if (sourcePart.getType() != resultType || + maskPart.getType() != maskType || physicalResultType != resultType) + return rewriter.notifyMatchFailure( + op, "vcgadd group_reduce_addf path requires uniform physical " + "chunk types"); + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + maskParts[sourceIndex]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + int64_t chunksPerGroup = 0; + if (failed(checkContiguousFullGroupChunks( + op, sourceVMIType, *groupSize, &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) + return failure(); + if (sourceParts.size() != maskParts.size() || + static_cast(sourceParts.size()) != + groupCount * chunksPerGroup || + resultTypes.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires matching source/mask/result arity"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf result must be vreg"); + FailureOr zero = createZeroVector(op.getLoc(), vregType, rewriter); + if (failed(zero)) + return rewriter.notifyMatchFailure( + op, "failed to materialize group_reduce_addf zero result"); + results.push_back(*zero); + } + + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires physical vreg result and mask"); + + FailureOr firstLaneMask = + createPrefixMask(op.getLoc(), maskType, "PAT_VL1", rewriter); + if (failed(firstLaneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_reduce_addf masks"); + + for (int64_t group = 0; group < groupCount; ++group) { + FailureOr accumulator = + createZeroVector(op.getLoc(), resultType, rewriter); + if (failed(accumulator)) + return rewriter.notifyMatchFailure( + op, "failed to create group_reduce_addf accumulator"); + + for (int64_t chunk = 0; chunk < chunksPerGroup; ++chunk) { + int64_t index = group * chunksPerGroup + chunk; + if (sourceParts[index].getType() != resultType || + maskParts[index].getType() != maskType) + return rewriter.notifyMatchFailure( + op, "group_reduce_addf requires uniform physical chunk types"); + Value reduced = + rewriter + .create(op.getLoc(), resultType, sourceParts[index], + maskParts[index]) + .getResult(); + *accumulator = + rewriter + .create(op.getLoc(), resultType, reduced, + *accumulator, *firstLaneMask) + .getResult(); + } + + int64_t destChunk = group * chunksPerGroup; + results[destChunk] = + rewriter + .create(op.getLoc(), resultType, *accumulator, + results[destChunk], *firstLaneMask) + .getResult(); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGroupBroadcastOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupBroadcastOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupBroadcastOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, + "group_broadcast requires num_groups to evenly divide lane count"); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + if (failed(checkFullGroupSlotSourceShape( + op, sourceVMIType, *groupSize, op.getNumGroupsAttr().getInt(), + &lanesPerPart, &groupCount, rewriter))) + return failure(); + int64_t resultLayoutFactor = 0; + int64_t resultGroupCount = 0; + if (failed(checkFullGroupBroadcastResultShape( + op, resultVMIType, *groupSize, lanesPerPart, &resultLayoutFactor, + &resultGroupCount, rewriter))) + return failure(); + if (resultGroupCount != groupCount) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires matching source/result group slots"); + + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "group_broadcast arity mismatch"); + + auto firstSourceType = dyn_cast(sourceParts.front().getType()); + if (!firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast source must be vreg"); + unsigned indexBits = + pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires 8/16/32-bit index elements"); + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = + VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), + indexElementType); + std::optional groupSlotIndex; + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstSourceType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast all mask"); + if (*groupSize < lanesPerPart) { + FailureOr index = createGroupSlotIndexVector( + op.getLoc(), indexType, *groupSize, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + groupSlotIndex = *index; + } + + SmallVector results; + results.resize(resultTypes.size()); + for (auto [flatIndex, resultType] : llvm::enumerate(resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires uniform physical vreg types"); + int64_t sourceChunk = flatIndex; + if (resultLayoutFactor == 1) { + if (*groupSize >= lanesPerPart) { + int64_t chunksPerGroup = *groupSize / lanesPerPart; + int64_t group = flatIndex / chunksPerGroup; + sourceChunk = group * chunksPerGroup; + } + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + sourceChunk = firstGroup * chunksPerGroup; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } + if (*groupSize >= lanesPerPart) { + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, sourceParts[sourceChunk], + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + } else { + if (resultLayoutFactor != 1) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group deinterleaved result is not " + "supported"); + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + template struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { @@ -5032,10 +5751,12 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, + OneToNVMIGroupLoadOpPattern, OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, + OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, OneToNVMIScatterOpPattern, OneToNVMITileReadOpPattern, @@ -5071,6 +5792,8 @@ void populateVMIOneToNConversionPatterns( OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, + OneToNVMIGroupReduceAddFOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point reduction"); + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !maskLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || + !maskLayout.isContiguous()) + return fail("requires contiguous source/mask layouts and matching " + "num_groups result layout"); + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(VMIReductionKind::AddF, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("requires source/result element type to match"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(sourceArity) || failed(resultArity) || failed(maskArity)) + return fail("requires computable source/result/mask physical arity"); + if (*sourceArity != *resultArity || *sourceArity != *maskArity) + return fail("requires source/result/mask physical arity to match"); + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (succeeded(checkVcgaddGroupReduceShape( + sourceType, maskType, resultType, *groupSize, nullptr))) + return success(); + return checkSupportedGroupChunkShape(sourceType, *groupSize, reason); +} + +LogicalResult checkSupportedGroupBroadcastShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason = nullptr) { + (void)capabilities; + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType() || + sourceType.getElementCount() != resultType.getElementCount()) { + if (reason) + *reason = "requires source/result shape and element type to match"; + return failure(); + } + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (!sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return success(); + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || + *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + return success(); +} + LogicalResult checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, VMIFmaOp op, std::string *reason = nullptr) { @@ -5678,11 +6517,37 @@ LogicalResult verifySupportedVMIToVPTOOps( return emitMaskableUnsupported( op, "pto.vmi.broadcast", cast(broadcast.getResult().getType())); + if (auto broadcast = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupBroadcastShape(capabilities, broadcast, + &reason))) + return WalkResult::advance(); + broadcast.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_broadcast requires full source chunks with " + "#pto.vmi.layout, a dense full result layout, " + "and num_groups deriving a group size that divides or is a " + "multiple of physical chunk lanes (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), load.getSource(), getConstantIndexValue(load.getOffset())); + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_load requires contiguous full result chunks, a " + "supported UB source, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { if (enableStableGatherMaskedLoad) { load.emitError() @@ -5744,6 +6609,19 @@ LogicalResult verifySupportedVMIToVPTOOps( << reason << ")"; return WalkResult::interrupt(); } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupStoreShape(capabilities, store, + &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_store requires contiguous full value chunks, a " + "supported UB destination, and num_groups deriving a group size " + "aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto store = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedMaskedStoreShape( @@ -6063,6 +6941,22 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupReduceAddFShape(capabilities, reduce, + &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for f32 " + "32B groups or through pto.vcadd with reassoc, contiguous full " + "source/mask chunks, #pto.vmi.layout result " + "chunks, and num_groups deriving a group size aligned to " + "physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto new file mode 100644 index 0000000000..078b61b5bf --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_deint( + %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) + -> !pto.vmi.vreg<512xf32> { + %src_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_deint( +// CHECK-COUNT-8: {position = "LOWEST"} +// CHECK-COUNT-8: pto.vmul +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto new file mode 100644 index 0000000000..01d9711ef0 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_vselr( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_vselr( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto new file mode 100644 index 0000000000..6a10e168dd --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_ops( + %src: !pto.ptr, + %dst: !pto.ptr, + %row_stride: index, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %v = pto.vmi.group_load %src[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %b = pto.vmi.group_broadcast %r {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + pto.vmi.group_store %b, %dst[%c0], %row_stride {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_ops( +// CHECK-COUNT-8: pto.vlds +// CHECK-COUNT-8: pto.vcadd +// CHECK-COUNT-8: {position = "LOWEST"} +// CHECK-NOT: pto.vselr +// CHECK-COUNT-8: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto new file mode 100644 index 0000000000..27d246e6d2 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto new file mode 100644 index 0000000000..d3da9416b6 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<1024xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, + !pto.vmi.mask<1024xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_vcgadd_multichunk( +// CHECK-COUNT-16: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py new file mode 100644 index 0000000000..5030420250 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py new file mode 100644 index 0000000000..69fbe13344 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 2 +ROW_ELEMS = 256 +ROW_STRIDE = 320 +TOTAL_ELEMS = ROWS * ROW_STRIDE +F16_VALUES = np.array([0.125, 0.25], dtype=np.float16) +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path) -> None: + repeats = (ROW_ELEMS + len(VALUES) - 1) // len(VALUES) + row_f8 = np.tile(F8E4M3FN_BYTES, repeats)[:ROW_ELEMS].astype(np.uint8) + row_decoded_f8 = np.tile(VALUES, repeats)[:ROW_ELEMS].astype(np.float32) + + src_f16 = np.zeros(TOTAL_ELEMS, dtype=np.float16) + src_f8 = np.zeros(TOTAL_ELEMS, dtype=np.uint8) + dst = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(TOTAL_ELEMS, SENTINEL, dtype=np.float32) + + for row in range(ROWS): + begin = row * ROW_STRIDE + end = begin + ROW_ELEMS + src_f16[begin:end] = F16_VALUES[row] + src_f8[begin:end] = np.roll(row_f8, row) + decoded_f8 = np.roll(row_decoded_f8, row) + reduction = np.sum(src_f16[begin:end].astype(np.float32), dtype=np.float32) + golden[begin:end] = decoded_f8 * reduction + + output_dir.mkdir(parents=True, exist_ok=True) + src_f16.tofile(output_dir / "v1.bin") + src_f8.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32, copy=False).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto new file mode 100644 index 0000000000..9cedd97e60 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_f8_mul_store_kernel(%src_f16_gm: !pto.ptr, + %src_f8_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c320 = arith.constant 320 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c256_i64 = arith.constant 256 : i64 + %c320_i64 = arith.constant 320 : i64 + %c512_i64 = arith.constant 512 : i64 + %c640_i64 = arith.constant 640 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_f16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_f8_u8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_f8 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_f16_gm, %ub_f16, %c0_i64, %c512_i64 + nburst(%c2_i64, %c640_i64, %c640_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_f8_gm, %ub_f8_u8, %c0_i64, %c256_i64 + nburst(%c2_i64, %c320_i64, %c320_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %src_f16 = pto.vmi.group_load %ub_f16[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf16> + %src_f16_f32 = pto.vmi.extf %src_f16 + : !pto.vmi.vreg<512xf16> -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %src_f16_f32, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %src_f8 = pto.vmi.group_load %ub_f8[%c0], %c320 {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf8E4M3FN> + %src_f8_f32 = pto.vmi.extf %src_f8 + : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> + %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.mulf %sum_vec, %src_f8_f32 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c320 {num_groups = 2} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c2_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp new file mode 100644 index 0000000000..03bf4d7e8f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_f8_mul_store_kernel(__gm__ half *src_f16, + __gm__ uint8_t *src_f8, + __gm__ float *dst); + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream) { + vmi_group_reduce_f16_f8_mul_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src_f16, (__gm__ uint8_t *)src_f8, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp new file mode 100644 index 0000000000..e5769e3978 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_f8_mul_store_kernel(uint16_t *src_f16, + uint8_t *src_f8, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 2; + constexpr size_t kRowStride = 320; + constexpr size_t kElems = kRows * kRowStride; + size_t srcF16Bytes = kElems * sizeof(uint16_t); + size_t srcF8Bytes = kElems * sizeof(uint8_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcF16Host = nullptr; + uint16_t *srcF16Device = nullptr; + uint8_t *srcF8Host = nullptr; + uint8_t *srcF8Device = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF16Host), srcF16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcF8Host), srcF8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcF16Device, srcF16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcF8Device, srcF8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcF16Bytes, srcF16Host, srcF16Bytes); + ReadFile("./v2.bin", srcF8Bytes, srcF8Host, srcF8Bytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcF16Device, srcF16Bytes, srcF16Host, srcF16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcF8Device, srcF8Bytes, srcF8Host, srcF8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_f8_mul_store_kernel(srcF16Device, srcF8Device, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcF16Device); + aclrtFree(srcF8Device); + aclrtFree(dstDevice); + aclrtFreeHost(srcF16Host); + aclrtFreeHost(srcF8Host); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From e2efffab9d16f470d6b4685052ee29bad8812176 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 00:33:13 +0800 Subject: [PATCH 05/54] feat: new layout-lowering design --- docs/designs/vmi-layout-lowering-cases.md | 3103 +++++++++++++++++++++ docs/isa/micro-isa/10-reduction-ops.md | 28 +- 2 files changed, 3118 insertions(+), 13 deletions(-) create mode 100644 docs/designs/vmi-layout-lowering-cases.md diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md new file mode 100644 index 0000000000..807baf841e --- /dev/null +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -0,0 +1,3103 @@ +# VMI Layout Lowering Cases + +本文是 VMI layout/lowering 的典型 case catalog,不是完整设计总文档。它只回答一个问题: +一个 VMI logical vector 在某个场景下选择某种 layout 后,`vmi-to-vpto` 必须生成什么 +VPTO 结果。这里不写动机式描述;每个场景都给出 layout assignment 和 lowering result。 + +## 1. Layout Families + +### 1.1 Dense Layout + +Dense layout 的每个 logical lane 都有语义值。 + +```text +#pto.vmi.layout +``` + +Physical ordering: + +```text +chunk c, lane l -> logical lane c * L + l +``` + +`L` is the physical lanes per 256B VPTO vector register for the element type. + +```text +#pto.vmi.layout +``` + +`block_elems` defaults to `1`. Existing spellings are shorthands: + +```text +#pto.vmi.layout + == #pto.vmi.layout + +#pto.vmi.layout + == #pto.vmi.layout +``` + +Logical-to-physical mapping: + +```text +logical lane i +block q = i / B +in_block lane r = i % B +part p = q % F +part_block t = q / F + +physical part p, physical lane t * B + r +``` + +Required invariants: + +```text +F > 0 +B > 0 +N % (F * B) == 0 for the direct full-chunk paths in this document +``` + +### 1.2 Sparse Group-Slot Layout + +Sparse group-slot layout is not dense. Only `G` lanes have semantic values. + +```text +#pto.vmi.layout +``` + +Physical slot mapping: + +```text +N = logical lane count +S = N / G // logical lanes per source group + +slot_block(g) = g / K +slot_lane(g) = g % K +``` + +Required invariants: + +```text +G > 0 +K > 0 +G % K == 0 +K must fit in the physical vreg element count +``` + +`K` is selected by the producer/consumer plan. It is not always 8. For +`VCGADD`-packed results, `K = 8` matches the eight 32B block results written to +the low lanes of one destination vreg. For row-local reductions where each +logical group already occupies one full 256B vreg, `K = 1` keeps each group's +scalar result in lane 0 of its own physical vreg and avoids an unsupported +cross-vreg scalar pack. + +Only these lanes are semantic: + +```text +physical slot block slot_block(g), lane slot_lane(g) +``` + +All other lanes are undefined for ordinary VMI consumers. They may only be read +by group-aware ops that define how to interpret group slots. + +## 2. Plan Selection Rules + +VMI cast ops must not hard-code one physical `vcvt` plan as their semantic +layout rule. + +```text +dense cast: + source/result are dense layouts. + lowering may require deinterleaved(F, block_elems=1) around VCVT. + +group-slot cast: + source/result are both group_slots(G,K). + lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are + legal only when a slot-preserving VPTO plan is registered, or when the cast + can be commuted through a later group-aware consumer such as group_broadcast. +``` + +Illegal consumer mix: + +```text +group_slots value -> ordinary dense store/add/mul +``` + +This must fail unless an explicit semantic op converts the sparse value: + +```text +group_broadcast +group_store +future explicit group-pack op +``` + +## 3. Lowering Results + +The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true +predicate with the element granularity required by the instruction. `PAT_VLk` +means a prefix predicate for the first `k` lanes. + +Completeness rule for this section: every numbered endpoint below must contain +VMI input, assigned layouts, VPTO lowering result, and either a memory result or +an explicit diagnostic. Non-endpoint layout notes may appear only as setup for +the immediately following complete endpoints. + +```text +3.1 f16 -> f32 -> store complete +3.2 f32 -> f16 -> store complete +3.3 f8 -> f32 -> compute -> f8 complete +3.4 group_reduce S=8 -> group_store complete +3.5.1 group_reduce S=16 -> group_store complete +3.5.2 group_reduce S=16 -> broadcast -> compute -> reduce -> store + complete +3.5.3 group_reduce S=16 -> elemwise(rhs) -> group_store complete +3.6.1 group_reduce S=32 -> group_store complete +3.6.2 group_reduce S=32 -> elemwise(rhs) -> group_store complete +3.6.3 group_reduce S=32 -> broadcast -> compute -> reduce -> store + complete +3.7.1 group_reduce S=64 -> group_store complete +3.7.2 group_reduce S=64 -> elemwise(rhs) -> group_store complete +3.7.3 group_reduce S=64 -> broadcast -> compute -> reduce -> store + complete +3.8 group_reduce -> truncf -> broadcast -> dense store complete +3.9 dense store of group slots illegal diagnostic +3.10 non-load producer feeding S=32 group_reduce complete +3.11 partial tail groups complete/diagnostic +3.12 control-flow join before group_reduce complete +3.13 direct group-slot f32 -> f16 cast illegal diagnostic +3.14 unsupported group size illegal diagnostic +3.15 compact S=12 written as logical S=16 complete/design +3.16 group_slot_load layout contract complete +3.17 group_broadcast physical arity alias complete +3.18 one value with dense and group-reduce consumers complete/materialization +3.19 S=16 reduce block_elems plan selection complete/diagnostic +3.20 group_slots control-flow join complete +3.21 S=32 tail with full-tile-readable source complete/design +3.22 scf.for loop-carried layout complete +3.23 group_broadcast with multiple dense consumers complete +3.24 mask with elementwise/select/store complete +3.25 function boundary layout specialization complete/design +``` + +### 3.1 `f16 -> f32 -> store` + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +pto.vmi.store %x32, %out[%off] +``` + +Assigned layouts: + +```text +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +pto.vstsx2 %x32_p0, %x32_p1, %out[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Alternative complete VPTO lowering result if `vstsx2 INTLV_B32` is unavailable: + +```text +%x16_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xf16> + +%x32_p0 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%x32_d0, %x32_d1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %x32_d0, %out[%off], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x32_d1, %out[%off_plus_64], PAT_ALL_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = extf(base[off + i]) +``` + +### 3.2 Dense `f32 -> f16 -> store` + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %x16, %out[%off] +``` + +Assigned layouts: + +```text +%x32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%x16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x32_p0, %x32_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Alternative complete VPTO lowering result if the source has already been loaded +as two contiguous f32 chunks and must be materialized to `deinterleaved=2` before +the conversion: + +```text +%x32_d0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_d1 = pto.vlds %base[%off_plus_64] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x32_p0, %x32_p1 = pto.vdintlv %x32_d0, %x32_d1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%part0 = pto.vcvt %x32_p0, PAT_ALL_B32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%part1 = pto.vcvt %x32_p1, PAT_ALL_B32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%x16_0 = pto.vor %part0, %part1, PAT_ALL_B16 + : !pto.vreg<128xf16> + +pto.vsts %x16_0, %out[%off], PAT_ALL_B16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +### 3.3 Dense `f8 -> f32 -> compute -> f8` + +VMI input: + +```text +%x8 = pto.vmi.load %base[%off] +%x32 = pto.vmi.extf %x8 +%scale = pto.vmi.broadcast %scale_s : f32 -> !pto.vmi.vreg<256xf32> +%y32 = pto.vmi.mulf %x32, %scale +%y8 = pto.vmi.truncf %y32 +pto.vmi.store %y8, %out[%off] +``` + +Assigned layouts: + +```text +%x8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +%x32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%scale : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y32 : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%y8 : !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%x8_0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%x32_p0 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_0, PAT_ALL_B8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, PAT_ALL_B32 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%y32_p0 = pto.vmul %x32_p0, %scale_p0, PAT_ALL_B32 +%y32_p1 = pto.vmul %x32_p1, %scale_p1, PAT_ALL_B32 +%y32_p2 = pto.vmul %x32_p2, %scale_p2, PAT_ALL_B32 +%y32_p3 = pto.vmul %x32_p3, %scale_p3, PAT_ALL_B32 + +%y8_p0 = pto.vcvt %y32_p0, PAT_ALL_B32 + {part = "P0", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p1 = pto.vcvt %y32_p1, PAT_ALL_B32 + {part = "P1", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p2 = pto.vcvt %y32_p2, PAT_ALL_B32 + {part = "P2", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> +%y8_p3 = pto.vcvt %y32_p3, PAT_ALL_B32 + {part = "P3", rnd = "R", sat = "SAT"} -> !pto.vreg<256xf8> + +%y8_01 = pto.vor %y8_p0, %y8_p1, PAT_ALL_B8 +%y8_23 = pto.vor %y8_p2, %y8_p3, PAT_ALL_B8 +%y8_0 = pto.vor %y8_01, %y8_23, PAT_ALL_B8 + +pto.vsts %y8_0, %out[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + out[off + i] = truncf(extf(base[off + i]) * scale_s) +``` + +### 3.4 `group_reduce` S=8 f32 + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<64xf32, #pto.vmi.layout> +%mask : !pto.vmi.mask<64xpred, #pto.vmi.layout> +%sum : !pto.vmi.vreg<64xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%mask_chunk = pto.pge_b32 "PAT_ALL" + +%x_chunk = pto.vlds %base[%tile_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%sum_block = pto.vcgadd %x_chunk, %mask_chunk + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Lowering result for one chunk, per the `visa.txt` VCGADD contract: + +```text +%sum_block lane 0 = reduce %x lanes 0..7 +%sum_block lane 1 = reduce %x lanes 8..15 +... +%sum_block lane 7 = reduce %x lanes 56..63 +all non-slot lanes are non-semantic +``` + +Layout result: + +```text +G = N / 8 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..7]) +``` + +### 3.5 `group_reduce` S=16 f32, load-fused split + +The facts used by this lowering are checked against the current repo: + +```text +pto.vldsx2 supports "BDINTLV". +pto.vstsx2 supports only "INTLV_B8" / "INTLV_B16" / "INTLV_B32". +visa.txt says VCGADD writes one 32B-block result continuously to destination +LSBs; the current repository golden tests follow lanes 0..7 for f32. +``` + +There are three complete consumers for this layout today: + +```text +load -> group_reduce -> group_store(sum) +load -> group_reduce -> elementwise compute on group-slot values + -> group_store +load -> group_reduce -> group_broadcast -> elementwise compute + -> group_reduce -> group_store +``` + +#### 3.5.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> + +%sum : !pto.vmi.vreg> +``` + +For each 8-row tile: + +```text +row r = 16xf32 = row_r.lo8, row_r.hi8 +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes 0..7 = row0.lo8 +%lo lanes 8..15 = row1.lo8 +... +%lo lanes 56..63 = row7.lo8 + +%hi lanes 0..7 = row0.hi8 +%hi lanes 8..15 = row1.hi8 +... +%hi lanes 56..63 = row7.hi8 + +%lo_sum = pto.vcgadd %lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %sum_block, %sum_out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +`BDINTLV` here denotes the ISA `#bdintlv` block-based interleaving load mode: +it loads `2 * VL` bytes and sends even 32B blocks to the first destination +register and odd 32B blocks to the second destination register. For f32, +one 32B block is `8xf32`, matching `block_elems = 8`. + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Layout result: + +```text +G = N / 16 +K = 8 + +slot_block(g) = g / 8 +slot_lane(g) = g % 8 + +%sum_block lane 0 = reduce row0 lanes 0..15 +%sum_block lane 1 = reduce row1 lanes 0..15 +... +%sum_block lane 7 = reduce row7 lanes 0..15 +``` + +No VMI value exposes `%lo_sum` or `%hi_sum`. They are internal VPTO values. + +Memory result: + +```text +sum_out[group_tile_off + 0] = reduce row0 lanes 0..15 +sum_out[group_tile_off + 1] = reduce row1 lanes 0..15 +... +sum_out[group_tile_off + 7] = reduce row7 lanes 0..15 +``` + +This endpoint is fully specified: the only sparse value is `%sum`; `group_store` +stores the low 8 slot lanes with an ordinary prefix store. + +#### 3.5.2 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref -> !pto.vmi.vreg +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = N / 16} +%b = pto.vmi.group_broadcast %sum {num_groups = N / 16} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = N / 16} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = N / 16} +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg> +%sum : !pto.vmi.vreg> +%b : !pto.vmi.vreg> +%y : !pto.vmi.vreg> +%ysum : !pto.vmi.vreg> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// This is the materialization of pto.vmi.group_broadcast. The group sums are +// in %sum_block lanes 0..7; vselr expands each sum to the 8 lanes of the +// corresponding row half. The following vmul/vcgadd consume an ordinary dense +// physical vector. +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 + : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %all_b32 + : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Final per-row reduction and store. +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%store8 = pto.pge_b32 "PAT_VL8" +pto.vsts %ysum_block, %out[%group_tile_off], %store8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +This trace processes 8 logical rows at once. `num_groups = N / 16` means each +logical group is one `16xf32` row, and one full f32 VPTO tile covers 8 such +groups: + +```text +64 f32 lanes per physical part = 8 rows * 8 f32 lanes per half-row +``` + +Tail tiles use the same dataflow with `%all_b32` replaced by masks derived from +the VMI mask for the low and high 8-lane halves of each row. + +Physical lane result for the tile: + +```text +%x_lo lanes 0..7 = row0[0..7] +%x_lo lanes 8..15 = row1[0..7] +... +%x_lo lanes 56..63 = row7[0..7] + +%x_hi lanes 0..7 = row0[8..15] +%x_hi lanes 8..15 = row1[8..15] +... +%x_hi lanes 56..63 = row7[8..15] + +%sum_block lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%b_rows lanes 0..7 = reduce(row0[0..15]) +%b_rows lanes 8..15 = reduce(row1[0..15]) +... +%b_rows lanes 56..63 = reduce(row7[0..15]) + +For each row `r` in this 8-row tile: + +%y_lo lanes r*8 .. r*8+7 = + row_r[0..7] * reduce(row_r[0..15]) + +%y_hi lanes r*8 .. r*8+7 = + row_r[8..15] * reduce(row_r[0..15]) + +Concretely: +%y_lo lanes 0..7 = row0[0..7] * reduce(row0[0..15]) +%y_lo lanes 8..15 = row1[0..7] * reduce(row1[0..15]) +... +%y_lo lanes 56..63 = row7[0..7] * reduce(row7[0..15]) + +%y_hi lanes 0..7 = row0[8..15] * reduce(row0[0..15]) +%y_hi lanes 8..15 = row1[8..15] * reduce(row1[0..15]) +... +%y_hi lanes 56..63 = row7[8..15] * reduce(row7[0..15]) + +%ysum_block lanes 0..7 = + reduce(%y row0), reduce(%y row1), ..., reduce(%y row7) +``` + +Memory result: + +```text +out[group_tile_off + r] = + reduce_i((row_r[i] * reduce_j(row_r[j])) for i in 0..15) + = reduce(row_r[0..15]) * reduce(row_r[0..15]) +for r = 0..7 +``` + +If a later consumer requires row-major contiguous order, `vmi-to-vpto` must +materialize: + +```text +deinterleaved=2, block_elems=8 -> contiguous +``` + +This materialization cannot be implemented with `vstsx2 INTLV_B32`, because +that instruction interleaves individual b32 elements, not 32B row halves. Until +a concrete block-interleave register materialization or store op is selected, +row-major store of this layout must be rejected with: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store requires materializing + #pto.vmi.layout to contiguous, but no + VPTO block-interleave materialization/store plan is registered. +``` + +#### 3.5.3 Reduce Result, Elementwise, Store + +This case computes a per-row reduction, applies an elementwise operation to the +reduced values themselves, and stores one result per group. There is no +`group_broadcast` in this flow because the elementwise op is not applied to the +original `8x16xf32` matrix elements. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x for reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +For this endpoint, the RHS is a packed per-group vector: + +```text +rhs_base[rhs_off + r] = rhs(row r), for r = 0..7 +``` + +Layout assignment must treat `group_slot_load` as a group-slot producer: one +f32 value per group is placed in the live slot lanes. It must not use +`group_load`, which loads `group_size` data elements per group instead of one +per-group scalar. + +The elementwise op runs only on the live group-slot lanes: + +```text +%sum lanes 0..7 = + reduce(row0[0..15]), reduce(row1[0..15]), ..., reduce(row7[0..15]) + +%rhs lanes 0..7 = + rhs(row0), rhs(row1), ..., rhs(row7) + +%outv lanes 0..7 = + %sum lanes 0..7 + %rhs lanes 0..7 + +lanes 8..63 remain dead/zero and are masked off by PAT_VL8. +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// Reduction path: use BDINTLV to feed two VCG reductions. +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +// One 32B block contains 8 f32 RHS values and materializes lanes 0..7; all +// other lanes are dead/zero. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Elementwise compute on group-slot values. Only lanes 0..7 are live. +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[group_tile_off + r] = s + rhs[r] +``` + +### 3.6 `group_reduce` S=32 f32, 4-way split + +This case covers one `8x32xf32` tile. Each logical row is 128B, so it must be +split into four 32B partial rows before `vcgadd` can reduce it efficiently. + +The canonical layout for the input is: + +```text +%x : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +With `deinterleaved = 4`, physical part `p` contains columns whose logical +column index is `p mod 4`: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[4], row_r[8], ..., row_r[28] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[5], row_r[9], ..., row_r[29] + +%x_p2 lanes r*8 .. r*8+7 = + row_r[2], row_r[6], row_r[10], ..., row_r[30] + +%x_p3 lanes r*8 .. r*8+7 = + row_r[3], row_r[7], row_r[11], ..., row_r[31] +``` + +Each physical part now has exactly 8 f32 values per row, so one `vcgadd` per +part computes one partial sum per row. The four partial sums are then added +under `PAT_VL8`. + +The full contiguous-to-4-way materialization for one tile should fuse the first +deinterleave level into the load. `vldsx2 DINTLV_B32` loads `2 * VL` bytes and +splits even/odd f32 elements into two physical vectors. Two such loads cover +the `8x32xf32` tile, and a second register `vdintlv` level splits even columns +into `mod4 = 0/2` and odd columns into `mod4 = 1/3`. + +This setup documentation is repeated inside every complete 32-wide endpoint +below. + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +``` + +Each endpoint below inlines this materialization before the first consumer of +`%x_p0..%x_p3`. + +#### 3.6.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..31]) +``` + +#### 3.6.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +// Packed RHS group-slot load. %rhs_tile_base points to rhs_base[rhs_off]. +%rhs_block = pto.vsldb %rhs_tile_base, %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%outv_block = pto.vadd %sum_block, %rhs_block, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %outv_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..31]) + rhs[r] +``` + +#### 3.6.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// group_broadcast materialized for each deinterleaved=4 physical part. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%ys01 = pto.vadd %ys0, %ys1, %sum_mask : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %sum_mask : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..31]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..31) + = s * s +``` + +### 3.7 `group_reduce` S=64 f32, row-local reduction + +This case covers one `8x64xf32` tile. Each logical row is exactly 256B, so the +input does not need a deinterleaved layout: + +```text +row r = 64xf32 = one !pto.vreg<64xf32> +``` + +The reduction is two-stage but row-local: + +```text +vcgadd(row_r) -> 8 partial sums in lanes 0..7 +vcadd(PAT_VL8) -> one row sum in lane 0 +``` + +The result layout is therefore not `slots = 8`. It is: + +```text +#pto.vmi.layout +``` + +Physical slot mapping for this tile: + +```text +slot_block(r) = r +slot_lane(r) = 0 + +%sum0 lane 0 = reduce row0 lanes 0..63 +%sum1 lane 0 = reduce row1 lanes 0..63 +... +%sum7 lane 0 = reduce row7 lanes 0..63 +``` + +Trying to canonicalize this result to `slots = 8` would require packing lane 0 +from eight different physical vregs into lanes 0..7 of one vreg. This document +does not use that plan. `slots = 1` is the canonical layout for S=64 row-local +group reductions. + +#### 3.7.1 Reduce And Store Group Sums + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %sum0, %sum_out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum1, %sum_out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum2, %sum_out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum3, %sum_out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum4, %sum_out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum5, %sum_out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum6, %sum_out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum7, %sum_out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_tile_off + r] = reduce(row_r[0..63]) +``` + +#### 3.7.2 Reduce Result, Elementwise, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%outv = pto.vmi.addf %sum, %rhs +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %rhs, %outv: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%x0 = pto.vlds %base[%row_off_0] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%row_off_1] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%row_off_2] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%row_off_3] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x4 = pto.vlds %base[%row_off_4] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x5 = pto.vlds %base[%row_off_5] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x6 = pto.vlds %base[%row_off_6] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> +%x7 = pto.vlds %base[%row_off_7] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> + +%p0 = pto.vcgadd %x0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p1 = pto.vcgadd %x1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p2 = pto.vcgadd %x2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p3 = pto.vcgadd %x3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p4 = pto.vcgadd %x4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p5 = pto.vcgadd %x5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p6 = pto.vcgadd %x6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%p7 = pto.vcgadd %x7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum0 = pto.vcadd %p0, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum1 = pto.vcadd %p1, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum2 = pto.vcadd %p2, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum3 = pto.vcadd %p3, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum4 = pto.vcadd %p4, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum5 = pto.vcadd %p5, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum6 = pto.vcadd %p6, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum7 = pto.vcadd %p7, %block8 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%rhs0 = pto.vsldb %rhs_ptr_0, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs1 = pto.vsldb %rhs_ptr_1, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs2 = pto.vsldb %rhs_ptr_2, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs3 = pto.vsldb %rhs_ptr_3, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs4 = pto.vsldb %rhs_ptr_4, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs5 = pto.vsldb %rhs_ptr_5, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs6 = pto.vsldb %rhs_ptr_6, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%rhs7 = pto.vsldb %rhs_ptr_7, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%out0 = pto.vadd %sum0, %rhs0, %one_b32 : !pto.vreg<64xf32> +%out1 = pto.vadd %sum1, %rhs1, %one_b32 : !pto.vreg<64xf32> +%out2 = pto.vadd %sum2, %rhs2, %one_b32 : !pto.vreg<64xf32> +%out3 = pto.vadd %sum3, %rhs3, %one_b32 : !pto.vreg<64xf32> +%out4 = pto.vadd %sum4, %rhs4, %one_b32 : !pto.vreg<64xf32> +%out5 = pto.vadd %sum5, %rhs5, %one_b32 : !pto.vreg<64xf32> +%out6 = pto.vadd %sum6, %rhs6, %one_b32 : !pto.vreg<64xf32> +%out7 = pto.vadd %sum7, %rhs7, %one_b32 : !pto.vreg<64xf32> + +pto.vsts %out0, %out[%group_tile_off_0], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out1, %out[%group_tile_off_1], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out2, %out[%group_tile_off_2], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out3, %out[%group_tile_off_3], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out4, %out[%group_tile_off_4], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out5, %out[%group_tile_off_5], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out6, %out[%group_tile_off_6], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %out7, %out[%group_tile_off_7], %one_b32 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..63]) + rhs[r] +``` + +#### 3.7.3 Reduce, Broadcast, Elementwise, Reduce, Store + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// The compiler emits this row-local block once for each r in 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> + +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// This vdup is the lowering of pto.vmi.group_broadcast for slots=1. +%b_r = pto.vdup %sum_r, %all_b32 {position = "LOWEST"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%y_r = pto.vmul %x_r, %b_r, %all_b32 : !pto.vreg<64xf32> + +%yp_r = pto.vcgadd %y_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_r = pto.vcadd %yp_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %ysum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +The row-local block above is not a runtime loop requirement. It is the repeated +VPTO shape for row offsets `%row_off_0` through `%row_off_7` and store offsets +`%group_tile_off_0` through `%group_tile_off_7`. + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..63]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..63) + = s * s +``` + +### 3.8 `group_reduce -> truncf -> group_broadcast -> store` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +%b16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +pto.vmi.store %b16, %out[%off] +``` + +Assigned layouts: + +```text +%x : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum32 : !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +%sum16 : semantic value only; not materialized as a group-slot VPTO value +%b32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +This case is supported by commuting `truncf` after `group_broadcast`: + +```text +group_broadcast(truncf(group_reduce(x))) + == truncf(group_broadcast(group_reduce(x))) +``` + +This avoids materializing a group-slot f16 value. The only cast emitted is the +existing dense `f32 deinterleaved=2 -> contiguous f16` truncation. + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum32_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// This vselr is the VPTO lowering of pto.vmi.group_broadcast. The later store +// only writes lanes as-is; it does not duplicate group-slot values. +%b32_rows = pto.vselr %sum32_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// The broadcasted f32 value is dense deinterleaved=2. +// Both parity parts carry the same per-row broadcast values. +%b16_even = pto.vcvt %b32_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%b16_odd = pto.vcvt %b32_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%b16 = pto.vor %b16_even, %b16_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %b16, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s32 = reduce(row_r[0..15]) + s16 = truncf(s32) + out[r * 16 + 0 .. r * 16 + 15] = splat(s16) +``` + +### 3.9 Illegal Dense Consumer Of Group Slots + +VMI input: + +```text +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = G} +pto.vmi.store %sum32, %out[%off] +``` + +Assigned layouts before the illegal consumer: + +```text +%sum32 : group_slots(G,K) +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.store cannot consume #pto.vmi.layout + as a dense vector. Use pto.vmi.group_store, pto.vmi.group_broadcast, or an + explicit group-pack op. +``` + +It must not be diagnosed as: + +```text +dense store materializes group slots implicitly +``` + +That behavior would silently reinterpret a sparse group-slot value as a dense +vector. + +### 3.10 Non-Load Producer Feeding S=32 `group_reduce` + +This case proves that layout assignment is consumer-driven. The producer of the +S=32 input is an elementwise op, not a load. The S=32 `group_reduce` still +requires the elementwise result to be `deinterleaved = 4`, and that requirement +must propagate backward through the elementwise op to both operands. + +VMI input: + +```text +%a = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.addf %a, %bias +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %bias, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one full `8x32xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%a_even_0, %a_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_even_1, %a_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%bias_p0 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p1 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p2 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%bias_p3 = pto.vdup %bias_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%x_p0 = pto.vadd %a_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> +%x_p1 = pto.vadd %a_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> +%x_p2 = pto.vadd %a_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> +%x_p3 = pto.vadd %a_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce_i(base[row_r, i] + bias_s for i = 0..31) +``` + +### 3.11 Partial Tail Groups + +Tail handling must be separated by the physical input layout. Row-local S=64 +can avoid inactive rows entirely. Load-fused S=16/S=32 cannot safely do that +with the current `vldsx2` materialization unless the source is known to be +full-tile readable. + +#### 3.11.1 S=64 Active Row Tail + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<384xf32> -> !pto.vmi.vreg<384xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<384xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local block for r = 0..5 only. No load or store is emitted for +// rows 6 and 7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %sum_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_tile_off + r] = reduce(row_r[0..63]) +``` + +#### 3.11.2 S=32 Tail Without Full-Tile Read Contract + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layout requested by the consumer: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +Required diagnostic when the source does not carry a full-tile-readable +contract: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with group size 32 and num_groups tail 6 requires + materializing #pto.vmi.layout. The registered fast plan + uses vldsx2 DINTLV_B32 over a full 8-row tile. This source is not marked + full-tile-readable, and the stable gather tail plan is not implemented. +``` + +If a future option enables the stable gather tail plan, the same VMI input may +lower by gathering only the active lanes. Until that plan is registered, the +converter must not silently issue the full-tile `vldsx2` loads. + +### 3.12 Control-Flow Join Before `group_reduce` + +The layout carried by a value must survive block arguments. In MLIR converter +terms, the logical VMI value lowered through control flow becomes a tuple of +physical VPTO values with one tuple type per assigned layout. + +VMI input: + +```text +%x = scf.if %cond -> !pto.vmi.vreg<256xf32> { + %a = pto.vmi.load %a_base[%a_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %a : !pto.vmi.vreg<256xf32> +} else { + %b = pto.vmi.load %b_base[%b_off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + scf.yield %b : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%a, %b, %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = + scf.if %cond + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %a_even_0, %a_odd_0 = pto.vldsx2 %a_base[%a_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_even_1, %a_odd_1 = pto.vldsx2 %a_base[%a_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p0, %a_p2 = pto.vdintlv %a_even_0, %a_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %a_p1, %a_p3 = pto.vdintlv %a_odd_0, %a_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %a_p0, %a_p1, %a_p2, %a_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } else { + %b_even_0, %b_odd_0 = pto.vldsx2 %b_base[%b_tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_even_1, %b_odd_1 = pto.vldsx2 %b_base[%b_tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p0, %b_p2 = pto.vdintlv %b_even_0, %b_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %b_p1, %b_p3 = pto.vdintlv %b_odd_0, %b_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + scf.yield %b_p0, %b_p1, %b_p2, %b_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +``` + +The consumer after the join is the same S=32 reduction plan as section 3.6: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + selected_row = cond ? a_row_r : b_row_r + out[group_tile_off + r] = reduce(selected_row[0..31]) +``` + +If the two branches cannot be assigned the same layout and no materialization +plan exists before `scf.yield`, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + scf.yield joins incompatible VMI layouts for !pto.vmi.vreg<256xf32>. + Expected #pto.vmi.layout on every incoming value. +``` + +### 3.13 Direct Group-Slot `f32 -> f16` Cast + +This case is intentionally illegal for the current S=16/S=32 packed +group-slot layout. It prevents the compiler from treating a width-changing +`vcvt` as if it preserved low-lane group slots. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts before the illegal cast: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf cannot lower from + #pto.vmi.layout f32 to f16 because no + slot-preserving width-changing VPTO plan is registered. f32->f16 vcvt writes + even/odd sub-lanes, not lanes 0..7. Use group_broadcast before truncf, or + keep the group_store element type as f32. +``` + +This does not contradict section 3.8. Section 3.8 is legal because the cast is +commuted after `group_broadcast`, where the value is dense again. + +### 3.14 Unsupported Group Size + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<96xf32> -> !pto.vmi.vreg<96xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based plans use +32B groups, i.e. 8 f32 elements per row fragment: + +```text +S = 8 -> one VCGADD block per group +S = 16 -> two 8-lane row fragments, add partial sums +S = 32 -> four 8-lane row fragments, add partial sums +S = 64 -> one full 256B row, VCGADD then VCADD +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_reduce_addf with f32 group size 12 has no registered VPTO + layout plan. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. + A scalar/gather fallback or a rewrite to logical group size 16 with an + explicit per-group mask is required. +``` + +### 3.15 Compact S=12 Written As Logical S=16 + +If the program wants to use the S=16 lowering for data with 12 semantic f32 +elements per group, the IR must distinguish two sizes: + +```text +logical group size used by VMI ops: 16 +active elements per group: 12 +``` + +The mask is not a prefix mask over the whole vector. It is a per-group mask: + +```text +mask lane i is active iff (i % 16) < 12 +``` + +The group load surface carries the physical source stride as an SSA operand: + +```text +%x = pto.vmi.group_load %base[%off], %source_group_stride + {num_groups = G, group_size = S} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +`source_group_stride` is in elements, not bytes. It is an operand because it may +come from a dynamic leading dimension, a subview, or a runtime tile descriptor. +Static strides use a constant index operand and can be canonicalized later. +`group_size` remains an attribute in this design because it selects the logical +load layout. `active_elems_per_group` belongs to the mask producer, not to the +load. + +Grouped masks use a paired `pto.vmi.create_group_mask` op. It is intentionally +separate from ordinary prefix `pto.vmi.create_mask` so the IR makes group +semantics explicit next to `pto.vmi.group_load` / `pto.vmi.group_reduce_*`: + +```text +%mask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = G, group_size = S} + : index -> !pto.vmi.mask<(G*S)xpred> +``` + +Semantics: + +```text +lane i is active iff (i % S) < active_elems_per_group +``` + +Ordinary `pto.vmi.create_mask %active_lanes` keeps the prefix-mask meaning: + +```text +lane i is active iff i < active_lanes +``` + +#### 3.15.1 Existing Design Works If Source Row Stride Is 16 + +If memory already has a 16-f32 row stride, the user can write a logical S=16 +tile and mask off the last four lanes of every group. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%lo, %hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+3 = row_r[8..11] +%hi lanes r*8+4 .. r*8+7 = row_r[12..15] // inactive by mask + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = reduce(row_r[0..11]) +``` + +Design requirement added by this case: VMI mask lowering must support +group-periodic masks by generating the predicate from lane indices. It must not +rewrite this mask to `PAT_M4`: VISA defines `M4` as multiples of 4, not the +first four lanes of each 8-lane block. + +```text +lane = vci(0) +row = lane >> 3 +col = lane - (row << 3) +mask = col < 4 +``` + +#### 3.15.2 Source Row Stride Greater Than 16 + +For now, support the non-compact case where each physical row has at least 16 +f32 slots and the row stride is greater than 16. The fast strided-block path +requires the row stride to be a multiple of one 32B block: + +```text +source_group_stride % 8 == 0 +``` + +The example below uses `source_group_stride = 24`. Each row has 12 semantic +values, 4 masked-but-readable slots, and 8 extra skipped slots: + +```text +row_r[0..11] semantic +row_r[12..15] readable but inactive for the S=16 logical group +row_r[16..23] outside the logical group +``` + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts are the same as section 3.15.1: + +```text +%x, %mask: + #pto.vmi.layout +%sum: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +%lo_mask = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %lo_mask + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %lo_mask + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %lo_mask, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +// source_group_stride = 24 f32 = 3 * 32B blocks. +%stride_blocks = %c3_i16 + +%base_lo = %base + tile_off +%base_hi = %base + tile_off + 8 + +%lo = pto.vsldb %base_lo, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%hi = pto.vsldb %base_hi, %stride_blocks, %c0_i16, %lo_mask + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%lo lanes r*8 .. r*8+7 = row_r[0..7] +%hi lanes r*8 .. r*8+7 = row_r[8..15] + +%lo_sum = pto.vcgadd %lo, %lo_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 24 + 0 .. tile_off + r * 24 + 11]) +``` + +If `source_group_stride > 16` but is not a multiple of 8 f32 elements, this +strided-block path is not legal because `vsldb` block addresses are 32B based. +That case remains unsupported until a gather materialization is selected. + +#### 3.15.3 Compact Source Row Stride 12 + +Compact storage is explicitly out of scope for the first implementation: + +```text +row0[0..11], row1[0..11], row2[0..11], ... +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + logical group size 16 with active_elems_per_group 12 and + source_group_stride 12 requires compact-row gather materialization. This + plan is not part of the initial VMI layout lowering. +``` + +### 3.16 `group_slot_load` Layout Contract + +`group_slot_load` is separate from `group_load`. + +```text +group_load: + loads group_size data elements per group and produces dense grouped data. + +group_slot_load: + loads one scalar value per group and produces sparse group slots. +``` + +Surface form: + +```text +%v = pto.vmi.group_slot_load %base[%off], %source_group_stride + {num_groups = G} + : !pto.ptr, index -> !pto.vmi.vreg +``` + +Semantics: + +```text +semantic group slot g = base[off + g * source_group_stride] +``` + +The result logical lane count `N` remains the surrounding VMI value shape. Only +the `G` group slots are semantic. Layout assignment chooses the sparse physical +placement requested by the consumer: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +#### 3.16.1 Packed `group_slot_load`, `slots = 8` + +VMI input: + +```text +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%slot_mask = pto.pge_b32 "PAT_VL8" +%one_block = pto.pge_b32 "PAT_VL1" + +// source_group_stride = 1, so one 32B block contains all 8 scalar group slots. +%rhs_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for g = 0..7: + out[group_off + g] = rhs_base[rhs_off + g] +``` + +If `source_group_stride != 1`, this packed `slots = 8` plan requires a +strided/gather group-slot load materializer. Until that plan is registered, +`group_slot_load` with `slots = 8` and non-unit stride must diagnose instead of +silently using full-group `group_load`. + +#### 3.16.2 Row-Local `group_slot_load`, `slots = 1` + +VMI input: + +```text +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. Each result value carries one semantic slot +// in lane 0, matching the S=64 row-local group_reduce result layout. +%rhs_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_off_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = rhs_base[rhs_off + r] +``` + +### 3.17 `group_broadcast` Physical Arity Alias + +This case fixes a lowering invariant: a layout determines physical arity. A +`deinterleaved = 2` result has two physical bundle entries even when both +entries can reuse the same VPTO SSA value. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// Physical bundle binding for %b, not emitted VPTO ops: +// physical entry 0 = %b_rows +// physical entry 1 = %b_rows +// The layout still has two physical entries; they alias the same SSA value +// because every even/odd logical lane pair contains the same broadcast value. + +%h_even = pto.vcvt %b_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %b_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> + +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +### 3.18 One Value With Dense And Group-Reduce Consumers + +This case forces layout assignment to handle a solvable use-site conflict. One +consumer requires an S=32 group-reduce layout; another consumer requires dense +row-major store. This is not semantically illegal. It must be solved by +use-site materialization or producer rematerialization when a registered plan +exists. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x, %copy_out[%off] +``` + +Assigned layouts: + +```text +%x for group_reduce: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x for dense store: + requires #pto.vmi.layout +``` + +If `%x` is cheap to rematerialize, layout assignment may clone the producer for +the dense store. Otherwise, if the registry has a `deinterleaved = 4 -> +contiguous` materialization plan, layout assignment may keep `%x` in +`deinterleaved = 4` and insert `ensure_layout` before the dense store. + +VPTO lowering result: + +```text +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Dense store materialization for the second consumer. +%even0, %even1 = pto.vintlv %x_p0, %x_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%odd0, %odd1 = pto.vintlv %x_p1, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%d0, %d1 = pto.vintlv %even0, %odd0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%d2, %d3 = pto.vintlv %even1, %odd1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %d0, %copy_out[%off_0], %all_b32 {dist = "NORM_B32"} +pto.vsts %d1, %copy_out[%off_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %d2, %copy_out[%off_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %d3, %copy_out[%off_192], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..31]) + +for i = 0..255: + copy_out[off + i] = base[off + i] +``` + +If the `deinterleaved = 4 -> contiguous` plan is not registered, the required +diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + value %x is required as #pto.vmi.layout by + pto.vmi.group_reduce_addf and as #pto.vmi.layout by + pto.vmi.store, but no registered materialization plan exists at the store + use site. +``` + +### 3.19 S=16 Reduce `block_elems` Plan Selection + +S=16 f32 group reduction has two legal dense input layouts: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems = 1` is the element-parity layout required by f32->f16 `truncf`. +It is also a valid S=16 reduction layout: each physical part contains eight +values per row, so `VCGADD` can reduce each part and `VADD` can combine the two +partial sums. + +`block_elems = 8` is still useful when the producer is a block load plan such +as `BDINTLV` or `vsldb` over 32B row fragments. Layout assignment must select +between these plans by producer/consumer cost. It must not hard-code S=16 +reduce to `block_elems = 8`. + +#### 3.19.1 Continuous S=16 Reduce And Truncf, `block_elems = 1` + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical lane map: + +```text +%x_p0 lanes r*8 .. r*8+7 = + row_r[0], row_r[2], row_r[4], ..., row_r[14] + +%x_p1 lanes r*8 .. r*8+7 = + row_r[1], row_r[3], row_r[5], ..., row_r[15] +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_p0, %x_p1 = pto.vldsx2 %base[%tile_off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +pto.vsts %h0, %out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = reduce(row_r[0..15]) + +for i = 0..127: + out[off + i] = truncf(base[off + i]) +``` + +#### 3.19.2 Block-Load Producer Fixed To `block_elems = 8` + +This is the real conflict case. The value is fixed to `block_elems = 8` +because the producer is a registered block-load plan. A later `truncf` +requires element-parity `block_elems = 1`. + +VMI input: + +```text +%stride24 = arith.constant 24 : index +%x = pto.vmi.group_load %base[%off], %stride24 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.store %h, %out[%off] +``` + +Assigned layouts before the conflicting `truncf` use: + +```text +%x from strided block group_load: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +The reduction path is legal and uses the same `vsldb` block-load shape as +section 3.15.2. The `truncf` path is legal only if one of these plans exists: + +```text +1. rematerialize the original memory producer as block_elems=1 +2. materialize block_elems=8 -> block_elems=1 in registers +3. use an explicitly enabled scratch/reload fallback +``` + +If no such plan is registered, the required diagnostic is: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the selected + strided group_load plan. Register a rematerialization or preserving + materialization plan, or avoid consuming this block-loaded value with truncf. +``` + +### 3.20 `group_slots` Control-Flow Join + +`group_slots` values must be allowed to cross control flow. The join type is a +sparse physical tuple, not a dense vector. + +VMI input: + +```text +%sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + scf.yield %a : !pto.vmi.vreg<128xf32> +} else { + %b = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + scf.yield %b : !pto.vmi.vreg<128xf32> +} +%bias = pto.vmi.group_slot_load %bias_base[%bias_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%outv = pto.vmi.addf %sum, %bias +pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%a, %b, %sum, %bias, %outv: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the join: + +```text +%sum_block = scf.if %cond -> !pto.vreg<64xf32> { + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + + %x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %a_block = pto.vadd %lo_sum, %hi_sum, %sum_mask + : !pto.vreg<64xf32> + scf.yield %a_block : !pto.vreg<64xf32> +} else { + %one_block = pto.pge_b32 "PAT_VL1" + %b_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + scf.yield %b_block : !pto.vreg<64xf32> +} + +%one_block = pto.pge_b32 "PAT_VL1" +%slot_mask = pto.pge_b32 "PAT_VL8" +%bias_block = pto.vsldb %bias_base[%bias_off], %c0_i16, %c0_i16, %one_block + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%out_block = pto.vadd %sum_block, %bias_block, %slot_mask + : !pto.vreg<64xf32> + +pto.vsts %out_block, %out[%group_off], %slot_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + lhs = cond ? reduce(row_r[0..15]) : rhs_base[rhs_off + r] + out[group_off + r] = lhs + bias_base[bias_off + r] +``` + +### 3.21 S=32 Tail With Full-Tile-Readable Source + +This is the positive counterpart to section 3.11.2. Tail participation is +still expressed by masks, but the source additionally promises that reading the +rounded-up 8-row physical tile is memory-safe. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] {full_tile_readable} + : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<192xpred, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<192xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// Full-tile-readable allows the load plan to read the rounded-up 8-row tile. +// Only rows 0..5 are semantically active. +%data_mask = pto.pge_b32 "PAT_VL48" // 6 rows * 8 lanes per physical part +%sum_mask = pto.pge_b32 "PAT_VL6" + +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %data_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..5: + out[group_off + r] = reduce(row_r[0..31]) +``` + +Rows 6 and 7 may be physically loaded because of `full_tile_readable`, but +their lanes are not active in `%data_mask`, and their group slots are not stored +because `%sum_mask` is `PAT_VL6`. + +### 3.22 `scf.for` Loop-Carried Layout + +Loop-carried VMI values require a layout fixed point. The iter_arg, body block +argument, yield operand, loop result, and later consumer must all agree on one +layout, or `vmi-layout-assignment` must insert a materialization at a legal +dominating use site. + +VMI input: + +```text +%init = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%acc = scf.for %i = %c0 to %steps step %c1 + iter_args(%arg = %init) -> !pto.vmi.vreg<256xf32> { + %bias = pto.vmi.broadcast %bias_s + : f32 -> !pto.vmi.vreg<256xf32> + %next = pto.vmi.addf %arg, %bias + scf.yield %next : !pto.vmi.vreg<256xf32> +} +%sum = pto.vmi.group_reduce_addf %acc, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%init, %arg, %bias, %next, %acc: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%init_even_0, %init_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_even_1, %init_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%init_p0, %init_p2 = pto.vdintlv %init_even_0, %init_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%init_p1, %init_p3 = pto.vdintlv %init_odd_0, %init_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%acc_p0, %acc_p1, %acc_p2, %acc_p3 = + scf.for %i = %c0 to %steps step %c1 + iter_args(%arg_p0 = %init_p0, %arg_p1 = %init_p1, + %arg_p2 = %init_p2, %arg_p3 = %init_p3) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %all_b32 = pto.pge_b32 "PAT_ALL" + %bias_p0 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p1 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p2 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + %bias_p3 = pto.vdup %bias_s, %all_b32 + : f32, !pto.mask -> !pto.vreg<64xf32> + + %next_p0 = pto.vadd %arg_p0, %bias_p0, %all_b32 : !pto.vreg<64xf32> + %next_p1 = pto.vadd %arg_p1, %bias_p1, %all_b32 : !pto.vreg<64xf32> + %next_p2 = pto.vadd %arg_p2, %bias_p2, %all_b32 : !pto.vreg<64xf32> + %next_p3 = pto.vadd %arg_p3, %bias_p3, %all_b32 : !pto.vreg<64xf32> + scf.yield %next_p0, %next_p1, %next_p2, %next_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" +%s0 = pto.vcgadd %acc_p0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %acc_p1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %acc_p2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %acc_p3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> +pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + for c = 0..31: + acc[row_r, c] = base[row_r, c] + steps * bias_s + out[group_off + r] = reduce(acc[row_r, 0..31]) +``` + +### 3.23 `group_broadcast` With Multiple Dense Consumers + +One `group_slots` value may feed multiple `group_broadcast` uses with different +dense result layout requirements. Layout assignment should rematerialize the +broadcast per use instead of forcing one result layout onto all consumers. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + +%b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b_for_mul +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} +%h = pto.vmi.truncf %b_for_cast +pto.vmi.store %h, %dense_out[%off] +``` + +Assigned layouts: + +```text +%x, %b_for_mul, %y: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_cast: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Use 1: broadcast for the S=16 block_elems=8 multiply path. Both row halves +// use the same per-row broadcast vector. +%b_rows_for_mul = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%y_lo = pto.vmul %x_lo, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> +pto.vsts %ysum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Use 2: rematerialize broadcast for the f32->f16 parity cast path. The +// deinterleaved=2 physical bundle has two entries that alias this SSA value. +%b_rows_for_cast = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%h_even = pto.vcvt %b_rows_for_cast, %all_b32 + {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %b_rows_for_cast, %all_b32 + {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%all_b16 = pto.pge_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 : !pto.vreg<128xf16> +pto.vsts %h0, %dense_out[%off], %all_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) + dense_out[r * 16 + 0 .. r * 16 + 15] = truncf(s) +``` + +### 3.24 Mask With Elementwise, Select, And Store + +This case separates compute masking from memory effects. A masked elementwise +operation with passthrough semantics can be represented as ordinary compute +plus `select`; a masked store uses the mask only on the store effect. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%rhs = pto.vmi.load %rhs_base[%off] + : memref<64xf32> -> !pto.vmi.vreg<64xf32> +%mask = pto.vmi.create_mask %c48 + : index -> !pto.vmi.mask<64xpred> +%sum = pto.vmi.addf %x, %rhs +%passthrough = pto.vmi.select %mask, %sum, %x +pto.vmi.store %passthrough, %dense_out[%off] +pto.vmi.masked_store %sum, %masked_out[%off], %mask +``` + +Assigned layouts: + +```text +%x, %rhs, %sum, %passthrough: + !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<64xpred, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%m = pto.pge_b32 "PAT_VL48" + +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%rhs0 = pto.vlds %rhs_base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%sum0 = pto.vadd %x0, %rhs0, %all_b32 : !pto.vreg<64xf32> + +%pass0 = pto.vsel %sum0, %x0, %m + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +pto.vsts %pass0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +pto.vsts %sum0, %masked_out[%off], %m {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..63: + if i < 48: + dense_out[off + i] = base[off + i] + rhs_base[off + i] + masked_out[off + i] = base[off + i] + rhs_base[off + i] + else: + dense_out[off + i] = base[off + i] + masked_out[off + i] is unchanged +``` + +### 3.25 Function Boundary Layout Specialization + +Function boundaries cannot rely on hidden layout side tables. Either the +function is internal and layout-specialized by `vmi-layout-assignment`, or a +public/external VMI boundary must diagnose until a stable VMI ABI is defined. + +#### 3.25.1 Internal Function Specialized To Consumer Layout + +VMI input: + +```text +func.func private @producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} + +func.func @caller(%base: !pto.ptr, %off: index, %out: !pto.ptr) { + %x = call @producer(%base, %off) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + return +} +``` + +Assigned layouts: + +```text +@producer result: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x in @caller: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @producer(...) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + return %x_p0, %x_p1, %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> +} + +func.func @caller(...) { + %x_p0, %x_p1, %x_p2, %x_p3 = call @producer(...) + : (...) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + + %all_b32 = pto.pge_b32 "PAT_ALL" + %sum_mask = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_p2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_p3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +} +``` + +Memory result: + +```text +for r = 0..7: + out[off + r] = reduce(row_r[0..31]) +``` + +#### 3.25.2 Public Or External VMI Boundary + +VMI input: + +```text +func.func @public_producer(%base: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> attributes {public} { + %x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> +} +``` + +Required diagnostic for the initial design: + +```text +VMI-LAYOUT-CONTRACT: + public or external function boundary returns !pto.vmi.vreg<256xf32> without a + stable VMI layout ABI. Mark the function internal for layout specialization, + inline it before vmi-layout-assignment, or define an explicit ABI layout. +``` diff --git a/docs/isa/micro-isa/10-reduction-ops.md b/docs/isa/micro-isa/10-reduction-ops.md index ecae818f2c..2129f91ce0 100644 --- a/docs/isa/micro-isa/10-reduction-ops.md +++ b/docs/isa/micro-isa/10-reduction-ops.md @@ -206,7 +206,9 @@ VLane 4: [32..39] VLane 5: [40..47] VLane 6: [48..55] VLane 7: [56..63] - **syntax:** `%result = pto.vcgadd %input, %mask : !pto.vreg, !pto.mask -> !pto.vreg` - **A5 types:** i16-i32, f16, f32 -- **semantics:** Sum within each VLane. 8 results at indices 0, 8, 16, 24, 32, 40, 48, 56 (for f32). +- **semantics:** Sum within each 32-byte VLane. The 8 VLane results are written + continuously to the low lanes of the destination vector. For f32, results are + at indices 0, 1, 2, 3, 4, 5, 6, 7. ```c int K = N / 8; // elements per VLane @@ -214,17 +216,17 @@ for (int g = 0; g < 8; g++) { T sum = 0; for (int i = 0; i < K; i++) sum += src[g*K + i]; - dst[g*K] = sum; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = sum; } -// For f32: results at dst[0], dst[8], dst[16], dst[24], dst[32], dst[40], dst[48], dst[56] +for (int i = 8; i < N; i++) + dst[i] = 0; +// For f32: results at dst[0], dst[1], ..., dst[7]. ``` - **inputs:** `%input` is the source vector and `%mask` selects participating lanes. - **outputs:** `%result` contains one sum per 32-byte VLane group, written - contiguously into the low slot of each group. + continuously to the low lanes of the destination vector. - **constraints and limitations:** This is a per-32-byte VLane-group reduction. Inactive lanes are treated as zero. @@ -242,10 +244,10 @@ for (int g = 0; g < 8; g++) { T mx = -INF; for (int i = 0; i < K; i++) if (src[g*K + i] > mx) mx = src[g*K + i]; - dst[g*K] = mx; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mx; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -268,10 +270,10 @@ for (int g = 0; g < 8; g++) { T mn = INF; for (int i = 0; i < K; i++) if (src[g*K + i] < mn) mn = src[g*K + i]; - dst[g*K] = mn; - for (int i = 1; i < K; i++) - dst[g*K + i] = 0; + dst[g] = mn; } +for (int i = 8; i < N; i++) + dst[i] = 0; ``` - **inputs:** `%input` is the source vector and `%mask` selects participating @@ -320,7 +322,7 @@ for (int i = 1; i < N; i++) // Row-wise sum using vcgadd (for 8-row tile) %row_sums = pto.vcgadd %tile, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// Results at indices 0, 8, 16, 24, 32, 40, 48, 56 +// Results at indices 0, 1, 2, 3, 4, 5, 6, 7 // Full vector sum for normalization %total = pto.vcadd %values, %mask : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> From 11809268e9406405636ab7aaeff79f56bc583148 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:19:45 +0800 Subject: [PATCH 06/54] Add VMI layout assignment lowering coverage --- README.md | 5 + .../vmi-layout-assignment-implementation.md | 1469 ++++++++ .../vmi-layout-assignment-lowering-design.md | 625 ++++ docs/designs/vmi-layout-lowering-cases.md | 2318 ++++++++++++- include/PTO/IR/VMIAttrs.td | 10 +- include/PTO/IR/VMIOps.td | 28 +- lib/PTO/IR/PTO.cpp | 24 + lib/PTO/IR/VMI.cpp | 548 ++- lib/PTO/Transforms/VMILayoutAssignment.cpp | 806 ++++- lib/PTO/Transforms/VMIToVPTO.cpp | 2940 +++++++++++------ .../lit/vmi/vmi_create_group_mask_invalid.pto | 20 + ...assignment_broadcast_dense_group_users.pto | 75 + ...yout_assignment_call_argument_boundary.pto | 74 + ...ayout_assignment_create_group_mask_s16.pto | 54 + ..._layout_assignment_dense_f16_f32_store.pto | 77 + ...ment_dense_group_reduce_multi_consumer.pto | 58 + ...gnment_dense_store_group_slots_invalid.pto | 32 + ..._layout_assignment_f32_f8_store_reduce.pto | 65 + .../vmi_layout_assignment_f8_compute_f8.pto | 61 + ...ignment_group_broadcast_multi_consumer.pto | 83 + ...yout_assignment_group_broadcast_slots8.pto | 27 + .../vmi/vmi_layout_assignment_group_load.pto | 27 + ...nment_group_load_block8_truncf_invalid.pto | 42 + ...roup_load_s16_compact_stride12_invalid.pto | 31 + ...assignment_group_load_s16_stride_store.pto | 50 + ...roup_load_s16_unaligned_stride_invalid.pto | 31 + ...group_load_s32_stride_broadcast_reduce.pto | 71 + ...assignment_group_load_s32_stride_store.pto | 51 + ...roup_load_s32_unaligned_stride_invalid.pto | 31 + ...ut_assignment_group_reduce_s12_invalid.pto | 29 + ...yout_assignment_group_reduce_s16_store.pto | 53 + ...roup_reduce_s16_truncf_broadcast_store.pto | 59 + ...ment_group_reduce_s32_broadcast_reduce.pto | 66 + ...nment_group_reduce_s32_multitile_store.pto | 53 + ...yout_assignment_group_reduce_s32_store.pto | 52 + ...gnment_group_reduce_s32_tail_full_tile.pto | 85 + ...p_reduce_s32_tail_no_full_tile_invalid.pto | 33 + ...vmi_layout_assignment_group_reduce_s64.pto | 29 + ...ment_group_reduce_s64_broadcast_reduce.pto | 57 + ...assignment_group_reduce_s64_tail_store.pto | 42 + ...out_assignment_group_reduce_s64_truncf.pto | 51 + ..._layout_assignment_group_reduce_slots8.pto | 29 + ...t_assignment_group_reduce_slots8_store.pto | 44 + .../vmi_layout_assignment_group_slot_load.pto | 58 + ...assignment_group_slot_load_dual_layout.pto | 76 + ...lot_load_slots1_dynamic_stride_invalid.pto | 24 + ...t_load_slots1_unaligned_stride_invalid.pto | 25 + ..._layout_assignment_group_slots_cf_join.pto | 59 + ...i_layout_assignment_group_slots_fanout.pto | 68 + ..._layout_assignment_group_slots_scf_for.pto | 79 + ...group_store_slots1_unit_stride_invalid.pto | 32 + ...ignment_mask_granularity_f32_f16_store.pto | 61 + ...mi_layout_assignment_mask_select_store.pto | 64 + ...signment_masked_load_dense_group_users.pto | 66 + ..._assignment_masked_load_group_tail_s32.pto | 39 + ..._layout_assignment_non_load_s32_reduce.pto | 62 + ...ment_packed_group_slots_truncf_invalid.pto | 35 + ...yout_assignment_widen_f16_store_reduce.pto | 64 + .../vmi/vmi_layout_group_slots_invalid.pto | 18 + .../vmi/vmi_load_full_read_elems_invalid.pto | 20 + test/lit/vmi/vmi_op_verifier_basic.pto | 7 + ...i_ptoas_call_boundary_vecscope_invalid.pto | 35 + .../vmi_to_vpto_group_broadcast_slots8.pto | 43 + ..._broadcast_slots8_missing_plan_invalid.pto | 29 + ...o_vpto_group_load_missing_plan_invalid.pto | 29 + test/lit/vmi/vmi_to_vpto_group_ops.pto | 3 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 45 + ..._group_reduce_s64_missing_plan_invalid.pto | 30 + .../vmi/vmi_to_vpto_group_reduce_slots8.pto | 34 + ...oup_reduce_slots8_missing_plan_invalid.pto | 28 + test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 74 + ...o_group_slot_load_missing_plan_invalid.pto | 27 + ...group_slot_load_nonunit_slots8_invalid.pto | 25 + .../vmi_to_vpto_group_slot_truncf_slots1.pto | 39 + ...lot_truncf_slots1_missing_plan_invalid.pto | 28 + ...pto_group_store_slots8_nonunit_invalid.pto | 26 + test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 4 +- test/lit/vmi/vmi_type_attr_parse.pto | 15 +- .../broadcast-dense-group-users/compare.py | 40 + .../vmi/broadcast-dense-group-users/golden.py | 47 + .../broadcast-dense-group-users/kernel.pto | 68 + .../broadcast-dense-group-users/launch.cpp | 33 + .../vmi/broadcast-dense-group-users/main.cpp | 97 + .../broadcast-dense-group-users/ptoas.flags | 1 + .../compare.py | 32 + .../golden.py | 50 + .../kernel.pto | 57 + .../launch.cpp | 35 + .../main.cpp | 94 + .../ptoas.flags | 1 + .../vmi/f32-to-f8-store-reduce/compare.py | 49 + .../vmi/f32-to-f8-store-reduce/golden.py | 55 + .../vmi/f32-to-f8-store-reduce/kernel.pto | 62 + .../vmi/f32-to-f8-store-reduce/launch.cpp | 41 + .../cases/vmi/f32-to-f8-store-reduce/main.cpp | 94 + .../vmi/f32-to-f8-store-reduce/ptoas.flags | 1 + test/vpto/cases/vmi/f8-compute-f8/compare.py | 27 + test/vpto/cases/vmi/f8-compute-f8/golden.py | 40 + test/vpto/cases/vmi/f8-compute-f8/kernel.pto | 55 + test/vpto/cases/vmi/f8-compute-f8/launch.cpp | 40 + test/vpto/cases/vmi/f8-compute-f8/main.cpp | 76 + test/vpto/cases/vmi/f8-compute-f8/ptoas.flags | 1 + .../group-broadcast-multi-consumer/compare.py | 44 + .../group-broadcast-multi-consumer/golden.py | 54 + .../group-broadcast-multi-consumer/kernel.pto | 69 + .../group-broadcast-multi-consumer/launch.cpp | 42 + .../group-broadcast-multi-consumer/main.cpp | 92 + .../ptoas.flags | 1 + .../group-load-s16-stride-store/compare.py | 27 + .../vmi/group-load-s16-stride-store/golden.py | 48 + .../group-load-s16-stride-store/kernel.pto | 51 + .../group-load-s16-stride-store/launch.cpp | 32 + .../vmi/group-load-s16-stride-store/main.cpp | 80 + .../group-load-s16-stride-store/ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 49 + .../kernel.pto | 59 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../group-load-s32-stride-store/compare.py | 27 + .../vmi/group-load-s32-stride-store/golden.py | 48 + .../group-load-s32-stride-store/kernel.pto | 51 + .../group-load-s32-stride-store/launch.cpp | 32 + .../vmi/group-load-s32-stride-store/main.cpp | 80 + .../group-load-s32-stride-store/ptoas.flags | 1 + .../vmi/group-reduce-basic-store/compare.py | 42 + .../vmi/group-reduce-basic-store/golden.py | 50 + .../vmi/group-reduce-basic-store/kernel.pto | 92 + .../vmi/group-reduce-basic-store/launch.cpp | 40 + .../vmi/group-reduce-basic-store/main.cpp | 123 + .../vmi/group-reduce-basic-store/ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 57 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 48 + .../kernel.pto | 63 + .../launch.cpp | 34 + .../main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 46 + .../kernel.pto | 55 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 30 + .../golden.py | 49 + .../kernel.pto | 55 + .../launch.cpp | 34 + .../main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 54 + .../launch.cpp | 42 + .../main.cpp | 80 + .../ptoas.flags | 1 + .../compare.py | 27 + .../group-reduce-s32-add-bias-store/golden.py | 48 + .../kernel.pto | 54 + .../launch.cpp | 33 + .../group-reduce-s32-add-bias-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 48 + .../kernel.pto | 57 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../group-reduce-s32-cf-join-store/compare.py | 27 + .../group-reduce-s32-cf-join-store/golden.py | 47 + .../group-reduce-s32-cf-join-store/kernel.pto | 63 + .../group-reduce-s32-cf-join-store/launch.cpp | 33 + .../group-reduce-s32-cf-join-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 47 + .../kernel.pto | 49 + .../launch.cpp | 33 + .../group-reduce-s32-multitile-store/main.cpp | 81 + .../ptoas.flags | 1 + .../compare.py | 28 + .../golden.py | 49 + .../kernel.pto | 53 + .../launch.cpp | 34 + .../main.cpp | 82 + .../ptoas.flags | 1 + .../compare.py | 27 + .../golden.py | 50 + .../kernel.pto | 61 + .../launch.cpp | 34 + .../main.cpp | 83 + .../ptoas.flags | 1 + .../compare.py | 36 + .../group-reduce-s64-slot-add-store/golden.py | 51 + .../kernel.pto | 64 + .../launch.cpp | 35 + .../group-reduce-s64-slot-add-store/main.cpp | 94 + .../ptoas.flags | 1 + .../group-reduce-s64-tail-store/compare.py | 30 + .../vmi/group-reduce-s64-tail-store/golden.py | 46 + .../group-reduce-s64-tail-store/kernel.pto | 52 + .../group-reduce-s64-tail-store/launch.cpp | 32 + .../vmi/group-reduce-s64-tail-store/main.cpp | 81 + .../group-reduce-s64-tail-store/ptoas.flags | 1 + .../group-reduce-s64-truncf-store/compare.py | 32 + .../group-reduce-s64-truncf-store/golden.py | 47 + .../group-reduce-s64-truncf-store/kernel.pto | 54 + .../group-reduce-s64-truncf-store/launch.cpp | 40 + .../group-reduce-s64-truncf-store/main.cpp | 79 + .../group-reduce-s64-truncf-store/ptoas.flags | 1 + .../group-reduce-slot-add-store/compare.py | 41 + .../vmi/group-reduce-slot-add-store/golden.py | 57 + .../group-reduce-slot-add-store/kernel.pto | 86 + .../group-reduce-slot-add-store/launch.cpp | 38 + .../vmi/group-reduce-slot-add-store/main.cpp | 113 + .../group-reduce-slot-add-store/ptoas.flags | 1 + .../vmi/group-slots-cf-join-store/compare.py | 38 + .../vmi/group-slots-cf-join-store/golden.py | 53 + .../vmi/group-slots-cf-join-store/kernel.pto | 97 + .../vmi/group-slots-cf-join-store/launch.cpp | 44 + .../vmi/group-slots-cf-join-store/main.cpp | 102 + .../vmi/group-slots-cf-join-store/ptoas.flags | 1 + .../compare.py | 38 + .../golden.py | 53 + .../kernel.pto | 71 + .../launch.cpp | 44 + .../main.cpp | 93 + .../ptoas.flags | 1 + .../vmi/group-slots-scf-for-store/compare.py | 36 + .../vmi/group-slots-scf-for-store/golden.py | 44 + .../vmi/group-slots-scf-for-store/kernel.pto | 68 + .../vmi/group-slots-scf-for-store/launch.cpp | 33 + .../vmi/group-slots-scf-for-store/main.cpp | 95 + .../vmi/group-slots-scf-for-store/ptoas.flags | 1 + .../mask-granularity-f32-f16-store/compare.py | 52 + .../mask-granularity-f32-f16-store/golden.py | 49 + .../mask-granularity-f32-f16-store/kernel.pto | 60 + .../mask-granularity-f32-f16-store/launch.cpp | 43 + .../mask-granularity-f32-f16-store/main.cpp | 91 + .../ptoas.flags | 1 + .../cases/vmi/mask-select-store/compare.py | 32 + .../cases/vmi/mask-select-store/golden.py | 51 + .../cases/vmi/mask-select-store/kernel.pto | 71 + .../cases/vmi/mask-select-store/launch.cpp | 42 + .../vpto/cases/vmi/mask-select-store/main.cpp | 99 + .../cases/vmi/mask-select-store/ptoas.flags | 1 + .../masked-load-dense-group-users/compare.py | 40 + .../masked-load-dense-group-users/golden.py | 46 + .../masked-load-dense-group-users/kernel.pto | 61 + .../masked-load-dense-group-users/launch.cpp | 33 + .../masked-load-dense-group-users/main.cpp | 97 + .../masked-load-dense-group-users/ptoas.flags | 1 + .../vmi/scf-for-loop-carried-store/compare.py | 27 + .../vmi/scf-for-loop-carried-store/golden.py | 41 + .../vmi/scf-for-loop-carried-store/kernel.pto | 53 + .../vmi/scf-for-loop-carried-store/launch.cpp | 32 + .../vmi/scf-for-loop-carried-store/main.cpp | 78 + .../scf-for-loop-carried-store/ptoas.flags | 1 + .../widen-f16-to-f32-store-reduce/compare.py | 38 + .../widen-f16-to-f32-store-reduce/golden.py | 50 + .../widen-f16-to-f32-store-reduce/kernel.pto | 67 + .../widen-f16-to-f32-store-reduce/launch.cpp | 42 + .../widen-f16-to-f32-store-reduce/main.cpp | 92 + .../widen-f16-to-f32-store-reduce/ptoas.flags | 1 + 270 files changed, 18976 insertions(+), 1444 deletions(-) create mode 100644 docs/designs/vmi-layout-assignment-implementation.md create mode 100644 docs/designs/vmi-layout-assignment-lowering-design.md create mode 100644 test/lit/vmi/vmi_create_group_mask_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_mask_select_store.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto create mode 100644 test/lit/vmi/vmi_layout_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_load_full_read_elems_invalid.pto create mode 100644 test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/compare.py create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/golden.py create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp create mode 100644 test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags create mode 100644 test/vpto/cases/vmi/f8-compute-f8/compare.py create mode 100644 test/vpto/cases/vmi/f8-compute-f8/golden.py create mode 100644 test/vpto/cases/vmi/f8-compute-f8/kernel.pto create mode 100644 test/vpto/cases/vmi/f8-compute-f8/launch.cpp create mode 100644 test/vpto/cases/vmi/f8-compute-f8/main.cpp create mode 100644 test/vpto/cases/vmi/f8-compute-f8/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp create mode 100644 test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/compare.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/golden.py create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/compare.py create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/golden.py create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp create mode 100644 test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/mask-select-store/compare.py create mode 100644 test/vpto/cases/vmi/mask-select-store/golden.py create mode 100644 test/vpto/cases/vmi/mask-select-store/kernel.pto create mode 100644 test/vpto/cases/vmi/mask-select-store/launch.cpp create mode 100644 test/vpto/cases/vmi/mask-select-store/main.cpp create mode 100644 test/vpto/cases/vmi/mask-select-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/compare.py create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/golden.py create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp create mode 100644 test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp create mode 100644 test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp create mode 100644 test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags diff --git a/README.md b/README.md index 0d8399783f..b3a547ab04 100644 --- a/README.md +++ b/README.md @@ -206,6 +206,11 @@ ptoas test/lit/pto/empty_func.pto --pto-arch=a5 -o outputfile.cpp # 指定构建 Level(level3 会禁用 PlanMemory/InsertSync) ptoas test/lit/pto/empty_func.pto --pto-level=level3 -o outputfile.cpp +# 启用实验性 VMI -> VPTO 语义 pipeline +# 该模式要求 --pto-backend=vpto,或输入 IR 中带 pto.backend = "vpto" +# public function signature 不能直接暴露 !pto.vmi.* 类型 +ptoas test/lit/vmi/vmi_ptoas_cli_pipeline.pto --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto -o - + # 查看当前 ptoas release 版本号 ptoas --version diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md new file mode 100644 index 0000000000..e6b39dd984 --- /dev/null +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -0,0 +1,1469 @@ +# VMI Layout Assignment Implementation Plan + +本文是 `vmi-layout-assignment` 和 `vmi-to-vpto` 的实现计划。它配套 +`vmi-layout-assignment-lowering-design.md`,并以 +`vmi-layout-lowering-cases.md` 为测试和验收来源。 + +不使用旧 `vmi-dialect-design.md` 作为设计输入。 + +## 1. Pipeline + +Recommended pass pipeline: + +```text +pto-validate-vmi-surface + -> vmi-layout-assignment + -> pto-validate-vmi-layout + -> vmi-to-vpto + -> canonicalize/cse + -> existing VPTO lowering/codegen +``` + +Pass responsibilities: + +```text +pto-validate-vmi-surface: + verify surface VMI has no physical VPTO layout dependency + reject public/external VMI ABI unless explicitly enabled + +vmi-layout-assignment: + solve value layouts + choose selected lowering plans + insert ensure/rematerialization helpers + make internal function boundary layouts explicit + rewrite VMI types with layout attrs + +pto-validate-vmi-layout: + verify every VMI data/mask value has layout + verify every context-sensitive op has selected_plan + verify helper ops have registered materialization plans + +vmi-to-vpto: + use OneToN type conversion + lower only from explicit layout/plan information + emit VPTO or precise unsupported diagnostic +``` + +## 2. Files To Add Or Update + +Expected implementation files: + +```text +include/PTO/IR/VMITypes.td +include/PTO/IR/VMIOps.td +include/PTO/IR/VMIAttrs.td +lib/PTO/IR/VMI.cpp + +include/PTO/Transforms/Passes.td +lib/PTO/Transforms/ValidateVMI.cpp +lib/PTO/Transforms/VMILayoutAssignment.cpp +lib/PTO/Transforms/VMIToVPTO.cpp +lib/PTO/Transforms/VMILayoutPlanRegistry.cpp + +test/lit/vmi/vmi_layout_assignment_*.pto +test/lit/vmi/vmi_to_vpto_*.pto +test/vpto/cases/vmi/*/ +``` + +Exact names may follow project conventions, but the layering should remain: + +```text +IR definitions + -> validation + -> assignment + -> OneToN lowering + -> lit and sim tests +``` + +## 3. IR Types And Attributes + +### 3.1 Layout Attribute + +Represent layout as a closed attribute family: + +```text +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +C++ form: + +```c++ +enum class VMILayoutKind { + Contiguous, + Deinterleaved, + GroupSlots, +}; + +struct VMILayoutKey { + VMILayoutKind kind; + int64_t deinterleaveFactor = 1; + int64_t blockElems = 1; + int64_t numGroups = 0; + int64_t slots = 0; +}; +``` + +Verifier rules: + +```text +contiguous: + no extra parameters + +deinterleaved: + F > 1 + B > 0 + direct full-chunk plans require N % (F * B) == 0 + +group_slots: + G > 0 + K > 0 + G % K == 0 + K fits in one physical vreg for element type +``` + +Parser compatibility during migration: + +```text +#pto.vmi.layout +``` + +is accepted as a legacy spelling for the pre-design implicit group layout. New +`vmi-layout-assignment` output must not rely on that implicit form. It must +print one of: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +so `vmi-to-vpto` can lower from the assigned type without reconstructing group +slot placement from producer or consumer context. + +### 3.2 VMI Types + +Surface: + +```text +!pto.vmi.vreg +!pto.vmi.mask +``` + +Layout-assigned: + +```text +!pto.vmi.vreg> +!pto.vmi.mask> +``` + +Surface VMI types are legal before assignment. Layout-assigned VMI types are +required after assignment. + +### 3.3 Selected Plan Attribute + +Every context-sensitive op gets a selected plan attr after assignment. The +initial implementation may use a stable string attr: + +```text +vmi.selected_plan = "s16_reduce_parity" +``` + +Once the plan registry syntax is stable, this can become a dedicated plan attr: + +```text +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +vmi.selected_plan = #pto.vmi.plan +``` + +Ops that are uniquely determined by layout may omit this attr, but the rule +should be conservative. If future maintainers could reasonably ask "why this +lowering?", assignment should write a plan. + +## 4. VMI Surface Ops Required By Cases + +Initial op set from the case catalog: + +```text +load +group_load +group_slot_load +store +masked_store + +create_mask +create_group_mask + +extf +truncf +addf +mulf +select +broadcast + +group_reduce_addf +group_broadcast +group_store + +ensure_layout // internal +ensure_mask_layout // internal +ensure_mask_granularity // internal +``` + +Important semantic split: + +```text +load: + optional full_read_elems=N is a memory-safety contract for pointer sources. + It states that source[offset : offset + N) may be physically read even if the + VMI logical result has fewer active lanes. + +group_load: + loads group_size data elements per group + +group_slot_load: + loads one scalar per group and produces group_slots +``` + +## 5. Plan Registry + +Create one registry object shared by assignment and lowering. + +```c++ +class VMILayoutPlanRegistry { +public: + SmallVector getProducerPlans(Operation *op); + SmallVector getConsumerPlans(OpOperand &use); + SmallVector getTransferPlans(Operation *op); + FailureOr getMaterializationPlan(Type valueType, + VMILayoutKey from, + VMILayoutKey to); + bool isCheaplyRematerializable(Operation *op); + bool hasTargetCapability(PlanID plan) const; +}; +``` + +Plan record: + +```c++ +struct VMILayoutPlan { + PlanID id; + SmallVector operandLayouts; + SmallVector resultLayouts; + int64_t cost; + bool requiresSelectedPlanAttr; + bool requiresFullTileReadable; + bool mayReadInactivePhysicalLanes; + DiagnosticBuilder (*explainFailure)(...); +}; +``` + +The registry must be target-aware but deterministic. It should not read global +mutable state. Pass options configure fallback availability: + +```text +enableScratchFallback +enableGatherFallback +enablePublicVMIABI +diagnosticVerbosity +``` + +## 6. Layout Assignment Data Model + +### 6.1 Solver State + +```c++ +struct ValueLayoutState { + Value value; + Type logicalType; + SmallVector candidates; + std::optional chosen; + SmallVector useRequests; +}; + +struct UseRequest { + OpOperand *operand; + VMILayoutKey requestedLayout; + PlanID requestingPlan; + bool hard; +}; + +struct OpPlanState { + Operation *op; + SmallVector candidates; + std::optional chosen; +}; +``` + +### 6.2 Collection Phase + +Walk the module and collect: + +```text +1. every VMI value +2. every VMI block argument +3. every VMI function argument/result +4. every VMI op with candidate plans +5. every branch/yield/call/return edge carrying VMI +``` + +Build SCCs over: + +```text +dataflow uses +region yields +loop iter_args +function call graph for private/internal functions +``` + +Public/external VMI function boundaries are rejected unless +`enablePublicVMIABI` is explicitly supported. + +Block arguments are first-class layout variables. Assignment must write the +chosen layout into the block argument type or specialized function signature. +`vmi-to-vpto` must never recover a block argument layout by walking to an +incoming branch, yield, or call operand. + +### 6.3 Constraint Generation + +Examples: + +```text +truncf f32->f16: + source request deinterleaved=2, block_elems=1 + result contiguous + +group_reduce S=16: + source candidate deinterleaved=2, block_elems=1 + source candidate deinterleaved=2, block_elems=8 + result group_slots(G, slots=8) + +group_reduce S=32: + source candidate deinterleaved=4, block_elems=1 + source candidate deinterleaved=4, block_elems=8 + result group_slots(G, slots=8) + +group_reduce S=64: + source request contiguous + result group_slots(G, slots=1) + +group_broadcast: + source request group_slots(G,K) + result candidate comes from each dense consumer request + op is rematerializable per use + +ordinary dense add/mul/select: + operands/results same dense layout + +group-slot add/mul: + operands/results same group_slots(G,K) + +ordinary store: + dense source required + group_slots source is illegal + +group_store: + source request group_slots(G,K) +``` + +Consumer-driven adoption is limited to producers that are layout-transparent or +can produce the requested memory layout directly: + +```text +direct layout producer: + load, tile_read + +layout-transparent producer: + broadcast, constant, iota + add/sub/mul/fma/div/min/max/neg/abs/sqrt/exp/ln/relu + integer bitwise/shift/not + select, bitcast +``` + +For a non-load layout-transparent producer, only non-contiguous consumer +requests may be adopted by the producer equivalence class. Contiguous requests +from ordinary stores are handled by use-site `ensure_layout` or +rematerialization instead. This prevents a dense store from overwriting a +natural `deinterleaved` cast layout while still allowing: + +```text +load -> broadcast -> addf -> S=32 group_reduce +``` + +to assign the whole producer chain as +`deinterleaved = 4, block_elems = 8` before `vmi-to-vpto`. + +Memory legality constraints: + +```text +S=32 tail fast load: + requires full_tile_readable + otherwise require gather fallback or diagnose + +compact S=12 logical S=16: + requires compact-row gather materialization + diagnose if gather fallback is disabled/missing +``` + +### 6.4 Solving And Rewriting + +Algorithm: + +```text +1. Pick candidate plan sets for every op. +2. Propagate hard constraints through SCCs. +3. Resolve transfer-equivalent dense values. +4. Choose multi-plan ops by cost: + - S=16 parity vs block8 + - load memory-fused vs load+materialize + - group_slot_load slots=8 vs slots=1 +5. For conflicting uses: + - rematerialize cheap producer where legal + - otherwise insert ensure_layout at use + - otherwise diagnose +6. Rewrite VMI result/block/function types with chosen layouts. +7. Attach selected_plan attrs where required. +8. Insert helper ops with source/result layout attrs. +``` + +Rewrite invariants: + +```text +No VMI data/mask value after assignment has a null layout. +No context-sensitive VMI op after assignment lacks selected_plan. +Every ensure_* helper has a registered materialization plan. +Every function/call signature carrying VMI is specialized or diagnosed. +``` + +## 7. OneToN Type Conversion + +`vmi-to-vpto` should use OneToN conversion for VMI values. + +Conversion rules: + +```text +contiguous: + ceil(N / lanesPerVReg(T)) physical vregs + +deinterleaved=F: + F * ceil((N / F) / lanesPerVReg(T)) physical vregs + ordering: part-major, then chunk + +group_slots(G,K): + ceil(G / K) physical vregs + each vreg has logical slot lanes 0..K-1 live +``` + +Mask conversion: + +```text +mask layout follows data layout +mask granularity is selected from consumer element width: + f32/i32 -> b32 + f16/i16 -> b16 + f8/i8 -> b8 +``` + +If one logical mask is used by multiple widths, assignment inserts +`ensure_mask_granularity` or rematerializes the mask producer. + +## 8. VMI-to-VPTO Pattern Rules + +Each pattern uses: + +```text +op +operand/result layouts +selected_plan +adaptor physical values +``` + +Each pattern rejects: + +```text +missing selected_plan for context-sensitive op +layout not matching selected_plan +missing target capability +unexpected group_slots dense consumer +``` + +Target selected-plan matrix: + +```text +load, selected_plan=dense_load_norm: + result layout contiguous + emits pto.vlds / pto.vsts NORM paths + covers dense store users and S=64 row-local reduce input + +load, selected_plan=load_dintlv2: + result layout deinterleaved=2, block_elems=1 + emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization + covers f32->f16, S=16 parity reduce, f16->f32 widened values + +load, selected_plan=load_dintlv4: + result layout deinterleaved=4, block_elems=1 + emits two vldsx2 DINTLV_B32 plus vdintlv + covers f32->f8, S=32 dintlv4 reduce + +group_load, selected_plan=s16_group_load_block8_unit_stride: + result layout deinterleaved=2, block_elems=8 + emits vldsx2/BDINTLV for 8 rows of 16xf32 + covers compact logical S=16 when source_group_stride == 16 + +group_load, selected_plan=s16_group_load_block8_stride: + result layout deinterleaved=2, block_elems=8 + emits two vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, selected_plan=s32_group_load_block8_stride: + result layout deinterleaved=4, block_elems=8 + emits four vsldb strided 32B block loads + requires source_group_stride % 8 == 0 + +group_load, selected_plan=group_load_contiguous_chunks: + result layout contiguous + emits one vlds per physical group chunk using row_stride address arithmetic + covers the currently implemented full-chunk row-local group_load path + +group_reduce_addf, selected_plan=s8_reduce_contiguous: + consumes contiguous f32 with group size 8 + produces group_slots(G, slots=8) + emits one vcgadd + +group_reduce_addf, selected_plan=s16_reduce_parity: + consumes deinterleaved=2, block_elems=1 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_addf, selected_plan=s16_reduce_block8: + consumes deinterleaved=2, block_elems=8 + produces group_slots(G, slots=8) + emits two vcgadd operations and one vadd + +group_reduce_addf, selected_plan=s32_reduce_dintlv4: + consumes deinterleaved=4, block_elems=1 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_addf, selected_plan=s32_reduce_block8_stride: + consumes deinterleaved=4, block_elems=8 + produces group_slots(G, slots=8) + emits four vcgadd operations and a vadd tree + +group_reduce_addf, selected_plan=s64_reduce_row_local: + consumes contiguous f32 with group size 64 + produces group_slots(G, slots=1) + target lowering emits per-row vcgadd plus vcadd; the current prototype uses + the existing row-local VCADD/VADD/VSEL sequence while preserving the same + group_slots(G, slots=1) value contract + +group_slot_load, selected_plan=group_slot_load_slots8_unit_stride: + result group_slots(G, slots=8) + requires source_group_stride == 1 + emits one packed vsldb load + +group_slot_load, selected_plan=group_slot_load_slots1_row_local: + result group_slots(G, slots=1) + supports aligned non-unit source_group_stride + requires constant positive source_group_stride divisible by 256 / elementBits + emits one lane-0 vsldb per group + +group_broadcast, selected_plan=group_broadcast_slots8_vselr: + source group_slots(G, slots=8) + result dense layout selected per use + emits vselr using assigned result layout + +group_broadcast, selected_plan=group_broadcast_slots1_vselr: + source group_slots(G, slots=1) + result dense layout selected per use + emits vdup/vselr row-local materialization + +truncf, selected_plan=group_slot_cast_slots1_f32_to_f16: + source/result group_slots(G, slots=1) + emits one lane-0 vcvt per group slot block + rejects packed slots=8 unless another plan is registered +``` + +The target matrix is the implementation contract. The staged status below +records how much of that contract the current prototype has already enforced. + +Current staged implementation status: + +```text +group_slot_load: + vmi-to-vpto requires vmi.selected_plan and checks it against + #pto.vmi.layout. + +group_reduce_addf: + explicit slots=8 VCGADD lowering requires + vmi.selected_plan = "s8_reduce_contiguous". Legacy bare num_groups and + generic VCADD lowering still need the plan-registry migration. + S=16 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout, and + vmi.selected_plan = "s16_reduce_block8"; vmi-to-vpto checks that plan and + lowers through two VCGADDs plus a PAT_VL8 VADD per packed result block. + S=32 block8 assignment emits source/mask + #pto.vmi.layout, result + #pto.vmi.layout, and + vmi.selected_plan = "s32_reduce_block8_stride"; vmi-to-vpto checks that + plan and lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed + result block. + S=64 row-local assignment now emits + vmi.selected_plan = "s64_reduce_row_local" and has focused + layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic + VCADD row-local path also requires and checks that selected_plan. Other + legacy bare num_groups generic VCADD paths still need the plan-registry + migration. + +group_broadcast: + explicit slots=8/1 source layouts require + vmi.selected_plan = "group_broadcast_slots8_vselr" or + "group_broadcast_slots1_vselr". Deinterleaved block-fragment results use + the result layout block_elems as the local vselr selection group, so + `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each + 32B row fragment. VSELR index vectors are materialized per physical result + chunk. For small-group results, layout assignment has already fixed the + result layout, and vmi-to-vpto computes: + `firstGroup = first logical group covered by this result chunk`, + `sourceChunk = firstGroup / slots`, and + `baseGroupSlot = firstGroup % slots`. The generated index vector selects + `baseGroupSlot .. baseGroupSlot + groupsPerResultChunk - 1`; it must not be + reused across result chunks. Legacy bare num_groups still needs the + plan-registry migration. + +group_load: + contiguous full-chunk path emits and checks + vmi.selected_plan = "group_load_contiguous_chunks". S=16/S=32 + block-aligned strided loads emit and check + vmi.selected_plan = "s16_group_load_block8_stride" or + "s32_group_load_block8_stride", assign + #pto.vmi.layout, and lower to one + vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride + vldsx2/BDINTLV plan remains a design target. S=16/S=32 group_load with a + non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by + vmi-layout-assignment because the stable gather fallback is not implemented. + +truncf group-slot cast: + layout assignment and vmi-to-vpto support and check + vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" for + group_slots(G, slots=1) f32 -> f16. The reduce->truncf->group_store + slots=1 flow has focused lit coverage and no longer relies on vmi-to-vpto + inspecting the truncf producer. + +group_store: + row-local group_slots(G, slots=1) lowering is implemented as one lane-0 + vsts per group and is covered by the reduce->truncf->group_store lit case. + The current plan is accepted only when row_stride is a constant positive + multiple of the 32B store alignment in destination elements: 8 for f32, + 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only + the first row-local store is 32B-aligned; later `group_off + r` stores are + 4B apart. A future pack-to-slots=8 or unaligned-store plan is required before + contiguous `%c1` slots=1 group_store can be accepted. + Packed group_slots(G, slots=8) group_store is implemented only when + num_groups is a multiple of 8 and row_stride is constant 1; it emits one + PAT_VL8 store per packed slot block. Non-unit packed group stores remain a + design target unless a strided packed-lane store plan is selected explicitly. +``` + +Examples: + +```text +group_reduce_addf, selected_plan=s16_reduce_parity: + consume deinterleaved=2, block_elems=1 + emit two VCGADDs and one VADD + +group_reduce_addf, selected_plan=s16_reduce_block8: + consume deinterleaved=2, block_elems=8 + emit two VCGADDs and one VADD + +group_reduce_addf, selected_plan=s32_reduce_dintlv4: + consume deinterleaved=4 + emit four VCGADDs and reduction tree + +group_broadcast: + consume group_slots + emit VSELR or VDUP depending slots and target dense layout + +group_slot_load slots=8: + emit one packed block load for unit stride + +group_slot_load slots=1: + emit row-local lane-0 loads for constant positive 32B-aligned strides +``` + +## 9. Validation Passes + +### 9.1 Surface Validation + +Before assignment: + +```text +VMI types may omit layout. +VPTO physical op must not consume VMI values. +Public/external VMI function ABI rejected unless enabled. +Unsupported vector-to-scalar extract rejected. +``` + +### 9.2 Layout Validation + +After assignment: + +```text +Every VMI value has layout. +Every VMI mask has layout and granularity plan. +Every context-sensitive op has selected_plan. +Every selected_plan matches operand/result layouts. +Every ensure_* helper has a materialization plan. +Every control-flow edge has matching VMI layouts. +``` + +### 9.3 `vmi-to-vpto` Context Read Audit + +`vmi-to-vpto` may still read defining ops in narrowly scoped cases that do not +select a layout or plan: + +```text +allowed: + arith.constant for the current op's scalar operands + create_mask/create_group_mask internals when lowering that mask op itself + ensure_mask_layout / ensure_mask_granularity stripping for static mask facts + memref.subview only to improve an already-failed non-identity memref + diagnostic + +not allowed: + walking from a consumer to a producer to decide a selected_plan + walking from a consumer to a mask producer to decide whether a plan is legal + inspecting users to choose a result layout or materialization + recovering full_tile_readable from surrounding MTE/caller context +``` + +Current audit result: + +```text +3.44 partial S=32 create_group_mask: + decision moved to vmi-layout-assignment. vmi-to-vpto no longer walks from + group_reduce_addf to the mask defining op to reject the plan. + +masked_load: + direct lowering is load + vsel. It does not inspect the mask producer to + choose a different load form; memory safety is provided by full physical + chunks, shaped memref proof, or load full_read_elems. + +memref.subview: + mentioned only after identity lane-to-address planning fails. It is not used + to recover a hidden base/stride lowering. +``` + +## 10. Diagnostics + +Implement diagnostics with stable prefixes: + +```text +VMI-LAYOUT-CONTRACT +VMI-UNSUPPORTED-PLAN +VMI-MISSING-CAPABILITY +VMI-PUBLIC-ABI +VMI-MASK-GRANULARITY +VMI-CONTROL-FLOW-LAYOUT +``` + +Minimum diagnostic payload: + +```text +op name +logical type +actual layout +requested layout +selected/missing plan +recommended rewrite or option +``` + +Example: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.truncf requires + #pto.vmi.layout, but the source value is + fixed to #pto.vmi.layout by the selected + strided group_load plan. Register a rematerialization or preserving + materialization plan, or avoid consuming this block-loaded value with truncf. +``` + +## 11. Test And Simulator Acceptance + +Each numbered endpoint in `vmi-layout-lowering-cases.md` should become: + +```text +1. a layout-assignment lit test +2. a vmi-to-vpto lit test +3. a simulator case when the VPTO sequence is supported by the current backend +4. a diagnostic lit test when the case is explicitly unsupported +``` + +Repository locations: + +```text +test/lit/vmi/ +test/vpto/cases/vmi/ +``` + +The current repository uses descriptive flat lit names rather than +case-numbered subdirectories. New tests should follow the existing prefixes: + +```text +vmi_layout_assignment_.pto +vmi_to_vpto_.pto +/kernel.pto +``` + +The case number should still be recoverable from the coverage table in this +document and from the corresponding section in `vmi-layout-lowering-cases.md`. + +### 11.1 Layout Assignment Checks + +Each positive layout-assignment test must check: + +```text +assigned data layouts +assigned mask layouts +selected_plan attrs +inserted ensure_layout/rematerialized producers +control-flow/function signature specialization +``` + +Negative tests check diagnostic text. + +### 11.2 VMI-to-VPTO Checks + +Each positive vmi-to-vpto test must check: + +```text +no pto.vmi ops remain +VPTO op sequence matches the case lowering +physical value arity and ordering are correct +mask granularity is correct +stores preserve observable logical memory order +``` + +### 11.3 Simulator Checks + +Simulator cases should compare final memory against the memory result written in +the case catalog. + +Current broad runtime sweep: + +```text +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-39 CASE_PREFIX='vmi/' JOBS=4 \ + test/vpto/scripts/run_host_vpto_validation_parallel.sh + +PASS=39 FAIL=0 +summary: .tmp/vmi-runtime-batch-39/parallel-summary.tsv +log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ + .tmp/vmi-runtime-batch-39.log +result: no matches +``` + +The `find: Permission denied` messages printed while discovering CANN simulator +paths are environment noise and are not treated as simulator failures. + +Required groups: + +```text +dense conversion: + 3.1, 3.2, 3.3, 3.31, 3.32 + +group reduce: + 3.4, 3.5.1, 3.5.2, 3.5.3 + 3.6.1, 3.6.2, 3.6.3 + 3.7.1, 3.7.2, 3.7.3 + 3.7.4 diagnostic + +layout/rematerialization: + 3.8, 3.10, 3.17, 3.18, 3.19.1, 3.22, 3.23, 3.31, + 3.32, 3.33, 3.34, 3.35, 3.36, 3.38, 3.40, 3.41 + +mask/tail: + 3.11.1, 3.15.1, 3.15.2, 3.21, 3.24, 3.26, 3.29, + 3.30, 3.44 + +strided/group-slot memory: + 3.27, 3.28, 3.37, 3.39 + +function/control-flow: + 3.12, 3.20, 3.22, 3.25.1, 3.42, 3.43 +``` + +Aggregate catalog headings are covered through their endpoint subcases: + +```text +3.11 partial tail groups: + 3.11.1 positive S=64 active-row tail + 3.11.2 diagnostic S=32 tail without full_tile_readable + +3.15 compact S=12 written as logical S=16: + 3.15.1 positive source row stride 16 + 3.15.2 positive source row stride greater than 16 + 3.15.3 diagnostic compact source row stride 12 + +3.16 group_slot_load layout contract: + 3.16.1 packed slots=8 positive and non-unit-stride diagnostic + 3.16.2 row-local slots=1 positive plus dynamic/unaligned diagnostics + +3.25 function boundary layout specialization: + 3.25.1 private/internal boundary lit coverage, runtime backend gap + 3.25.2 public/external boundary diagnostics +``` + +Current checked-in coverage for 3.3 dense f8->f32->compute->f8: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto + +runtime SIM: + test/vpto/cases/vmi/f8-compute-f8 +``` + +Current checked-in coverage for 3.1/3.2 dense f16/f32 conversion stores: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce + test/vpto/cases/vmi/quant-f32-to-f16-tail +``` + +Current checked-in coverage for basic packed group_reduce -> group_store paths +for 3.4, 3.5.1, and 3.6.1: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-basic-store +``` + +Current checked-in coverage for S=16 group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store +``` + +Current checked-in coverage for 3.35 group_slots fanout to direct group_store +and group_broadcast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-fanout-store-broadcast +``` + +Current checked-in coverage for 3.8 `group_reduce -> truncf -> +group_broadcast -> dense store` and 3.17 `group_broadcast` feeding a +deinterleaved consumer: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store +``` + +Current checked-in coverage for 3.18 one dense value with dense and +group-reduce consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/dense-group-reduce-multi-consumer +``` + +Current checked-in coverage for 3.10 non-load producer feeding S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-add-bias-store +``` + +Current checked-in coverage for 3.23 group_broadcast with multiple dense +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto + +runtime SIM: + test/vpto/cases/vmi/group-broadcast-multi-consumer +``` + +Current checked-in coverage for S=32 contiguous group broadcast continuation: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store +``` + +Current checked-in coverage for 3.21 S=32 tail with a statically safe +full-read source: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store + This case has `ptoas.flags` with `--enable-vmi`, because the partial pointer + load must run through layout assignment before VPTO/LLVM emission. +``` + +Current checked-in runtime coverage for 3.12 control-flow join before S=32 +`group_reduce`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_cf_branch.pto + test/lit/vmi/vmi_to_vpto_cf_branch.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-cf-join-store +``` + +Current checked-in runtime coverage for 3.20 `group_slots` control-flow join: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-cf-join-store +``` + +Current checked-in runtime coverage for 3.22 `scf.for` loop-carried VMI layout: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_scf_for.pto + test/lit/vmi/vmi_to_vpto_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/scf-for-loop-carried-store +``` + +Current checked-in runtime coverage for 3.42 `group_slots` `scf.for` +loop-carried accumulator: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto + +runtime SIM: + test/vpto/cases/vmi/group-slots-scf-for-store +``` + +Current checked-in lit coverage for 3.43 internal function argument boundary +materialization: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + +runtime SIM: + blocked by the current private vector callee backend path; see known + implementation gaps below +``` + +Current checked-in coverage for packed group-slot RHS elementwise continuations +for 3.5.3 and 3.6.2: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-slot-add-store +``` + +Current checked-in coverage for S=64 row-local group broadcast continuation +with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store +``` + +Current checked-in coverage for S=64 active-row tail with aligned row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-tail-store +``` + +The companion negative lit case for contiguous `%c1` slots=1 group_store is: + +```text +test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +``` + +Current checked-in coverage for S=64 row-local group-slot RHS elementwise +continuation with aligned source_group_stride and aligned output row_stride: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-slot-add-store +``` + +Current checked-in coverage for 3.34 S=64 `slots = 1` group-slot f32->f16 cast: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s64-truncf-store +``` + +The companion negative lit cases for dynamic or unaligned `%c2` slots=1 +group_slot_load, and non-unit `slots = 8` group_slot_load, are: + +```text +test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +``` + +Current checked-in coverage for the strided block-load cases: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-load-s16-stride-store + test/vpto/cases/vmi/group-load-s32-stride-store + test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce +``` + +Current checked-in coverage for grouped mask S=16 tail: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto + test/lit/vmi/vmi_create_group_mask_invalid.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store + test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store +``` + +Current checked-in coverage for 3.24 mask/select/masked-store semantics: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_select_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-select-store +``` + +Current checked-in coverage for 3.29 one semantic mask with f32 and f16 +consumers: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto + +runtime SIM: + test/vpto/cases/vmi/mask-granularity-f32-f16-store +``` + +Current checked-in coverage for 3.31 f16->f32 feeding dense store and S=16 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/widen-f16-to-f32-store-reduce +``` + +Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto + +runtime SIM: + test/vpto/cases/vmi/f32-to-f8-store-reduce +``` + +Current checked-in coverage for multi-tile group-slot arity: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto + +runtime SIM: + test/vpto/cases/vmi/group-reduce-s32-multitile-store +``` + +Current checked-in coverage for 3.40 scalar broadcast feeding dense and grouped +users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/broadcast-dense-group-users +``` + +Current checked-in coverage for 3.41 non-rematerializable `masked_load` feeding +dense and grouped users: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-dense-group-users +``` + +Diagnostic-only cases: + +```text +3.9 dense store of group slots +3.11.2 S=32 tail without full_tile_readable +3.7.4 S=64 slots=1 group_store with unit output stride +3.13 packed group-slot f32 -> f16 cast +3.14 unsupported group size +3.15.3 compact source row stride 12 +3.16.1 group_slot_load slots=8 non-unit stride +3.16.2 group_slot_load slots=1 dynamic or unaligned stride +3.27 S=32 source_group_stride not divisible by 8 f32 elements +3.19.2 block_elems=8 value consumed by truncf without materialization plan +3.25.1 full ptoas emission for private VMI callees that return VPTO vector values +3.25.2 public/external VMI boundary +3.30 unsafe masked_load tail without stable masked/gather fallback +3.44 masked_load grouped tail with S=32 partial create_group_mask +``` + +Current checked-in diagnostic coverage for 3.9/3.13/3.14: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto + test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +``` + +Current checked-in diagnostic coverage for the remaining non-SIM diagnostic +entries: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto + test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto + test/lit/vmi/vmi_ptoas_public_abi_invalid.pto + test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto + test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto + test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto + test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto + test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +``` + +Known implementation gaps before all catalog cases can become runtime SIM +coverage: + +```text +dynamic grouped masks: + pto.vmi.create_group_mask exists and supports constant + active_elems_per_group. Dynamic active_elems_per_group is not implemented + yet. Do not replace grouped masks with prefix create_mask; that would change + the semantics. + +S=32 partial grouped masks: + 3.44 `masked_load` grouped tail with `active_elems_per_group < 32` is + diagnostic-only for the current S=32 block8 reduce path, and the diagnostic + is emitted by `vmi-layout-assignment` before a selected plan is written. A + runtime probe of the previously allowed lowering did not preserve the logical + 25-lane row sum. A second probe with `active_elems_per_group = 25` produced + row 0 `golden=-3.6290324` but `output=-3.6592741`, and the row-wise error + grew monotonically. This combination must stay unsupported until the + deinterleaved grouped-mask materialization is fixed and validated by SIM. + +remaining function runtime coverage: + 3.25.1 internal function boundary specialization has layout-assignment and + vmi-to-vpto lit coverage, but full ptoas emission still fails after + physicalization because today's inferred pto.vecscope is resultless and VPTO + vector-scope values cannot escape through a function return. Runtime coverage + requires either a resultful vecscope/VPTO vector ABI or an explicit inlining + policy before vecscope inference. + + 3.43 internal function argument boundary materialization has + layout-assignment and vmi-to-vpto lit coverage. Full ptoas emission for a + private void vector callee currently reaches the Bisheng device backend and + fails on the physicalized callee with: + + fatal error: error in backend: Do not know how to split the result of this operator! + + Runtime coverage requires either inlining private vector callees before the + device backend path or adding backend support for the physical VPTO vector + function ABI. This is a runtime/backend gap, not a license for `vmi-to-vpto` + to infer layouts from caller/callee context. + +memory-proof runtime coverage: + 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses + `pto.vmi.load {full_read_elems = 256}` on a UB pointer source. The attr is + the explicit safe-read proof consumed by `vmi-to-vpto`; no surrounding MTE, + caller/body context, or producer/user scan is inspected to justify the + rounded-up physical reads. +``` + +## 12. Implementation Slices + +### Slice 1: IR Skeleton And Verifiers + +```text +layout attrs +vmi.vreg/vmi.mask types +surface op definitions +selected_plan attr +surface/layout validators +``` + +### Slice 2: Straight-Line Dense Assignment/Lowering + +```text +3.1 f16->f32->store +3.2 f32->f16->store +3.3 f8->f32->compute->f8 +``` + +### Slice 3: Group Slots And Reductions + +```text +3.4 S=8 +3.5 S=16 parity/block8 +3.6 S=32 +3.7 S=64 +group_slot_load +group_broadcast +group_store +``` + +### Slice 4: Layout Conflicts And Materialization + +```text +3.8 cast commute through group_broadcast +3.18 dense/group-reduce multi-consumer +3.19 block_elems plan selection +3.23 group_broadcast multi-consumer +3.32 f32 feeding f8 store and S=32 reduce +3.33 S=16/S=32 reduce multi-consumer rematerialization +3.34 slots=1 group-slot f32->f16 cast +3.35 group_slots fanout to group_store and group_broadcast +3.36 group_slot_load rematerialized for slots=8/slots=1 +3.38 multi-tile group_slots arity +3.40 scalar broadcast rematerialized for dense/grouped users +3.41 non-rematerializable value with ensure_layout +``` + +### Slice 5: Masks, Tail, And Memory Legality + +```text +create_mask +create_group_mask +masked_store +safe full-read proof +compact/gather diagnostics +mask granularity per use +group_load stride greater than group size +group_slot_load slots=1 aligned non-unit stride plus dynamic/unaligned diagnostic +group_store slots=1 non-unit output stride +strided group_load feeding broadcast and a second reduce +masked_load grouped tail feeding S=32 reduce +``` + +### Slice 6: Control Flow And Functions + +```text +scf.if +scf.for +group_slots across control flow +group_slots loop-carried accumulator +internal function specialization +internal function argument boundary materialization +public ABI diagnostic +``` + +## 13. Completion Checklist + +The implementation is not complete until: + +```text +1. every case has a layout-assignment test +2. every positive case has a vmi-to-vpto test +3. every simulator-supported case has a sim validation +4. every unsupported case has a diagnostic test +5. vmi-to-vpto contains no producer/user context inference +6. missing selected_plan on context-sensitive ops is a hard failure +7. release docs are updated only after the design stabilizes +``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md new file mode 100644 index 0000000000..710ab267a7 --- /dev/null +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -0,0 +1,625 @@ +# VMI Layout Assignment And Lowering Design + +本文是新的 VMI layout assignment / lowering 设计文档。它只以 +`docs/designs/vmi-layout-lowering-cases.md` 为 source of truth,不继承旧 +`vmi-dialect-design.md` 的 layout 设计,以避免旧上下文污染。 + +目标: + +```text +VMI surface IR + -> vmi-layout-assignment + -> layout-assigned VMI IR + -> vmi-to-vpto + -> VPTO IR +``` + +核心验收约束: + +```text +vmi-to-vpto 不允许通过上下文猜 lowering。 + +任何需要 producer/consumer/control-flow/memory/mask 上下文才能决定的事, +必须在 vmi-layout-assignment 阶段变成显式 IR 信息: + +1. vmi.vreg/vmi.mask 的 layout +2. op 的 selected lowering plan +3. use-site ensure_layout / ensure_mask_layout +4. rematerialized producer +5. target capability diagnostic +``` + +## 1. Source Case Coverage + +设计必须覆盖 case catalog 中的端到端场景: + +```text +dense cast: + f16 -> f32 -> store + f32 -> f16 -> store + f8 -> f32 -> compute -> f8 + f16 -> f32 shared by dense store and S=16 reduce + f32 shared by f8 store and S=32 reduce + +group reduce: + S=8, S=16, S=32, S=64 + reduce -> group_store + reduce -> group_slot_load/elemwise -> group_store + reduce -> group_broadcast -> elemwise -> reduce -> store + one group_slots result fanning out to group_store and group_broadcast + grouped tail -> broadcast -> reduce -> store + +layout conflict: + one value with dense and group-reduce consumers + one value with S=16 and S=32 group-reduce consumers + one scalar broadcast rematerialized for dense and grouped users + one non-rematerializable value materialized with use-site ensure_layout + one scalar group-slot source rematerialized as slots=8 and slots=1 + S=16 block_elems=1/8 plan selection + dense consumer of group_slots diagnostic + packed group-slot width-changing cast diagnostic + S=64 slots=1 group-slot width-changing cast + +control flow: + scf.if before group_reduce + group_slots across scf.if + scf.for loop-carried layout fixed point + group_slots as scf.for loop-carried accumulator + internal function boundary specialization + internal function argument boundary materialization + public/external VMI ABI diagnostic + +mask and tail: + prefix mask + group-periodic mask + masked_load tail with explicit passthrough instead of padding + masked_load grouped tail feeding group_reduce + masked select/store + one semantic mask used by multiple predicate granularities + S=32 tail with and without full_tile_readable + compact S=12 diagnostic + +strided memory: + group_load source stride greater than logical group size + strided group_load feeding broadcast and a second group_reduce + group_slot_load slots=1 with non-unit source stride + group_store slots=1 with non-unit output stride +``` + +## 2. Layout Domain + +Layout is a property of a layout-assigned VMI value, not a property inferred by +the final lowering pattern. + +### 2.1 Dense Layouts + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` defaults to `1`: + +```text +#pto.vmi.layout + == #pto.vmi.layout +``` + +Dense layouts preserve one semantic value for every logical lane. + +Lane map for `deinterleaved = F, block_elems = B`: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +Important consequence: + +```text +deinterleaved=2, block_elems=1 +deinterleaved=2, block_elems=8 +``` + +are different layouts. They cannot be treated as compatible because `F` is the +same. + +### 2.2 Sparse Group-Slot Layouts + +```text +#pto.vmi.layout +``` + +Only `G` lanes have semantic values: + +```text +slot_block(g) = g / K +slot_lane(g) = g % K +``` + +All non-slot lanes are undefined and may only be read by group-aware operations. +Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. + +`K` is selected by the lowering plan: + +```text +S=8/16/32 packed VCG result -> slots=8 +S=64 row-local result -> slots=1 +``` + +## 3. Lowering Context Must Become Assignment Output + +`vmi-to-vpto` may inspect only: + +```text +1. op name and explicit op attrs +2. converted operand/result types with layout +3. selected plan attrs written by layout assignment +4. inserted helper ops +5. target capability registry +``` + +It must not: + +```text +1. walk to defining op to infer layout +2. inspect all users to choose a lowering path +3. infer memory legality from a later mask +4. decide S=16 block_elems=1 vs block_elems=8 locally +5. decide whether group_broadcast should be materialized for one or many users +6. specialize function signatures during vmi-to-vpto +``` + +Any of those decisions belongs to `vmi-layout-assignment`. + +## 4. Explicit Assignment Products + +After `vmi-layout-assignment`, every VMI data and mask value must be in one of +these states: + +```text +layout-assigned type: + !pto.vmi.vreg> + !pto.vmi.mask> + +or explicit helper: + pto.vmi.ensure_layout + pto.vmi.ensure_mask_layout + pto.vmi.ensure_mask_granularity +``` + +Every context-sensitive op must also have a selected plan if layout alone does +not uniquely identify the lowering: + +```text +vmi.selected_plan = "dense_load_norm" +vmi.selected_plan = "load_dintlv2" +vmi.selected_plan = "load_dintlv4" +vmi.selected_plan = "group_load_contiguous_chunks" +vmi.selected_plan = "s16_group_load_block8_unit_stride" +vmi.selected_plan = "s16_group_load_block8_stride" +vmi.selected_plan = "s32_group_load_block8_stride" +vmi.selected_plan = "s8_reduce_contiguous" +vmi.selected_plan = "s16_reduce_parity" +vmi.selected_plan = "s16_reduce_block8" +vmi.selected_plan = "s32_reduce_dintlv4" +vmi.selected_plan = "s32_reduce_block8_stride" +vmi.selected_plan = "s64_reduce_row_local" +vmi.selected_plan = "group_slot_load_slots8_unit_stride" +vmi.selected_plan = "group_slot_load_slots1_row_local" +vmi.selected_plan = "group_broadcast_slots8_vselr" +vmi.selected_plan = "group_broadcast_slots1_vselr" +vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +``` + +The spelling above is illustrative; implementation may use an enum attr. The +invariant is not illustrative: if a lowering decision is not uniquely implied +by op + assigned operand/result layouts + explicit attrs, assignment must write +a selected plan. + +## 5. Plan Registry + +The compiler owns a target-aware plan registry. Layout assignment queries this +registry; vmi-to-vpto verifies and consumes the chosen plan. + +### 5.1 Plan Kinds + +```text +ProducerPlan: + op can produce result layout L + example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv + +ConsumerPlan: + op can consume operand layout L + example: group_reduce S=32 consumes deinterleaved=4 + +TransferPlan: + op ties operand/result layouts + example: addf requires same dense layout for operands/result + +MaterializationPlan: + layout A -> layout B without changing logical value + example: deinterleaved=4 -> contiguous by vintlv tree + +RematerializationPlan: + cheap producer can be cloned for a use-site layout + example: broadcast/create_mask/group_broadcast + +DiagnosticPlan: + known unsupported semantic/capability boundary + example: compact S=12 requires gather materialization +``` + +### 5.2 Dense Plans From Cases + +```text +f16 -> f32: + source contiguous f16 + result deinterleaved=2, block_elems=1 + +f8 -> f32: + source contiguous f8 + result deinterleaved=4, block_elems=1 + +f32 -> f16: + source deinterleaved=2, block_elems=1 + result contiguous f16 + +f32 -> f8: + source deinterleaved=4, block_elems=1 + result contiguous f8 + +elementwise dense: + all dense operands/results share the same layout + +broadcast scalar: + rematerializable to any dense layout requested by the consumer + +load: + may be rematerialized per use when two consumers request incompatible dense + layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 +``` + +### 5.3 Group Plans From Cases + +```text +group_reduce f32 S=8: + input contiguous + result group_slots(G, slots=8) + +group_reduce f32 S=16: + legal input layout A: deinterleaved=2, block_elems=1 + legal input layout B: deinterleaved=2, block_elems=8 + result group_slots(G, slots=8) + +group_reduce f32 S=32: + legal input layout A: deinterleaved=4, block_elems=1 + legal input layout B: deinterleaved=4, block_elems=8 + result group_slots(G, slots=8) + +group_reduce f32 S=64: + input contiguous + result group_slots(G, slots=1) + +group_slot_load: + result group_slots(G, slots=8) for packed slots + result group_slots(G, slots=1) for row-local slots + +group_broadcast: + source group_slots(G,K) + result is dense layout requested by each consumer + rematerialize per use instead of forcing one result layout + +group_store: + source group_slots(G,K) + +group_slot_cast f32 -> f16: + slots=1 row-local source/result is legal with + group_slot_cast_slots1_f32_to_f16 + slots=8 packed source is illegal unless a packed slot-preserving plan is + registered +``` + +### 5.4 Tail And Memory Safety Plans + +Mask semantics and memory legality are separate: + +```text +mask: + decides which logical lanes participate in compute/store semantics + +full_tile_readable: + decides whether a rounded-up physical load is allowed to read inactive lanes +``` + +The full-tile-readable proof must be explicit. It may be carried by a +statically shaped memref source, or by `pto.vmi.load {full_read_elems = N}` for +pointer sources. `vmi-to-vpto` consumes only this proof carrier; it does not +inspect surrounding MTE copies, producer bodies, callers, or later consumers to +decide whether inactive physical lanes are safe to read. + +Example: + +```text +S=32 tail num_groups=6: + without full_tile_readable: + fast DINTLV_B32 full-tile load is illegal + + with full_tile_readable: + full 8-row physical tile may be loaded + compute mask is PAT_VL48 per physical part + group store mask is PAT_VL6 + +S=16 grouped tail active_elems_per_group=12: + low 8-lane row half uses PAT_ALL + high 8-lane row half uses lane_mod_8 < 4 + the same split applies before and after group_broadcast + +one mask used by f32 and f16 consumers: + f32 use materializes a b32 predicate + f16 use materializes a b16 predicate + vmi-to-vpto consumes the assigned per-use mask materialization +``` + +## 6. Layout Assignment Algorithm + +`vmi-layout-assignment` is module-level. It must see function/call/control-flow +connections before choosing layouts. + +### 6.1 Variables + +Create a layout variable for: + +```text +1. every VMI OpResult +2. every VMI BlockArgument +3. every function argument/result that is allowed to carry VMI +4. every VMI mask value +``` + +Create a use-site request for: + +```text +1. every operand use that requires a specific layout +2. every control-flow yield/branch/call/return edge +3. every memory operation that requires a memory legality plan +``` + +### 6.2 Constraints + +Hard constraints: + +```text +group_slots cannot feed ordinary dense consumers +direct group-slot width-changing cast requires a slot-preserving plan +public/external VMI function boundary requires a stable ABI or diagnostic +S=32 fast tail load requires full_tile_readable or gather fallback +``` + +`slots = 1` row-local cast may satisfy the slot-preserving plan requirement. +Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast +or unpack/materialization plan is registered. + +Equivalence constraints: + +```text +dense add/mul/select: + operands/results use same dense layout unless an explicit materialization is + inserted at a use site + +scf.if/scf.for: + region yield operands and block arguments must have the same assigned layout + as the region result/iter_arg +``` + +Candidate constraints: + +```text +S=16 group_reduce: + choose block_elems=1 or block_elems=8 by cost and explicit assignment constraints + +one dense value feeding S=16 and S=32 group_reduce: + rematerialize a cheap producer per consumer layout, or insert an explicit + materialization plan; the final lowering pass must not pick one layout after + seeing both users + +load/group_load: + choose memory plan and result layout together + +group_broadcast: + rematerialize per dense consumer layout +``` + +### 6.3 Solving + +Recommended solving order: + +```text +1. Build function/control-flow SCCs. +2. Collect candidate plans for every op. +3. Propagate hard required layouts from consumers. +4. Propagate producer natural layouts where they are unique. +5. Resolve multi-plan ops by cost. +6. Insert use-site materialization where a value has multiple incompatible uses. +7. Rematerialize cheap producers instead of materializing when cheaper. +8. Specialize internal function signatures. +9. Emit diagnostics for unsatisfied hard constraints. +10. Rewrite VMI types and selected plan attrs. +``` + +Tie-breaking must be deterministic. Suggested priority: + +```text +1. Avoid unsupported plans. +2. Prefer rematerializing cheap producers over register materialization. +3. Prefer layouts accepted by all consumers without conversion. +4. Prefer memory-fused layout plans over load + register rearrange. +5. Prefer fewer VPTO instructions. +6. Prefer contiguous only when cost ties and no consumer requests a special layout. +``` + +## 7. Control Flow And Functions + +### 7.1 `scf.if` + +All branch yields for one result must agree on one assigned layout. If they do +not, assignment inserts materialization before `scf.yield` where possible. +The `scf.if` result type after assignment carries that layout, so +`vmi-to-vpto` does not need to inspect either branch body. + +### 7.2 `scf.for` + +Loop-carried VMI values are fixed-point variables: + +```text +initial iter_arg layout +body block argument layout +yield operand layout +loop result layout +``` + +must converge to one layout. If a body consumer needs another layout, it is a +use-site request inside the loop body. +The loop body block argument has no defining op. Its layout is therefore part +of the block argument type after assignment, not information reconstructed from +the initial value or previous iteration during lowering. + +### 7.3 Calls + +Internal/private VMI function boundaries must make layout choices explicit in +the assigned IR. The baseline implementation keeps function arguments in a +contiguous VMI ABI and inserts callee-entry `ensure_layout` helpers when the +callee body needs another layout. A later private-function optimization may +specialize signatures directly: + +```text +func @producer() -> !vmi.vreg<256xf32, deinterleaved=4> +``` + +then physicalized by `vmi-to-vpto` into multiple VPTO function results. + +Public/external VMI function boundaries are rejected until a stable VMI ABI is +defined. + +## 8. vmi-to-vpto Contract + +`vmi-to-vpto` receives layout-assigned VMI. It performs no global reasoning. + +For each op, the pattern: + +```text +1. reads operand/result layouts +2. reads selected_plan if required +3. asks TypeConverter for ordered physical values +4. emits the registered VPTO recipe +5. fails if the selected plan is missing or target capability is absent +``` + +The pattern must not: + +```text +1. inspect all users to decide result layout +2. inspect defining ops to decide source layout +3. choose between S=16 block_elems=1 and block_elems=8 +4. decide whether a load is full_tile_readable +5. decide function signature specialization +``` + +Allowed local reads are deliberately narrower: + +```text +arith.constant defining op: + allowed only to materialize an operand of the current op, such as + create_mask active_lanes or a constant memory offset + +current VMI op body/attrs: + allowed for op-local semantics, such as create_group_mask + active_elems_per_group when lowering the create_group_mask op itself + +helper materialization chain: + allowed only to strip ensure_mask_layout / ensure_mask_granularity for + static predicate analysis that does not choose a different layout or plan + +diagnostic embellishment: + allowed only to improve an already-failed capability message, such as naming + memref.subview after identity lane-to-address planning has failed +``` + +Anything else is a layout-assignment responsibility. In particular, an +unsupported producer/consumer combination must be rejected before assignment +writes a selected plan. Section 3.44 is the model: partial S=32 grouped masks +are diagnosed in `vmi-layout-assignment`, not by `vmi-to-vpto` walking from +`group_reduce_addf` to the mask producer. + +## 9. Physical Value Ordering + +The OneToN lowering order is fixed. + +```text +contiguous: + chunk0, chunk1, ... + +deinterleaved=F: + part0_chunk0, part0_chunk1, ..., + part1_chunk0, part1_chunk1, ..., + ... + part(F-1)_chunk0, ... + +group_slots(G,K): + slot_block0, slot_block1, ... +``` + +Two physical bundle entries may alias the same VPTO SSA value when the selected +plan proves they have the same contents, such as group_broadcast feeding both +parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; +aliasing is not a different layout. + +## 10. Diagnostics + +Diagnostics are part of the design. They must name: + +```text +1. the VMI op +2. source logical type +3. assigned source layout +4. requested layout +5. missing plan or disabled fallback +6. suggested rewrite when available +``` + +Examples: + +```text +dense store of group_slots: + use group_store, group_broadcast, or explicit group-pack + +packed group-slot f32->f16: + group_broadcast before truncf, or keep group_store as f32 + +S=32 tail without full_tile_readable: + mark source full_tile_readable or enable stable gather fallback + +S=32 group_load with unaligned source_group_stride: + choose a stride divisible by 8 f32 elements or enable stable gather fallback + +public VMI function boundary: + make function internal, inline before assignment, or define ABI layout +``` + +## 11. Design Completion Criteria + +The design is complete only when: + +```text +1. every case in vmi-layout-lowering-cases.md maps to registered plans +2. every selected plan can be emitted without looking at producer/user context +3. every unsupported case has a precise capability diagnostic +4. every control-flow/function boundary either specializes layout or diagnoses +5. every mask has explicit data layout and predicate granularity +6. every case has an end-to-end test and simulator validation +``` diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 807baf841e..b111397fc9 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -154,20 +154,22 @@ the immediately following complete endpoints. 3.6.2 group_reduce S=32 -> elemwise(rhs) -> group_store complete 3.6.3 group_reduce S=32 -> broadcast -> compute -> reduce -> store complete -3.7.1 group_reduce S=64 -> group_store complete -3.7.2 group_reduce S=64 -> elemwise(rhs) -> group_store complete +3.7.1 group_reduce S=64 -> aligned group_store complete +3.7.2 group_reduce S=64 -> elemwise(rhs) -> aligned group_store + complete 3.7.3 group_reduce S=64 -> broadcast -> compute -> reduce -> store complete +3.7.4 group_reduce S=64 -> unit-stride group_store illegal diagnostic 3.8 group_reduce -> truncf -> broadcast -> dense store complete 3.9 dense store of group slots illegal diagnostic 3.10 non-load producer feeding S=32 group_reduce complete 3.11 partial tail groups complete/diagnostic 3.12 control-flow join before group_reduce complete -3.13 direct group-slot f32 -> f16 cast illegal diagnostic +3.13 packed group-slot f32 -> f16 cast illegal diagnostic 3.14 unsupported group size illegal diagnostic 3.15 compact S=12 written as logical S=16 complete/design 3.16 group_slot_load layout contract complete -3.17 group_broadcast physical arity alias complete +3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization 3.19 S=16 reduce block_elems plan selection complete/diagnostic 3.20 group_slots control-flow join complete @@ -176,6 +178,25 @@ the immediately following complete endpoints. 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete 3.25 function boundary layout specialization complete/design +3.26 S=16 grouped tail through broadcast/reduce/store complete +3.27 S=32 group_load with stride greater than group size complete +3.28 group_slot_load slots=1 aligned non-unit stride complete +3.29 one semantic mask with f32 and f16 consumers complete +3.30 masked_load tail without padding complete/diagnostic +3.31 f16->f32 feeding dense store and S=16 reduce complete +3.32 f32 feeding f8 store and S=32 reduce complete +3.33 one dense value feeding S=16 and S=32 reduces complete/materialization +3.34 S=64 group-slot result f32->f16 cast complete +3.35 group_slots fanout to group_store and broadcast complete/design +3.36 same scalar source materialized as slots=8/slots=1 complete/design +3.37 S=64 group_store with non-unit output stride complete/design +3.38 multi-tile S=32 group_reduce complete +3.39 strided S=32 group_load through broadcast/reduce complete +3.40 scalar broadcast feeding dense and grouped users complete/materialization +3.41 non-rematerializable value with incompatible users complete/materialization +3.42 group_slots scf.for loop-carried accumulator complete +3.43 internal function argument boundary materialization complete/design +3.44 masked_load grouped tail feeding S=32 reduce complete/design ``` ### 3.1 `f16 -> f32 -> store` @@ -1170,7 +1191,8 @@ VMI input: %x = pto.vmi.load %base[%off] : memref<512xf32> -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} -pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %sum_out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1247,7 +1269,7 @@ Memory result: ```text for r = 0..7: - sum_out[group_tile_off + r] = reduce(row_r[0..63]) + sum_out[group_tile_off + r * 8] = reduce(row_r[0..63]) ``` #### 3.7.2 Reduce Result, Elementwise, Store @@ -1261,7 +1283,8 @@ VMI input: : !pto.ptr -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} %outv = pto.vmi.addf %sum, %rhs -pto.vmi.group_store %outv, %out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %outv, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1348,7 +1371,7 @@ Memory result: ```text for r = 0..7: - out[group_tile_off + r] = reduce(row_r[0..63]) + rhs[r] + out[group_tile_off + r * 8] = reduce(row_r[0..63]) + rhs[r] ``` #### 3.7.3 Reduce, Broadcast, Elementwise, Reduce, Store @@ -1362,7 +1385,8 @@ VMI input: %b = pto.vmi.group_broadcast %sum {num_groups = 8} %y = pto.vmi.mulf %x, %b %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} -pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +pto.vmi.group_store %ysum, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layouts: @@ -1415,11 +1439,50 @@ Memory result: ```text for r = 0..7: s = reduce(row_r[0..63]) - out[group_tile_off + r] = + out[group_tile_off + r * 8] = reduce_i(row_r[i] * s for i = 0..63) = s * s ``` +#### 3.7.4 Unit-Stride Store Is Not A Valid Lowering Yet + +The row-local S=64 result uses one physical vreg per group with the semantic +value in lane 0: + +```text +%sum_r lane 0 = reduce(row_r[0..63]) +``` + +The current VPTO lowering for `slots = 1` group_store emits one lane-0 `vsts` +per group. Therefore unit-stride f32 output would issue stores at: + +```text +group_off + 0, group_off + 1, group_off + 2, ... +``` + +Only the first address is necessarily 32B-aligned. The remaining f32 addresses +are 4B apart and are not valid for this `vsts` lowering. The compiler must not +accept this as a clean lowering until either a pack-to-slots=8 plan or an +unaligned-store plan is selected. + +VMI input: + +```text +%c1 = arith.constant 1 : index +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Required diagnostic: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_store with #pto.vmi.layout lowers + as one lane-0 vsts per group and requires constant positive row_stride + divisible by 8 f32 elements for 32B store alignment. Packed or unaligned + contiguous store lowering is not implemented. +``` + ### 3.8 `group_reduce -> truncf -> group_broadcast -> store` VMI input: @@ -1441,7 +1504,8 @@ Assigned layouts: %sum32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> %sum16 : semantic value only; not materialized as a group-slot VPTO value -%b32 : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b32_dense : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b32_split : !pto.vmi.vreg<128xf32, #pto.vmi.layout> %b16 : !pto.vmi.vreg<128xf16, #pto.vmi.layout> ``` @@ -1452,8 +1516,13 @@ group_broadcast(truncf(group_reduce(x))) == truncf(group_broadcast(group_reduce(x))) ``` -This avoids materializing a group-slot f16 value. The only cast emitted is the -existing dense `f32 deinterleaved=2 -> contiguous f16` truncation. +This avoids materializing a group-slot f16 value. Current lowering makes the +layout transition explicit: `group_broadcast` first produces a dense contiguous +f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 +view required by dense `f32 -> f16` truncation. A future direct +`group_broadcast -> deinterleaved=2` lowering may remove that materialization, +but it must be implemented as a `group_broadcast` selected plan rather than +hidden inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -1473,19 +1542,28 @@ VPTO lowering result for one full 8-row tile: : !pto.vreg<64xf32> %lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> -%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 - : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%broadcast_idx_lo = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx_hi = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> -// This vselr is the VPTO lowering of pto.vmi.group_broadcast. The later store -// only writes lanes as-is; it does not duplicate group-slot values. -%b32_rows = pto.vselr %sum32_block, %broadcast_idx +// These vselr ops are the VPTO lowering of pto.vmi.group_broadcast for the two +// dense contiguous f32 physical chunks. +%b32_rows_lo = pto.vselr %sum32_block, %broadcast_idx_lo : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b32_rows_hi = pto.vselr %sum32_block, %broadcast_idx_hi + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=2 materializes the two f32 parity +// inputs expected by f32 -> f16 truncation. +%b32_even_input, %b32_odd_input = pto.vdintlv %b32_rows_lo, %b32_rows_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// The broadcasted f32 value is dense deinterleaved=2. -// Both parity parts carry the same per-row broadcast values. -%b16_even = pto.vcvt %b32_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} +%b16_even = pto.vcvt %b32_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%b16_odd = pto.vcvt %b32_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} +%b16_odd = pto.vcvt %b32_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -1561,7 +1639,7 @@ Assigned layouts: ```text %a, %bias, %x: - !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.vreg<256xf32, #pto.vmi.layout> %sum: !pto.vmi.vreg<256xf32, #pto.vmi.layout> @@ -1629,7 +1707,8 @@ VMI input: %x = pto.vmi.load %base[%off] : memref<384xf32> -> !pto.vmi.vreg<384xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} -pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} +%c8 = arith.constant 8 : index +pto.vmi.group_store %sum, %out[%group_off], %c8 {num_groups = 6} ``` Assigned layouts: @@ -1665,7 +1744,7 @@ Memory result: ```text for r = 0..5: - out[group_tile_off + r] = reduce(row_r[0..63]) + out[group_tile_off + r * 8] = reduce(row_r[0..63]) ``` #### 3.11.2 S=32 Tail Without Full-Tile Read Contract @@ -1803,7 +1882,7 @@ VMI-LAYOUT-CONTRACT: Expected #pto.vmi.layout on every incoming value. ``` -### 3.13 Direct Group-Slot `f32 -> f16` Cast +### 3.13 Packed Group-Slot `f32 -> f16` Cast This case is intentionally illegal for the current S=16/S=32 packed group-slot layout. It prevents the compiler from treating a width-changing @@ -1922,6 +2001,10 @@ Semantics: lane i is active iff (i % S) < active_elems_per_group ``` +Current lowering support covers constant `active_elems_per_group`. Dynamic +grouped masks require a runtime lane-index predicate materializer and remain a +separate implementation item. + Ordinary `pto.vmi.create_mask %active_lanes` keeps the prefix-mask meaning: ```text @@ -1960,6 +2043,10 @@ Assigned layouts: %sum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x32_for_store: + pto.vmi.ensure_layout %x32 + : #pto.vmi.layout -> #pto.vmi.layout ``` VPTO lowering result for one `8x16xf32` tile: @@ -2212,9 +2299,10 @@ silently using full-group `group_load`. VMI input: ```text -%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} : !pto.ptr, index -> !pto.vmi.vreg<512xf32> -pto.vmi.group_store %rhs, %out[%group_off], %c1 {num_groups = 8} +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} ``` Assigned layout: @@ -2231,6 +2319,8 @@ VPTO lowering result: // Emit this shape for r = 0..7. Each result value carries one semantic slot // in lane 0, matching the S=64 row-local group_reduce result layout. +// For f32, source_group_stride = 8 elements = 32B, so every lane-0 vsldb is +// aligned. %rhs_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> @@ -2242,14 +2332,32 @@ Memory result: ```text for r = 0..7: - out[group_off + r] = rhs_base[rhs_off + r] + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Current lowering rule: + +```text +slots = 1 group_slot_load uses one lane-0 vsldb per semantic group slot. +For f32, source_group_stride must be a positive constant divisible by 8 +elements. For f16 it must be divisible by 16 elements, and for f8 it must be +divisible by 32 elements. ``` -### 3.17 `group_broadcast` Physical Arity Alias +### 3.17 `group_broadcast` Feeding A Deinterleaved Consumer + +This case fixes a lowering invariant: `group_broadcast` itself does not infer a +consumer-specific deinterleaved result. It produces the layout selected by +layout assignment. If a later consumer requires another layout, assignment must +insert an explicit `ensure_layout`. -This case fixes a lowering invariant: a layout determines physical arity. A -`deinterleaved = 2` result has two physical bundle entries even when both -entries can reuse the same VPTO SSA value. +The current endpoint is: + +```text +group_reduce -> group_broadcast(contiguous f32) + -> ensure_layout(deinterleaved = 2) + -> truncf(contiguous f16) +``` VMI input: @@ -2272,8 +2380,12 @@ Assigned layouts: %sum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -%b: - !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b_dense: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_split = pto.vmi.ensure_layout %b_dense: + #pto.vmi.layout + -> #pto.vmi.layout %h: !pto.vmi.vreg<128xf16, #pto.vmi.layout> @@ -2295,22 +2407,26 @@ VPTO lowering result: %sum_block = pto.vadd %lo_sum, %hi_sum, %sum_mask : !pto.vreg<64xf32> -%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> -%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 - : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +// group_broadcast lowers to two contiguous f32 chunks. +%idx_lo = materialize indices [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%idx_hi = materialize indices [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> -%b_rows = pto.vselr %sum_block, %broadcast_idx +%b_lo = pto.vselr %sum_block, %idx_lo + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_hi = pto.vselr %sum_block, %idx_hi : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -// Physical bundle binding for %b, not emitted VPTO ops: -// physical entry 0 = %b_rows -// physical entry 1 = %b_rows -// The layout still has two physical entries; they alias the same SSA value -// because every even/odd logical lane pair contains the same broadcast value. +// ensure_layout contiguous -> deinterleaved=2 is explicit in assigned VMI. +%b_even_input, %b_odd_input = pto.vdintlv %b_lo, %b_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%h_even = pto.vcvt %b_rows, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} +%h_even = pto.vcvt %b_even_input, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%h_odd = pto.vcvt %b_rows, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} +%h_odd = pto.vcvt %b_odd_input, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -2329,6 +2445,15 @@ for r = 0..7: out[r * 16 + 0 .. r * 16 + 15] = truncf(s) ``` +Required assignment rule: + +```text +`group_broadcast` layout is chosen before `vmi-to-vpto`. A width-changing +consumer such as `truncf` may require a deinterleaved f32 source, but that +requirement must be represented by `ensure_layout`; `truncf` lowering must not +look through the defining `group_broadcast` and choose a hidden broadcast shape. +``` + ### 3.18 One Value With Dense And Group-Reduce Consumers This case forces layout assignment to handle a solvable use-site conflict. One @@ -2663,27 +2788,40 @@ for r = 0..7: ### 3.21 S=32 Tail With Full-Tile-Readable Source This is the positive counterpart to section 3.11.2. Tail participation is -still expressed by masks, but the source additionally promises that reading the -rounded-up 8-row physical tile is memory-safe. +still expressed by masks, but the source must provide a static proof that +reading the rounded-up 8-row physical tile is memory-safe. That proof is +explicit: it can come from a statically shaped memref source, or from +`pto.vmi.load {full_read_elems = N}` on a pointer source. The pointer attr +means the memory interval starting at the load offset is safe to read for `N` +logical elements; it is not inferred from surrounding MTE copies or caller +context. VMI input: ```text -%x = pto.vmi.load %base[%off] {full_tile_readable} - : memref<192xf32> -> !pto.vmi.vreg<192xf32> +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6} pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} ``` +Equivalent pointer-source VMI input for runtime kernels: + +```text +%x = pto.vmi.load %base[%off] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> +``` + Assigned layouts: ```text %x: - !pto.vmi.vreg<192xf32, #pto.vmi.layout> + !pto.vmi.vreg<192xf32, #pto.vmi.layout> %mask: - !pto.vmi.mask<192xpred, #pto.vmi.layout> + !pto.vmi.mask<192xpred, + #pto.vmi.layout> %sum: !pto.vmi.vreg<192xf32, #pto.vmi.layout> @@ -2692,28 +2830,44 @@ Assigned layouts: VPTO lowering result: ```text -// Full-tile-readable allows the load plan to read the rounded-up 8-row tile. -// Only rows 0..5 are semantically active. -%data_mask = pto.pge_b32 "PAT_VL48" // 6 rows * 8 lanes per physical part -%sum_mask = pto.pge_b32 "PAT_VL6" - -%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" - : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" - : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// A statically safe full-read proof allows the load plan to read the +// rounded-up 8-row tile. Only rows 0..5 are semantically active. +%x_c0 = pto.vlds %base[%tile_off_0] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c1 = pto.vlds %base[%tile_off_1] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c2 = pto.vlds %base[%tile_off_2] + : memref<256xf32> -> !pto.vreg<64xf32> +%x_c3 = pto.vlds %base[%tile_off_3] + : memref<256xf32> -> !pto.vreg<64xf32> -%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 +%x_lo01, %x_hi01 = pto.vdintlv %x_c0, %x_c1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 +%x_lo23, %x_hi23 = pto.vdintlv %x_c2, %x_c3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x_lo01, %x_lo23 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_hi01, %x_hi23 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%s0 = pto.vcgadd %x_p0, %data_mask +%data_mask0, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask1, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask2, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%data_mask3, %_ = pto.plt_b32 %c48_i32 + : i32 -> !pto.mask, i32 +%sum_mask, %_ = pto.plt_b32 %c6_i32 + : i32 -> !pto.mask, i32 + +%s0 = pto.vcgadd %x_p0, %data_mask0 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s1 = pto.vcgadd %x_p1, %data_mask +%s1 = pto.vcgadd %x_p1, %data_mask1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s2 = pto.vcgadd %x_p2, %data_mask +%s2 = pto.vcgadd %x_p2, %data_mask2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -%s3 = pto.vcgadd %x_p3, %data_mask +%s3 = pto.vcgadd %x_p3, %data_mask3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> @@ -2731,9 +2885,9 @@ for r = 0..5: out[group_off + r] = reduce(row_r[0..31]) ``` -Rows 6 and 7 may be physically loaded because of `full_tile_readable`, but -their lanes are not active in `%data_mask`, and their group slots are not stored -because `%sum_mask` is `PAT_VL6`. +Rows 6 and 7 may be physically loaded because of the safe full-read proof, but +their lanes are not active in `%data_mask*`, and their group slots are not +stored because `%sum_mask` is produced by `plt_b32 %c6_i32`. ### 3.22 `scf.for` Loop-Carried Layout @@ -2851,23 +3005,49 @@ pto.vmi.group_store %ysum, %sum_out[%group_off], %c1 {num_groups = 8} pto.vmi.store %h, %dense_out[%off] ``` -Assigned layouts: +Assigned layouts in the current implementation: ```text -%x, %b_for_mul, %y: +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: !pto.vmi.vreg<128xf32, #pto.vmi.layout> %sum, %ysum: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%b_for_mul, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + %b_for_cast: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b_for_cast_split: !pto.vmi.vreg<128xf32, #pto.vmi.layout> %h: !pto.vmi.vreg<128xf16, #pto.vmi.layout> ``` +The important invariant is not that both dense consumers choose the same dense +layout. It is that each use has an explicit layout boundary: + +```text +%x_for_reduce = pto.vmi.ensure_layout %x +%y_for_reduce = pto.vmi.ensure_layout %y +%b_for_cast_split = pto.vmi.ensure_layout %b_for_cast +``` + +If a future `group_broadcast -> deinterleaved` selected plan is added, layout +assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but +the choice must still be visible in the assigned IR and selected plan. + VPTO lowering result: ```text @@ -2887,12 +3067,17 @@ VPTO lowering result: %broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> -// Use 1: broadcast for the S=16 block_elems=8 multiply path. Both row halves -// use the same per-row broadcast vector. -%b_rows_for_mul = pto.vselr %sum_block, %broadcast_idx +// Use 1: broadcast for the multiply path. Current lowering materializes two +// contiguous f32 chunks, multiplies them with the original contiguous chunks, +// then deinterleaves the product for the second group_reduce. +%b_rows_for_mul_0 = pto.vselr %sum_block, %broadcast_idx_0 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -%y_lo = pto.vmul %x_lo, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> -%y_hi = pto.vmul %x_hi, %b_rows_for_mul, %all_b32 : !pto.vreg<64xf32> +%b_rows_for_mul_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%y0 = pto.vmul %x0, %b_rows_for_mul_0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b_rows_for_mul_1, %all_b32 : !pto.vreg<64xf32> +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> %y_lo_sum = pto.vcgadd %y_lo, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %y_hi_sum = pto.vcgadd %y_hi, %all_b32 @@ -2902,14 +3087,17 @@ VPTO lowering result: pto.vsts %ysum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -// Use 2: rematerialize broadcast for the f32->f16 parity cast path. The -// deinterleaved=2 physical bundle has two entries that alias this SSA value. -%b_rows_for_cast = pto.vselr %sum_block, %broadcast_idx +// Use 2: rematerialize broadcast for the f32->f16 parity cast path. +%b_rows_for_cast_0 = pto.vselr %sum_block, %broadcast_idx_0 : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> -%h_even = pto.vcvt %b_rows_for_cast, %all_b32 +%b_rows_for_cast_1 = pto.vselr %sum_block, %broadcast_idx_1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%cast_lo, %cast_hi = pto.vdintlv %b_rows_for_cast_0, %b_rows_for_cast_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%h_even = pto.vcvt %cast_lo, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -%h_odd = pto.vcvt %b_rows_for_cast, %all_b32 +%h_odd = pto.vcvt %cast_hi, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> %all_b16 = pto.pge_b16 "PAT_ALL" @@ -2962,7 +3150,7 @@ VPTO lowering result: ```text %all_b32 = pto.pge_b32 "PAT_ALL" -%m = pto.pge_b32 "PAT_VL48" +%m, %_ = pto.plt_b32 %c48_i32 : i32 -> !pto.mask, i32 %x0 = pto.vlds %base[%off] {dist = "NORM"} : !pto.ptr -> !pto.vreg<64xf32> @@ -3101,3 +3289,1967 @@ VMI-LAYOUT-CONTRACT: stable VMI layout ABI. Mark the function internal for layout specialization, inline it before vmi-layout-assignment, or define an explicit ABI layout. ``` + +### 3.26 S=16 Grouped Tail Through Broadcast, Reduce, Store + +This case extends section 3.15.1 from `reduce -> group_store` to the full +grouped compute path. It is needed because `create_group_mask` must remain a +group-periodic mask after a `group_broadcast`; it cannot collapse to a prefix +mask or an all-true mask. + +VMI input: + +```text +%stride16 = arith.constant 16 : index +%x = pto.vmi.group_load %base[%off], %stride16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%c12 = arith.constant 12 : index +%mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %b, %y: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result for one `8x16xf32` tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row8 = pto.vshls %row, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row8, %all_b32 + : !pto.vreg<64xi32> +%hi4_mask = pto.vcmps %col, %c4_i32, %all_b32, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask + +%x_lo, %x_hi = pto.vldsx2 %base[%tile_off], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%x_hi_sum = pto.vcgadd %x_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %x_lo_sum, %x_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +%broadcast_idx = pto.vshrs %lane, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%b_rows = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_lo = pto.vmul %x_lo, %b_rows, %all_b32 : !pto.vreg<64xf32> +%y_hi = pto.vmul %x_hi, %b_rows, %hi4_mask : !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %hi4_mask + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..11]) + out[group_tile_off + r] = + reduce_i(row_r[i] * s for i = 0..11) + = s * s +``` + +Required assignment rule: + +```text +%mask is a grouped mask with S=16 and active_elems_per_group=12. +For the low half, the physical predicate is PAT_ALL. +For the high half, the physical predicate is lane_mod_8 < 4. +The same split must be reused for both group_reduce operations. +``` + +### 3.27 S=32 `group_load` With Stride Greater Than Group Size + +This case is the S=32 counterpart to section 3.15.2. The logical group is +`32xf32`, but rows in memory have a larger stride. The fast plan is legal only +when the stride is a multiple of one 32B f32 block. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +// source_group_stride = 40 f32 = 5 * 32B blocks. +%stride_blocks = %c5_i16 + +%frag0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%frag3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%frag0 lanes r*8 .. r*8+7 = row_r[0..7] +%frag1 lanes r*8 .. r*8+7 = row_r[8..15] +%frag2 lanes r*8 .. r*8+7 = row_r[16..23] +%frag3 lanes r*8 .. r*8+7 = row_r[24..31] + +%s0 = pto.vcgadd %frag0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %frag1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s2 = pto.vcgadd %frag2, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s3 = pto.vcgadd %frag3, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_tile_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_tile_off + r] = + reduce(base[tile_off + r * 40 + 0 .. tile_off + r * 40 + 31]) +``` + +Required diagnostic when the stride is not block-aligned: + +```text +VMI-LAYOUT-CONTRACT: + pto.vmi.group_load group_size 32 with source_group_stride not divisible by + 8 f32 elements cannot use the registered vsldb strided-block plan. Enable a + stable gather plan or choose a block-aligned source_group_stride. +``` + +Required assignment rule: + +```text +This producer selects the S=32 block-fragment plan: + #pto.vmi.layout + +It must not be unified with the contiguous-load S=32 plan from section 3.6: + #pto.vmi.layout + +Both layouts are legal inputs to group_reduce_addf S=32, but they require +different producer materialization plans. +``` + +### 3.28 `group_slot_load` `slots = 1` With Aligned Non-Unit Stride + +Section 3.16.1 diagnoses non-unit stride for the packed `slots = 8` plan. The +row-local `slots = 1` plan supports non-unit stride only when each one-lane +load can be issued as an aligned `vsldb`. In the current lowering this means +the stride is a positive compile-time constant and is divisible by the 32B +alignment expressed in source elements. + +VMI input: + +```text +%c8 = arith.constant 8 : index +%rhs = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c8 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +pto.vmi.group_store %rhs, %out[%group_off], %c8 {num_groups = 8} +``` + +Assigned layout: + +```text +%rhs: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this shape for r = 0..7. The address expression is scalar/index +// arithmetic outside the vector register layout. For f32, %c8 is 32B. +%addr_r = %rhs_base + %rhs_off + r * 8 +%rhs_r = pto.vsldb %addr_r, %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %rhs_r, %out[%group_tile_off_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * 8] = rhs_base[rhs_off + r * 8] +``` + +Required assignment rule: + +```text +If a non-unit-stride group_slot_load has only slots=1 consumers and its stride +is a positive constant divisible by the element count of 32B, select +group_slot_load_slots1_row_local. Do not diagnose it using the slots=8 +unit-stride restriction. +``` + +Required diagnostic: + +```text +%c2 = arith.constant 2 : index +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + slots=1 group_slot_load currently lowers as one lane-0 vsldb per group and + requires constant positive source_group_stride divisible by 8 elements for + 32B load alignment; packed or unaligned scalar load lowering is not + implemented. +``` + +Dynamic stride has the same status until a stable gather or scalarized packed +load plan is designed: + +```text +%bad = pto.vmi.group_slot_load %rhs_base[%rhs_off], %runtime_stride + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + +VMI-UNSUPPORTED: pto.vmi.group_slot_load + requires constant positive source_group_stride divisible by 8 elements. +``` + +### 3.29 One Semantic Mask With f32 And f16 Consumers + +One VMI mask may feed consumers with different physical predicate +granularities. Layout assignment must keep the semantic mask value single, but +materialize per-use physical masks after element type is known. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xpred> +pto.vmi.masked_store %x, %out32[%off], %mask +%h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> +pto.vmi.masked_store %h, %out16[%off], %mask +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%x_for_cast: + pto.vmi.ensure_layout %x + : #pto.vmi.layout -> #pto.vmi.layout + +%mask_for_h_store: + pto.vmi.create_mask %c96 + : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + +%h: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +Physical mask materialization: + +```text +use at masked_store %x: + predicate granularity b32, PAT_VL96, layout contiguous + +use at vcvt %x -> %h: + predicate granularity b32, PAT_ALL. The cast may compute inactive lanes + because the following masked_store controls the external memory effect. + +use at masked_store %h: + predicate granularity b16, PAT_VL96, layout contiguous +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%mask32_0 = pto.pge_b32 "PAT_ALL" +%mask32_1 = pto.pge_b32 "PAT_VL32" + +%x0 = pto.vlds %base[%off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +pto.vsts %x0, %out32[%off], %mask32_0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %x1, %out32[%off_plus_64], %mask32_1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +%x_p0, %x_p1 = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%h_even = pto.vcvt %x_p0, %all_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +%h_odd = pto.vcvt %x_p1, %all_b32 {part = "ODD", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +%all_b16 = pto.pset_b16 "PAT_ALL" +%h0 = pto.vor %h_even, %h_odd, %all_b16 + : !pto.vreg<128xf16> +%mask_b16, %scalar_out = pto.plt_b16 %c96_i32 + : i32 -> !pto.mask, i32 +pto.vsts %h0, %out16[%off], %mask_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..95: + out32[off + i] = base[off + i] + out16[off + i] = truncf(base[off + i]) + +for i = 96..127: + out32[off + i] is unchanged + out16[off + i] is unchanged +``` + +Required assignment rule: + +```text +`vmi-to-vpto` must not decide mask granularity by inspecting users. It consumes +the per-use typed mask materialization inserted by vmi-layout-assignment. For +a rematerializable `create_mask`, assignment may clone it as b32/b16 masks. For +a non-rematerializable mask producer, assignment must insert +`ensure_mask_granularity` or diagnose if no materialization plan is registered. +``` + +### 3.30 `masked_load` Tail Without Padding + +This case is the replacement for `vector.transfer_read` padding semantics in the +initial VMI surface. Tail lanes are expressed by a mask and a passthrough value; +there is no implicit padding constant in the load. The direct lowering is legal +only when every physical chunk read by `vlds` is memory-safe. + +VMI input: + +```text +%c100 = arith.constant 100 : index +%mask = pto.vmi.create_mask %c100 : index -> !pto.vmi.mask<100xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<100xf32> +%x = pto.vmi.masked_load %base[%c0], %mask, %zero + : memref<128xf32>, !pto.vmi.mask<100xpred>, !pto.vmi.vreg<100xf32> + -> !pto.vmi.vreg<100xf32> +pto.vmi.store %x, %out[%c0] +``` + +Assigned layouts: + +```text +%mask: + !pto.vmi.mask<100xb32, #pto.vmi.layout> + +%zero, %x: + !pto.vmi.vreg<100xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%m0 = pto.pge_b32 "PAT_ALL" +%m1 = pto.pge_b32 "PAT_VL36" + +%zero0 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %m0 + : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%c0] + : memref<128xf32> -> !pto.vreg<64xf32> +%x0 = pto.vsel %l0, %zero0, %m0 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +%l1 = pto.vlds %base[%c64] + : memref<128xf32> -> !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %m1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +pto.vsts %x0, %out[%c0], %m0 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %x1, %out[%c64], %m1 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for i = 0..99: + out[i] = base[i] + +for i = 100..127: + out[i] is unchanged +``` + +Required diagnostic when the source cannot prove a safe full-read footprint: + +```text +VMI-UNSUPPORTED: + pto.vmi.masked_load direct lowering requires a supported memory source, + contiguous result/passthru/mask layouts, and either full physical chunks or a + statically safe full-read footprint. Use a memref with enough static extent, + enable the future stable masked/gather load plan, or make the logical vector a + full physical chunk. +``` + +Required assignment rule: + +```text +`masked_load` requests contiguous result, passthru, and mask layouts. Padding +is not a layout decision; it is the explicit passthrough operand selected by the +user. +``` + +### 3.31 `f16 -> f32` Feeding Dense Store And S=16 Reduce + +This case proves that the `deinterleaved = 2` layout produced by widening +`f16 -> f32` is not just a store layout. It must also be a legal S=16 grouped +reduction input. Layout assignment must not force the reduce consumer to +`block_elems = 8` and then rematerialize the widened value. + +VMI input: + +```text +%x16 = pto.vmi.load %base[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assigned layouts: + +```text +%x16: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%x32: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x16_0 = pto.vlds %base[%off] + : memref<128xf16> -> !pto.vreg<128xf16> +%x32_p0 = pto.vcvt %x16_0, %all_b16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x16_0, %all_b16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %s0, %s1, %sum_mask + : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%dense0, %dense1 = pto.vintlv %x32_p0, %x32_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %dense0, %dense_out[%off], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +pto.vsts %dense1, %dense_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<128xf32>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(extf(base[off + r * 16 + 0 .. off + r * 16 + 15])) + +for i = 0..127: + dense_out[off + i] = extf(base[off + i]) +``` + +Required assignment rule: + +```text +When S=16 group_reduce consumes an existing `deinterleaved = 2` dense value, +the reduce plan must accept `block_elems = 1`. `block_elems = 8` is only a +producer-driven fast plan for block-fragment loads, not the semantic +requirement of S=16 reduction. +``` + +### 3.32 `f32` Feeding f8 Store And S=32 Reduce + +This is the `f32 -> f8` counterpart to section 3.31. A 256-lane f32 value can +serve both `truncf -> f8` and S=32 group reduction with the same +`deinterleaved = 4, block_elems = 1` layout. The value must not be forced to a +block-fragment `block_elems = 8` layout unless its producer requires that plan. + +VMI input: + +```text +%x32 = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +%x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8> +pto.vmi.store %x8, %out8[%off] +``` + +Assigned layouts: + +```text +%x32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum_mask = pto.pge_b32 "PAT_VL8" + +%x0 = pto.vlds %base[%off] : memref<256xf32> -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] : memref<256xf32> -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%off_plus_128] : memref<256xf32> -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%off_plus_192] : memref<256xf32> -> !pto.vreg<64xf32> + +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %sum_mask : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %sum_mask : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %sum_mask : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %sum_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, memref<8xf32>, !pto.mask + +%x8_p0 = pto.vcvt %x_p0, %all_b32 {part = "P0", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p1 = pto.vcvt %x_p1, %all_b32 {part = "P1", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p2 = pto.vcvt %x_p2, %all_b32 {part = "P2", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> +%x8_p3 = pto.vcvt %x_p3, %all_b32 {part = "P3", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8> + +%x8_01 = pto.vor %x8_p0, %x8_p1, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_23 = pto.vor %x8_p2, %x8_p3, PAT_ALL_B8 + : !pto.vreg<256xf8> +%x8_0 = pto.vor %x8_01, %x8_23, PAT_ALL_B8 + : !pto.vreg<256xf8> + +pto.vsts %x8_0, %out8[%off], PAT_ALL_B8 {dist = "NORM_B8"} + : !pto.vreg<256xf8>, memref<256xf8>, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) + +for i = 0..255: + out8[off + i] = truncf(base[off + i]) +``` + +Required assignment rule: + +```text +The common layout selected for `%x32` is +`#pto.vmi.layout`. This satisfies both +`truncf f32 -> f8` and S=32 `group_reduce_addf`. A later strided block-load +producer may introduce `block_elems = 8`, but that is a different case and +requires an explicit materialization/rematerialization decision. +``` + +### 3.33 One Dense Value Feeding S=16 And S=32 Reduces + +This case is a pure layout-assignment conflict. The same logical +`256xf32` value is consumed by two legal reductions, but their efficient input +layouts are different: + +```text +S=16 reduce over 16 groups: + #pto.vmi.layout + +S=32 reduce over 8 groups: + #pto.vmi.layout +``` + +The program is semantically legal. Layout assignment must solve it by cloning +or rematerializing the cheap load for one use, or by inserting an explicit +registered materialization plan. `vmi-to-vpto` must not inspect both users and +choose one locally. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%mask16 = pto.vmi.create_group_mask %c16 {num_groups = 16, group_size = 16} + : index -> !pto.vmi.mask<256xpred> +%sum16 = pto.vmi.group_reduce_addf %x, %mask16 {num_groups = 16} +pto.vmi.group_store %sum16, %out16[%group_off16], %c1 {num_groups = 16} + +%mask32 = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask32 {num_groups = 8} +pto.vmi.group_store %sum32, %out32[%group_off32], %c1 {num_groups = 8} +``` + +Assigned layouts after rematerializing the load: + +```text +%x_s16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask16: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask32: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%sum8_mask = pto.pge_b32 "PAT_VL8" + +// Rematerialized S=16 use. The first vldsx2 covers rows 0..7, the second +// covers rows 8..15. Each pair is deinterleaved by element parity. +%s16_p0, %s16_p1 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_p2, %s16_p3 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s16_0 = pto.vcgadd %s16_p0, %all_b32 : !pto.vreg<64xf32> +%s16_1 = pto.vcgadd %s16_p1, %all_b32 : !pto.vreg<64xf32> +%s16_2 = pto.vcgadd %s16_p2, %all_b32 : !pto.vreg<64xf32> +%s16_3 = pto.vcgadd %s16_p3, %all_b32 : !pto.vreg<64xf32> + +%sum16_lo = pto.vadd %s16_0, %s16_1, %sum8_mask + : !pto.vreg<64xf32> +%sum16_hi = pto.vadd %s16_2, %s16_3, %sum8_mask + : !pto.vreg<64xf32> + +pto.vsts %sum16_lo, %out16[%group_off16], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +pto.vsts %sum16_hi, %out16[%group_off16_plus_8], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Rematerialized S=32 use. Two DINTLV loads plus one register deinterleave +// level produce mod-4 columns for rows 0..7. +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%tile_off_0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%tile_off_1], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s32_0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s32_1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s32_2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s32_3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> + +%s32_01 = pto.vadd %s32_0, %s32_1, %sum8_mask : !pto.vreg<64xf32> +%s32_23 = pto.vadd %s32_2, %s32_3, %sum8_mask : !pto.vreg<64xf32> +%sum32_block = pto.vadd %s32_01, %s32_23, %sum8_mask : !pto.vreg<64xf32> + +pto.vsts %sum32_block, %out32[%group_off32], %sum8_mask {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out16[group_off16 + r] = + reduce(base[off + r * 16 + 0 .. off + r * 16 + 15]) + +for r = 0..7: + out32[group_off32 + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +If a cheap producer such as load can produce both requested layouts, clone or +rematerialize it at the use sites and assign each clone independently. If the +producer is not rematerializable and no deinterleaved=2 <-> deinterleaved=4 +materialization plan is registered, emit a layout-contract diagnostic naming +both consumers and both required layouts. +``` + +### 3.34 S=64 Group-Slot Result `f32 -> f16` Cast + +Section 3.13 rejects direct width-changing cast for packed `slots = 8` +group-slot values. This case is the positive counterpart for row-local +`slots = 1`: each group result is already lane 0 of its own physical vreg, so a +slot-preserving cast can lower one row-local result at a time. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 8, group_size = 64} + : index -> !pto.vmi.mask<512xpred> +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> +pto.vmi.group_store %sum16, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum32: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum16: + !pto.vmi.vreg<512xf16, #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%one_b16 = pto.pge_b16 "PAT_VL1" + +// The compiler emits this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum32_r = pto.vcadd %p_r, %block8 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + +// Only lane 0 is semantic. EVEN keeps f32 lane 0 in f16 lane 0; all other +// lanes are non-semantic for group_slots(num_groups=8, slots=1). +%sum16_r = pto.vcvt %sum32_r, %one_b32 {part = "EVEN", rnd = "R", sat = "SAT"} + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> + +pto.vsts %sum16_r, %out[%group_tile_off_r], %one_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + truncf(reduce(base[off + r * 64 + 0 .. off + r * 64 + 63])) +``` + +Required assignment rule: + +```text +Group-slot casts are layout-specific. `slots = 1` may use a slot-preserving +row-local cast because each semantic scalar is lane 0 of its own physical vreg. +This does not legalize packed `slots = 8` casts from section 3.13. +``` + +### 3.35 `group_slots` Fanout To `group_store` And `group_broadcast` + +This case fixes the fanout rule for sparse values. A `group_slots` value may +feed multiple group-aware consumers directly. Layout assignment must not +materialize it as dense just because one later use broadcasts it. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} + +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask_for_reduce: + !pto.vmi.mask<128xb32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%b, %y: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%y_for_reduce: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> +``` + +VPTO lowering result for one full 8-row tile: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%x0 = pto.vlds %base[%tile_off] + : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%tile_off_plus_64] + : !pto.ptr, index -> !pto.vreg<64xf32> + +// ensure_layout for the first group_reduce. +%x_lo, %x_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%lo_sum = pto.vcgadd %x_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%hi_sum = pto.vcgadd %x_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + +// First sparse consumer: store the group slots without changing layout. +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Second sparse consumer: materialize only this use as dense grouped data. +%broadcast_idx0 = compute index vector [0 repeated 16, 1 repeated 16, + 2 repeated 16, 3 repeated 16] + : !pto.vreg<64xi32> +%broadcast_idx1 = compute index vector [4 repeated 16, 5 repeated 16, + 6 repeated 16, 7 repeated 16] + : !pto.vreg<64xi32> +%b0 = pto.vselr %sum_block, %broadcast_idx0 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b1 = pto.vselr %sum_block, %broadcast_idx1 + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y0 = pto.vmul %x0, %b0, %all_b32 : !pto.vreg<64xf32> +%y1 = pto.vmul %x1, %b1, %all_b32 : !pto.vreg<64xf32> + +// ensure_layout for the second group_reduce. +%y_lo, %y_hi = pto.vdintlv %y0, %y1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%y_lo_sum = pto.vcgadd %y_lo, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%y_hi_sum = pto.vcgadd %y_hi, %all_b32 + : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +%ysum_block = pto.vadd %y_lo_sum, %y_hi_sum, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(row_r[0..15]) + sum_out[group_off + r] = s + out[group_off + r] = reduce_i(row_r[i] * s for i = 0..15) +``` + +Required assignment rule: + +```text +`%sum` keeps one assigned layout: + #pto.vmi.layout + +`group_store` consumes that sparse layout directly. +`group_broadcast` is a use-site materialization to a dense layout. It must not +rewrite the defining `group_reduce` result or the sibling `group_store` use. +``` + +### 3.36 Same Scalar Source Materialized As `slots = 8` And `slots = 1` + +The same memory scalar stream may be used by both packed S=16 group-slot +compute and row-local S=64 group-slot compute. The two uses require different +logical vector shapes and different sparse layouts, so the source must be +rematerialized as two VMI values. There is no single `group_slots` layout that +serves both uses. + +VMI input: + +```text +%rhs16 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> +%x16 = pto.vmi.load %base16[%off16] + : memref<128xf32> -> !pto.vmi.vreg<128xf32> +%sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8} +%out16v = pto.vmi.addf %sum16, %rhs16 +pto.vmi.group_store %out16v, %out16[%group_off16], %c1 {num_groups = 8} + +%rhs64 = pto.vmi.group_slot_load %rhs_base[%rhs_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<512xf32> +%x64 = pto.vmi.load %base64[%off64] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum64 = pto.vmi.group_reduce_addf %x64, %mask64 {num_groups = 8} +%out64v = pto.vmi.addf %sum64, %rhs64 +pto.vmi.group_store %out64v, %out64[%group_off64], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%rhs16, %sum16, %out16v: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x16, %mask16: + #pto.vmi.layout + +%rhs64, %sum64, %out64v: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%x64, %mask64: + #pto.vmi.layout +``` + +VPTO lowering result: + +```text +// Packed S=16 RHS: one 32B scalar block in lanes 0..7. +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" +%rhs16_block = pto.vsldb %rhs_base[%rhs_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// S=16 reduction is the section 3.5.1 shape. +%x16_lo, %x16_hi = pto.vldsx2 %base16[%tile_off16], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%s16_lo = pto.vcgadd %x16_lo, PAT_ALL_B32 : !pto.vreg<64xf32> +%s16_hi = pto.vcgadd %x16_hi, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum16_block = pto.vadd %s16_lo, %s16_hi, %slot8 : !pto.vreg<64xf32> +%out16_block = pto.vadd %sum16_block, %rhs16_block, %slot8 + : !pto.vreg<64xf32> +pto.vsts %out16_block, %out16[%group_off16], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + +// Row-local S=64 RHS: rematerialize the same scalar stream into one lane-0 +// value per physical row-local result. +%rhs64_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +// Emit this row-local reduction/add/store shape for r = 0..7. +%x64_r = pto.vlds %base64[%row_off64_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p64_r = pto.vcgadd %x64_r, PAT_ALL_B32 : !pto.vreg<64xf32> +%sum64_r = pto.vcadd %p64_r, PAT_VL8_B32 : !pto.vreg<64xf32> +%out64_r = pto.vadd %sum64_r, %rhs64_r, %one_b32 : !pto.vreg<64xf32> +pto.vsts %out64_r, %out64[%group_off64_plus_r], %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out16[group_off16 + r] = reduce(base16[row_r, 0..15]) + rhs_base[rhs_off + r] + out64[group_off64 + r] = reduce(base64[row_r, 0..63]) + rhs_base[rhs_off + r] +``` + +Required assignment rule: + +```text +`group_slot_load` is cheaply rematerializable. If two use sites request +different `group_slots` layouts, clone/rematerialize the load per use. Do not +invent a common layout or make `vmi-to-vpto` inspect both users. +``` + +### 3.37 S=64 `group_store` With Non-Unit Output Stride + +Packed `slots = 8` stores currently require unit output stride. Row-local +`slots = 1` does not have that restriction because each group scalar is stored +by a separate lane-0 store. + +VMI input: + +```text +%row_stride = arith.index_cast %ld : i64 to index +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %row_stride {num_groups = 8} +``` + +Assigned layouts: + +```text +%x: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%block8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +// Emit this row-local sequence for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<64xf32> +%p_r = pto.vcgadd %x_r, %all_b32 : !pto.vreg<64xf32> +%sum_r = pto.vcadd %p_r, %block8 : !pto.vreg<64xf32> + +%dst_r = %out + %group_off + r * %row_stride +pto.vsts %sum_r, %dst_r, %one_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r * row_stride] = reduce(row_r[0..63]) +``` + +Required assignment rule: + +```text +If `group_store` has non-unit row_stride and the source can legally use +`slots = 1`, assignment may select `slots = 1` to keep the store legal. If the +source is fixed to `slots = 8`, the current target plan must diagnose unless a +strided packed store materializer is registered. +``` + +### 3.38 Multi-Tile S=32 `group_reduce` + +The S=32 plan is not only a one-tile special case. For more than eight groups, +layout assignment keeps the same layout and `vmi-to-vpto` emits the same +8-row tile recipe for each physical tile. + +VMI input: + +```text +%x = pto.vmi.load %base[%off] + : memref<512xf32> -> !pto.vmi.vreg<512xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 16, group_size = 32} + : index -> !pto.vmi.mask<512xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 16} +``` + +Assigned layouts: + +```text +%x, %mask: + !pto.vmi.vreg<512xf32, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<512xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +// Emit this shape for tile t = 0 and tile t = 1. +// Each tile covers eight 32-f32 rows. +%tile_base_t = %base + %off + t * 256 +%tile_out_t = %out + %group_off + t * 8 + +%x_even_0_t, %x_odd_0_t = pto.vldsx2 %tile_base_t[%c0], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1_t, %x_odd_1_t = pto.vldsx2 %tile_base_t[%c128], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%x_p0_t, %x_p2_t = pto.vdintlv %x_even_0_t, %x_even_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1_t, %x_p3_t = pto.vdintlv %x_odd_0_t, %x_odd_1_t + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0_t = pto.vcgadd %x_p0_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s1_t = pto.vcgadd %x_p1_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s2_t = pto.vcgadd %x_p2_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s3_t = pto.vcgadd %x_p3_t, PAT_ALL_B32 : !pto.vreg<64xf32> +%s01_t = pto.vadd %s0_t, %s1_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%s23_t = pto.vadd %s2_t, %s3_t, PAT_VL8_B32 : !pto.vreg<64xf32> +%sum_block_t = pto.vadd %s01_t, %s23_t, PAT_VL8_B32 + : !pto.vreg<64xf32> + +pto.vsts %sum_block_t, %tile_out_t, PAT_VL8_B32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..15: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For `group_slots(num_groups = 16, slots = 8)`, the physical arity is +`num_groups / slots = 2`. The type conversion must expose two packed result +blocks in group order. `group_store` stores both blocks with offsets +`group_off + 0` and `group_off + 8`. +``` + +### 3.39 Strided S=32 `group_load` Through Broadcast And Second Reduce + +Section 3.27 covers strided S=32 `group_load -> group_reduce -> group_store`. +This case adds the missing dense continuation. The important layout fact is +that a strided block load naturally produces +`deinterleaved = 4, block_elems = 8`; `group_broadcast` must materialize the +broadcast into that same block-fragment layout when the broadcast feeds +elementwise compute and another S=32 group reduction. + +VMI input: + +```text +%stride40 = arith.constant 40 : index +%x = pto.vmi.group_load %base[%off], %stride40 + {num_groups = 8, group_size = 32} + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +%b = pto.vmi.group_broadcast %sum {num_groups = 8} +%y = pto.vmi.mulf %x, %b +%ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8} +pto.vmi.group_store %ysum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask, %b, %y: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%sum, %ysum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%stride_blocks = %c5_i16 // 40 f32 = 5 * 32B blocks. + +%x_p0 = pto.vsldb %base_frag0, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p1 = pto.vsldb %base_frag1, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p2 = pto.vsldb %base_frag2, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +%x_p3 = pto.vsldb %base_frag3, %stride_blocks, %c0_i16, %all_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +%lane_id = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%broadcast_idx = pto.vshrs %lane_id, %c3_i16, %all_b32 + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> + +// Materialize the same per-row scalar into every 32B row fragment. The four +// bundle entries have the same lane contents, but the result layout remains +// deinterleaved=4, block_elems=8 because the consumer `%y = mulf %x, %b` +// operates on the block-fragment layout. +%b_p0 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p1 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p2 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> +%b_p3 = pto.vselr %sum_block, %broadcast_idx + : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + +%y_p0 = pto.vmul %x_p0, %b_p0, %all_b32 : !pto.vreg<64xf32> +%y_p1 = pto.vmul %x_p1, %b_p1, %all_b32 : !pto.vreg<64xf32> +%y_p2 = pto.vmul %x_p2, %b_p2, %all_b32 : !pto.vreg<64xf32> +%y_p3 = pto.vmul %x_p3, %b_p3, %all_b32 : !pto.vreg<64xf32> + +%ys0 = pto.vcgadd %y_p0, %all_b32 : !pto.vreg<64xf32> +%ys1 = pto.vcgadd %y_p1, %all_b32 : !pto.vreg<64xf32> +%ys2 = pto.vcgadd %y_p2, %all_b32 : !pto.vreg<64xf32> +%ys3 = pto.vcgadd %y_p3, %all_b32 : !pto.vreg<64xf32> +%ys01 = pto.vadd %ys0, %ys1, %slot8 : !pto.vreg<64xf32> +%ys23 = pto.vadd %ys2, %ys3, %slot8 : !pto.vreg<64xf32> +%ysum_block = pto.vadd %ys01, %ys23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %ysum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + s = reduce(base[off + r * 40 + 0 .. off + r * 40 + 31]) + out[group_off + r] = + reduce_i(base[off + r * 40 + i] * s for i = 0..31) +``` + +Required assignment rule: + +```text +`block_elems` is part of dense layout compatibility. A broadcast result feeding +an elementwise op with `%x : deinterleaved=4, block_elems=8` must also be +assigned `deinterleaved=4, block_elems=8`. Reusing a +`deinterleaved=4, block_elems=1` broadcast would be a layout mismatch even +though both have four physical parts. +``` + +### 3.40 Scalar Broadcast Feeding Dense And Grouped Users + +This case fixes the rule for ordinary scalar broadcasts. A scalar broadcast is +not born with a physical layout. Layout assignment may either rematerialize it +per use, or assign the transfer-equivalent producer chain to the non-contiguous +layout requested by the grouped consumer and insert an explicit materialization +at the dense store use. The latter is the concrete plan below. + +VMI input: + +```text +%scale = pto.vmi.broadcast %scale_s + : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.load %base[%off] + : memref<256xf32> -> !pto.vmi.vreg<256xf32> + +%copy = pto.vmi.addf %x, %scale +pto.vmi.store %copy, %copy_out[%off] + +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%prod = pto.vmi.mulf %x, %scale +%sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %scale, %copy, %prod: + !pto.vmi.vreg<256xf32, + #pto.vmi.layout> + +%copy_dense = pto.vmi.ensure_layout %copy: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask: + !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// The shared load is assigned deinterleaved=4, block_elems=8 because the +// grouped consumer dominates the useful compute layout. +%x0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%x1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%x2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%x3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%scale_p0 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p1 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p2 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%scale_p3 = pto.vdup %scale_s, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +// Dense store use: compute in deinterleaved=4, then ensure_layout materializes +// the contiguous memory order for the external effect. +%copy_p0 = pto.vadd %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%copy_p1 = pto.vadd %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%copy_p2 = pto.vadd %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%copy_p3 = pto.vadd %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%c01_lo, %c01_hi = pto.vintlv %copy_p0, %copy_p2 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%c23_lo, %c23_hi = pto.vintlv %copy_p1, %copy_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy0, %copy1 = pto.vintlv %c01_lo, %c23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%copy2, %copy3 = pto.vintlv %c01_hi, %c23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +pto.vsts %copy0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %copy3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// Grouped use: reuse the same deinterleaved operands directly. +%prod_p0 = pto.vmul %x_p0, %scale_p0, %all_b32 : !pto.vreg<64xf32> +%prod_p1 = pto.vmul %x_p1, %scale_p1, %all_b32 : !pto.vreg<64xf32> +%prod_p2 = pto.vmul %x_p2, %scale_p2, %all_b32 : !pto.vreg<64xf32> +%prod_p3 = pto.vmul %x_p3, %scale_p3, %all_b32 : !pto.vreg<64xf32> + +%s0 = pto.vcgadd %prod_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %prod_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %prod_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %prod_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + scale_s + +for r = 0..7: + sum_out[group_off + r] = + reduce_i(base[off + r * 32 + i] * scale_s for i = 0..31) +``` + +Required assignment rule: + +```text +`broadcast` is layout-transparent and cheaply rematerializable, but assignment +does not have to force a separate contiguous broadcast just because a dense +store exists. It may choose a common deinterleaved compute layout for +transfer-equivalent elementwise ops and insert `ensure_layout` at the dense +store. The required invariant is that this choice is explicit in the assigned +IR; `vmi-to-vpto` must not infer it by inspecting both users. +``` + +### 3.41 Non-Rematerializable Value With Incompatible Users + +This is the non-cheap counterpart to section 3.18. A `masked_load` has explicit +mask and passthrough semantics, so layout assignment should not clone it as a +normal cheap load unless the registry explicitly marks that clone legal. The +conflict is solved by inserting `ensure_layout` at one use site. + +VMI input: + +```text +%mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + +pto.vmi.store %x, %copy_out[%off] + +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %sum_out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %zero for masked_load/store: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%mask for masked_load/store: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +%zero0 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero1 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero2 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> +%zero3 = pto.vdup %c0_f32, %all_b32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%l0 = pto.vlds %base[%off] : !pto.ptr, index -> !pto.vreg<64xf32> +%l1 = pto.vlds %base[%off_plus_64] : !pto.ptr, index -> !pto.vreg<64xf32> +%l2 = pto.vlds %base[%off_plus_128] : !pto.ptr, index -> !pto.vreg<64xf32> +%l3 = pto.vlds %base[%off_plus_192] : !pto.ptr, index -> !pto.vreg<64xf32> + +%x0 = pto.vsel %l0, %zero0, %all_b32 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %zero1, %all_b32 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %zero2, %all_b32 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %zero3, %all_b32 : !pto.vreg<64xf32> + +pto.vsts %x0, %copy_out[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %x1, %copy_out[%off_plus_64], %all_b32 {dist = "NORM_B32"} +pto.vsts %x2, %copy_out[%off_plus_128], %all_b32 {dist = "NORM_B32"} +pto.vsts %x3, %copy_out[%off_plus_192], %all_b32 {dist = "NORM_B32"} + +// ensure_layout contiguous -> deinterleaved=4 at the reduce use. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for i = 0..255: + copy_out[off + i] = base[off + i] + +for r = 0..7: + sum_out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +For non-rematerializable producers, assignment must insert a registered +use-site materialization plan, such as contiguous -> deinterleaved=4. If no +plan exists, it must diagnose at assignment time. `vmi-to-vpto` must not clone +the masked_load or choose a materialization after seeing both users. +``` + +### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator + +Section 3.22 covers dense loop-carried values. Sparse group-slot values need a +separate case because the loop-carried block argument has no dense lane +semantics outside the live group slots. + +VMI input: + +```text +%acc0 = pto.vmi.group_slot_load %init[%group_off], %c1 {num_groups = 8} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + +%acc = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg = %acc0) -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.group_load %base[%tile_off_k], %c16 + {num_groups = 8, group_size = 16} + : !pto.ptr, index -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + %next = pto.vmi.addf %arg, %sum + scf.yield %next : !pto.vmi.vreg<128xf32> +} + +pto.vmi.group_store %acc, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%acc0, %arg, %sum, %next, %acc: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x: + !pto.vmi.vreg<128xf32, + #pto.vmi.layout> + +%mask: + !pto.vmi.mask<128xpred, + #pto.vmi.layout> +``` + +VPTO lowering result: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" +%one_b32 = pto.pge_b32 "PAT_VL1" + +%acc0_block = pto.vsldb %init[%group_off], %c0_i16, %c0_i16, %one_b32 + : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> + +%acc_block = scf.for %k = %c0 to %steps step %c1 + iter_args(%arg_block = %acc0_block) -> !pto.vreg<64xf32> { + %lo, %hi = pto.vldsx2 %base[%tile_off_k], "BDINTLV" + : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %lo_sum = pto.vcgadd %lo, %all_b32 : !pto.vreg<64xf32> + %hi_sum = pto.vcgadd %hi, %all_b32 : !pto.vreg<64xf32> + %sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> + %next_block = pto.vadd %arg_block, %sum_block, %slot8 : !pto.vreg<64xf32> + scf.yield %next_block : !pto.vreg<64xf32> +} + +pto.vsts %acc_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + init[group_off + r] + + sum_k reduce(base[tile_k, row_r, 0..15]) +``` + +Required assignment rule: + +```text +Loop-carried `group_slots` values are valid. The iter_arg, body block +argument, yield operand, loop result, and final group_store operand all carry +the same `group_slots(num_groups=8, slots=8)` layout. Ordinary dense consumers +inside the loop still require an explicit `group_broadcast` or diagnostic. +``` + +### 3.43 Internal Function Argument Boundary Materialization + +Section 3.25 covers a private function returning a VMI value. A callee argument +is the other direction of the same ABI problem: the callee body may require a +layout that is different from the layout naturally produced at a call site. + +The current implementation keeps the internal function VMI signature +contiguous and makes the callee-entry materialization explicit with +`ensure_layout` / `ensure_mask_layout`. This is less aggressive than +specializing the VMI function signature to `deinterleaved = 4`, but it preserves +the same invariant: after layout assignment, `vmi-to-vpto` lowers only from +explicit type and helper information and does not inspect the callee body while +lowering a call. + +VMI input: + +```text +func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, %group_off: index) { + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} + pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} + return +} + +func.func @caller(%base: !pto.ptr, %off: index, + %out: !pto.ptr, %group_off: index) { + %x = pto.vmi.load %base[%off] + : !pto.ptr, index -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %group_off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return +} +``` + +Assigned layouts: + +```text +@consume argument %x: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +@consume argument %mask: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +inside @consume: + %x_split = pto.vmi.ensure_layout %x + : #pto.vmi.layout + -> #pto.vmi.layout + + %mask_split = pto.vmi.ensure_mask_layout %mask + : #pto.vmi.layout + -> #pto.vmi.layout + +@caller %x and %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering result for the function boundary: + +```text +func.func private @consume(%x_p0: !pto.vreg<64xf32>, + %x_p1: !pto.vreg<64xf32>, + %x_p2: !pto.vreg<64xf32>, + %x_p3: !pto.vreg<64xf32>, + %m0: !pto.mask, + %m1: !pto.mask, + %m2: !pto.mask, + %m3: !pto.mask, + %out: !pto.ptr, + %group_off: index) { + // Callee-entry lowering of ensure_layout contiguous -> deinterleaved=4, + // block_elems=8. + %x01_lo, %x01_hi = pto.vdintlv %x_p0, %x_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x23_lo, %x23_hi = pto.vdintlv %x_p2, %x_p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d0, %x_d2 = pto.vdintlv %x01_lo, %x23_lo + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %x_d1, %x_d3 = pto.vdintlv %x01_hi, %x23_hi + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %m01_lo, %m01_hi = pto.pdintlv_b32 %m0, %m1 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m23_lo, %m23_hi = pto.pdintlv_b32 %m2, %m3 + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d0, %m_d2 = pto.pdintlv_b32 %m01_lo, %m23_lo + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + %m_d1, %m_d3 = pto.pdintlv_b32 %m01_hi, %m23_hi + : !pto.mask, !pto.mask -> !pto.mask, !pto.mask + + %slot8 = pto.pge_b32 "PAT_VL8" + %s0 = pto.vcgadd %x_d0, %m_d0 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s1 = pto.vcgadd %x_d1, %m_d1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s2 = pto.vcgadd %x_d2, %m_d2 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s3 = pto.vcgadd %x_d3, %m_d3 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> + %s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> + %sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + return +} + +func.func @caller(...) { + // Caller keeps the load and group mask in the contiguous function ABI layout. + %x0 = pto.vlds %base[%off] : !pto.ptr -> !pto.vreg<64xf32> + %x1 = pto.vlds %base[%off_plus_64] : !pto.ptr -> !pto.vreg<64xf32> + %x2 = pto.vlds %base[%off_plus_128] : !pto.ptr -> !pto.vreg<64xf32> + %x3 = pto.vlds %base[%off_plus_192] : !pto.ptr -> !pto.vreg<64xf32> + + %m0 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m1 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m2 = pto.pset_b32 "PAT_ALL" : !pto.mask + %m3 = pto.pset_b32 "PAT_ALL" : !pto.mask + + call @consume(%x0, %x1, %x2, %x3, %m0, %m1, %m2, %m3, %out, %group_off) + : (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.mask, !pto.mask, + !pto.mask, !pto.mask, !pto.ptr, index) -> () + return +} +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 31]) +``` + +Required assignment rule: + +```text +Private function boundary layout is explicit in the assigned function type and +callee-entry helpers. The current endpoint chooses a contiguous VMI function +ABI and inserts callee-entry materialization for the grouped body requirement. +`vmi-to-vpto` does not inspect the callee body while lowering the call and does +not inspect callers while lowering the callee block argument. + +Future optimization may specialize private VMI function signatures directly to +`deinterleaved = 4, block_elems = 8` when all call sites agree. That +optimization must still be expressed in the assigned VMI function type before +`vmi-to-vpto` runs. +``` + +### 3.44 `masked_load` Grouped Tail Feeding S=32 Reduce + +This case connects the explicit `masked_load` tail model from section 3.30 with +grouped reduction. The load has no padding constant hidden in the op; inactive +lanes are provided by the passthrough value and excluded from the reduction by +the same grouped mask. + +VMI input: + +```text +%c25 = arith.constant 25 : index +%mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%zero = pto.vmi.broadcast %c0_f32 : f32 -> !pto.vmi.vreg<256xf32> +%x = pto.vmi.masked_load %base[%off], %mask, %zero + : memref<256xf32>, !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xpred, #pto.vmi.layout> + +%zero, %x for masked_load: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_for_reduce = pto.vmi.ensure_layout %x: + #pto.vmi.layout + -> #pto.vmi.layout + +%mask_for_reduce: + pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + -> !pto.vmi.mask<256xpred, + #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Current implementation result: + +```text +VMI-UNSUPPORTED: pto.vmi.group_reduce_addf s32 block8 lowering does not yet +support partial create_group_mask active_elems_per_group during layout +assignment +``` + +This must remain a layout-assignment diagnostic until the S=32 block8 +grouped-mask lowering is proven against runtime SIM. Assignment must not write +`vmi.selected_plan = "s32_reduce_block8_stride"` for this case and leave +`vmi-to-vpto` to discover the partial mask by walking the mask defining op. A +`masked_load` can be lowered contiguously and then materialized to +`deinterleaved = 4, block_elems = 8`, but the grouped reduce still needs a +physically correct `create_group_mask` for `active_elems_per_group = 25`. +Allowing the current S=32 block8 path to proceed would not preserve the logical +memory result below. + +Intended VPTO lowering shape after the grouped-mask issue is fixed: + +```text +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8 = pto.pge_b32 "PAT_VL8" + +// masked_load direct lowering stays contiguous. +%m0, %m1, %m2, %m3 = materialize contiguous create_group_mask(c25, S=32) +%z0, %z1, %z2, %z3 = vdup zero +%l0 = pto.vlds %base[%off] +%l1 = pto.vlds %base[%off_plus_64] +%l2 = pto.vlds %base[%off_plus_128] +%l3 = pto.vlds %base[%off_plus_192] +%x0 = pto.vsel %l0, %z0, %m0 : !pto.vreg<64xf32> +%x1 = pto.vsel %l1, %z1, %m1 : !pto.vreg<64xf32> +%x2 = pto.vsel %l2, %z2, %m2 : !pto.vreg<64xf32> +%x3 = pto.vsel %l3, %z3, %m3 : !pto.vreg<64xf32> + +// ensure_layout contiguous -> deinterleaved=4, block_elems=8. +%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 +%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 +%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo +%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi + +// Correct deinterleaved grouped mask for active columns 0..24: +// part 0 covers columns 0..7 for every row: all active +// part 1 covers columns 8..15 for every row: all active +// part 2 covers columns 16..23 for every row: all active +// part 3 covers columns 24..31 for every row: one active lane per row +%mask_p0 = pto.pset_b32 "PAT_ALL" +%mask_p1 = pto.pset_b32 "PAT_ALL" +%mask_p2 = pto.pset_b32 "PAT_ALL" +%mask_p3 = materialize one lane per 8-lane row block + +%s0 = pto.vcgadd %x_p0, %mask_p0 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x_p1, %mask_p1 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x_p2, %mask_p2 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x_p3, %mask_p3 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8 : !pto.vreg<64xf32> +%sum_block = pto.vadd %s01, %s23, %slot8 : !pto.vreg<64xf32> + +pto.vsts %sum_block, %out[%group_off], %slot8 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = + reduce(base[off + r * 32 + 0 .. off + r * 32 + 24]) +``` + +Required assignment rule: + +```text +`masked_load` and `group_reduce` must share the same grouped mask layout. The +passthrough value defines inactive loaded lanes, while the reduce mask defines +participation. Assignment may select a deinterleaved S=32 load plan only when +the rounded physical reads are memory-safe; otherwise it must diagnose or use a +future stable gather fallback. + +Current implementation additionally diagnoses the S=32 block8 partial grouped +mask itself. This is deliberate: the case is not implemented until the +deinterleaved grouped-mask materialization and `vcgadd` interpretation are +validated end to end by SIM. +``` diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index da8428dd23..e8c44a1454 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -16,7 +16,9 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { let summary = "VMI logical vector register layout"; let parameters = (ins StringRefParameter<"layout kind">:$kind, - "int64_t":$factor + "int64_t":$factor, + "int64_t":$blockElems, + "int64_t":$slots ); let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; @@ -24,9 +26,11 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { let extraClassDeclaration = [{ static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, - int64_t factor); + int64_t factor, + int64_t blockElems = 1); static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, - int64_t numGroups); + int64_t numGroups, + int64_t slots = 0); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 7bd7524118..80036f9946 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -76,6 +76,22 @@ def VMICreateMaskOp : VMI_Op<"create_mask"> { let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; } +def VMICreateGroupMaskOp : VMI_Op<"create_group_mask"> { + let summary = "Create a VMI logical grouped predicate mask"; + let description = [{ + Creates a mask where lane i is active iff + `(i % group_size) < active_elems_per_group`. + }]; + let arguments = (ins + Index:$active_elems_per_group, + I64Attr:$num_groups, + I64Attr:$group_size + ); + let results = (outs VMI_MaskTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$active_elems_per_group attr-dict `:` type($active_elems_per_group) `->` type($result)"; +} + def VMIConstantMaskOp : VMI_Op<"constant_mask"> { let summary = "VMI logical predicate mask constant"; let arguments = (ins AnyAttr:$value); @@ -437,7 +453,8 @@ def VMIBitcastOp : VMI_Op<"bitcast"> { def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical vector load"; - let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + OptionalAttr:$full_read_elems); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; @@ -452,6 +469,15 @@ def VMIGroupLoadOp : VMI_Op<"group_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load one scalar value per logical group into group slots"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$source_group_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $source_group_stride attr-dict `:` type($source) `->` type($result)"; +} + def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector load with passthrough lanes"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 3b50fff4c7..2a966b3eb8 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/InliningUtils.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -120,6 +121,27 @@ static bool isKnownZeroOrUnitExtent(int64_t value); static bool isByteIntegerType(Type ty); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name, bool allowLowPrecision = false); + +namespace { +struct PTOInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); @@ -2668,6 +2690,8 @@ void PTODialect::initialize() { #define GET_ATTRDEF_LIST #include "PTO/IR/PTOAttrs.cpp.inc" >(); + + addInterfaces(); } diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index e26982e347..ff7170044e 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMI.cpp - PTO VMI type and attribute support -----------------------===// //===----------------------------------------------------------------------===// @@ -36,8 +38,7 @@ static std::string formatVMIVRegType(int64_t elementCount, Type elementType, } static std::string formatVMIMaskType(int64_t elementCount, - StringRef granularity, - Attribute layout) { + StringRef granularity, Attribute layout) { std::string result; llvm::raw_string_ostream os(result); os << "!pto.vmi.mask<" << elementCount << "x" << granularity; @@ -146,6 +147,13 @@ static FailureOr getLayoutFactor(Type type) { return (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; } +static FailureOr getLayoutBlockElems(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; +} + static FailureOr getPhysicalLanesPerPart(Type type) { if (auto vregType = dyn_cast(type)) return getDataLanesPerPart(vregType.getElementType()); @@ -172,26 +180,29 @@ static bool isLayoutAssigned(VMIMaskType type) { return static_cast(type.getLayoutAttr()); } -static LogicalResult verifyAllSameVRegShapeAndLayout(Operation *op, - ArrayRef types, - bool requireSameElement) { +static LogicalResult +verifyAllSameVRegShapeAndLayout(Operation *op, ArrayRef types, + bool requireSameElement) { if (types.empty()) return success(); VMIVRegType first = types.front(); - bool anyLayout = llvm::any_of(types, [](VMIVRegType type) { - return isLayoutAssigned(type); - }); + bool anyLayout = llvm::any_of( + types, [](VMIVRegType type) { return isLayoutAssigned(type); }); for (VMIVRegType type : types) { if (type.getElementCount() != first.getElementCount()) - return op->emitOpError("requires all VMI data values to have the same logical lane count"); + return op->emitOpError( + "requires all VMI data values to have the same logical lane count"); if (requireSameElement && type.getElementType() != first.getElementType()) - return op->emitOpError("requires all VMI data values to have the same element type"); + return op->emitOpError( + "requires all VMI data values to have the same element type"); if (anyLayout && !isLayoutAssigned(type)) - return op->emitOpError("requires either all or no VMI data values to carry layout"); + return op->emitOpError( + "requires either all or no VMI data values to carry layout"); if (anyLayout && type.getLayout() != first.getLayout()) - return op->emitOpError("requires all layout-assigned VMI data values to have the same layout"); + return op->emitOpError("requires all layout-assigned VMI data values to " + "have the same layout"); } return success(); } @@ -203,8 +214,7 @@ static LogicalResult verifyElementwiseVRegOp(Operation *op, VMIVRegType lhs, /*requireSameElement=*/true); } -static LogicalResult verifyFloatUnaryVRegOp(Operation *op, - VMIVRegType source, +static LogicalResult verifyFloatUnaryVRegOp(Operation *op, VMIVRegType source, VMIVRegType result) { if (!isVMIFloatLikeType(source.getElementType())) return op->emitOpError("requires floating-point-like VMI element type"); @@ -221,15 +231,15 @@ static LogicalResult verifyFloatTernaryVRegOp(Operation *op, VMIVRegType lhs, /*requireSameElement=*/true); } -static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( - Operation *op, ArrayRef types) { +static LogicalResult +verifyAllSameMaskShapeLayoutAndGranularity(Operation *op, + ArrayRef types) { if (types.empty()) return success(); VMIMaskType first = types.front(); - bool anyLayout = llvm::any_of(types, [](VMIMaskType type) { - return isLayoutAssigned(type); - }); + bool anyLayout = llvm::any_of( + types, [](VMIMaskType type) { return isLayoutAssigned(type); }); for (VMIMaskType type : types) { if (type.getElementCount() != first.getElementCount()) @@ -252,11 +262,13 @@ static LogicalResult verifyAllSameMaskShapeLayoutAndGranularity( static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, VMIVRegType dataType) { if (maskType.getElementCount() != dataType.getElementCount()) - return op->emitOpError("requires mask logical lane count to match data lane count"); + return op->emitOpError( + "requires mask logical lane count to match data lane count"); if (isLayoutAssigned(maskType) || isLayoutAssigned(dataType)) { if (!isLayoutAssigned(maskType) || !isLayoutAssigned(dataType)) - return op->emitOpError("requires either both mask and data to carry layout or neither to carry layout"); + return op->emitOpError("requires either both mask and data to carry " + "layout or neither to carry layout"); if (maskType.getLayout() != dataType.getLayout()) return op->emitOpError("requires mask layout to match data layout"); } @@ -268,7 +280,8 @@ static LogicalResult verifyMaskMatchesData(Operation *op, VMIMaskType maskType, int64_t maskBitWidth = getMaskGranularityBitWidth(maskType.getGranularity()); if (elementBitWidth != 0 && maskBitWidth != 0 && elementBitWidth != static_cast(maskBitWidth)) - return op->emitOpError("requires mask granularity to match data element width"); + return op->emitOpError( + "requires mask granularity to match data element width"); return success(); } @@ -288,9 +301,8 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, if (!memoryElementType) return success(); if (memoryElementType != dataType.getElementType()) - return op->emitOpError() - << "requires memory " << role - << " element type to match VMI data element type"; + return op->emitOpError() << "requires memory " << role + << " element type to match VMI data element type"; return success(); } @@ -309,24 +321,26 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, TypeRange physicalTypes) { FailureOr expectedArity = getVMIPhysicalArity(vmiType); if (failed(expectedArity)) - return op->emitOpError("requires a layout-assigned VMI type with computable physical arity"); + return op->emitOpError( + "requires a layout-assigned VMI type with computable physical arity"); if (static_cast(physicalTypes.size()) != *expectedArity) - return op->emitOpError() - << "requires " << *expectedArity << " physical parts, got " - << physicalTypes.size(); + return op->emitOpError() << "requires " << *expectedArity + << " physical parts, got " << physicalTypes.size(); if (auto vregType = dyn_cast(vmiType)) { FailureOr lanesPerPart = getDataLanesPerPart(vregType.getElementType()); if (failed(lanesPerPart)) - return op->emitOpError("requires data element type with known physical lane count"); + return op->emitOpError( + "requires data element type with known physical lane count"); for (Type physicalType : physicalTypes) { auto partType = dyn_cast(physicalType); if (!partType) return op->emitOpError("requires physical data parts to be !pto.vreg"); if (partType.getElementCount() != *lanesPerPart || partType.getElementType() != vregType.getElementType()) - return op->emitOpError("requires physical data part type to match VMI lane-map helper"); + return op->emitOpError( + "requires physical data part type to match VMI lane-map helper"); } return success(); } @@ -335,45 +349,86 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, if (!maskType) return op->emitOpError("requires VMI data or mask type"); if (maskType.isPred()) - return op->emitOpError("requires layout-assigned mask with concrete granularity"); + return op->emitOpError( + "requires layout-assigned mask with concrete granularity"); for (Type physicalType : physicalTypes) { auto partType = dyn_cast(physicalType); if (!partType) return op->emitOpError("requires physical mask parts to be !pto.mask"); if (partType.getGranularity() != maskType.getGranularity()) - return op->emitOpError("requires physical mask part granularity to match VMI mask"); + return op->emitOpError( + "requires physical mask part granularity to match VMI mask"); } return success(); } -static int64_t getLogicalLanesInPart(int64_t elementCount, int64_t factor, - int64_t part) { - if (part < 0 || part >= factor || part >= elementCount) - return 0; - return ((elementCount - 1 - part) / factor) + 1; +static std::optional +mapDenseLogicalLaneToPartIndex(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t logicalLane, + int64_t &part) { + if (logicalLane < 0 || logicalLane >= elementCount || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t block = logicalLane / blockElems; + int64_t inBlockLane = logicalLane % blockElems; + part = block % factor; + int64_t partBlock = block / factor; + return partBlock * blockElems + inBlockLane; +} + +static std::optional +mapDensePartIndexToLogicalLane(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part, + int64_t indexInPart) { + if (part < 0 || part >= factor || indexInPart < 0 || factor <= 0 || + blockElems <= 0) + return std::nullopt; + int64_t partBlock = indexInPart / blockElems; + int64_t inBlockLane = indexInPart % blockElems; + int64_t logicalBlock = partBlock * factor + part; + int64_t logicalLane = logicalBlock * blockElems + inBlockLane; + if (logicalLane >= elementCount) + return std::nullopt; + return logicalLane; +} + +static int64_t getDenseLogicalLanesInPart(int64_t elementCount, int64_t factor, + int64_t blockElems, int64_t part) { + int64_t maxIndex = -1; + for (int64_t lane = 0; lane < elementCount; ++lane) { + int64_t lanePart = 0; + std::optional index = mapDenseLogicalLaneToPartIndex( + elementCount, factor, blockElems, lane, lanePart); + if (index && lanePart == part) + maxIndex = std::max(maxIndex, *index); + } + return maxIndex + 1; } } // namespace VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context) { - return VMILayoutAttr::get(context, "contiguous", 1); + return VMILayoutAttr::get(context, "contiguous", 1, 1, 0); } VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, - int64_t factor) { - return VMILayoutAttr::get(context, "deinterleaved", factor); + int64_t factor, + int64_t blockElems) { + return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0); } VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, - int64_t numGroups) { - return VMILayoutAttr::get(context, "num_groups", numGroups); + int64_t numGroups, int64_t slots) { + return VMILayoutAttr::get(context, "num_groups", numGroups, 1, slots); } Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { SMLoc loc = parser.getCurrentLocation(); StringRef kind; int64_t factor = 1; + int64_t blockElems = 1; + int64_t slots = 0; if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) return {}; @@ -383,9 +438,28 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { } else if (kind == "deinterleaved") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + if (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || field != "block_elems" || + failed(parser.parseEqual()) || + failed(parser.parseInteger(blockElems))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'block_elems = '"); + return {}; + } + } } else if (kind == "num_groups") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; + if (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || field != "slots" || + failed(parser.parseEqual()) || failed(parser.parseInteger(slots))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'slots = '"); + return {}; + } + } } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " @@ -397,39 +471,62 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { return {}; return parser.getChecked(loc, parser.getContext(), kind, - factor); + factor, blockElems, slots); } void VMILayoutAttr::print(AsmPrinter &printer) const { printer << "<" << getKind(); - if (isDeinterleaved() || isGroupSlots()) + if (isDeinterleaved()) { + printer << " = " << getFactor(); + if (getBlockElems() != 1) + printer << ", block_elems = " << getBlockElems(); + } else if (isGroupSlots()) { printer << " = " << getFactor(); + if (getSlots() != 0) + printer << ", slots = " << getSlots(); + } printer << ">"; } LogicalResult VMILayoutAttr::verify(function_ref emitError, - StringRef kind, int64_t factor) { + StringRef kind, int64_t factor, int64_t blockElems, + int64_t slots) { if (kind == "contiguous") { - if (factor != 1) + if (factor != 1 || blockElems != 1 || slots != 0) return emitError() - << "#pto.vmi.layout requires factor to be 1"; + << "#pto.vmi.layout requires factor, block_elems, " + "and slots to be their defaults"; return success(); } if (kind == "deinterleaved") { if (factor != 2 && factor != 4) - return emitError() - << "#pto.vmi.layout expected factor to be 2 or 4"; + return emitError() << "#pto.vmi.layout expected factor to be 2 or 4"; + if (blockElems <= 0) + return emitError() << "#pto.vmi.layout requires block_elems to be positive"; + if (slots != 0) + return emitError() << "#pto.vmi.layout requires slots to be omitted"; return success(); } if (kind == "num_groups") { if (factor <= 0) + return emitError() << "#pto.vmi.layout requires num_groups to be positive"; + if (blockElems != 1) + return emitError() << "#pto.vmi.layout requires block_elems to be 1"; + if (slots < 0 || (slots != 0 && factor % slots != 0)) return emitError() << "#pto.vmi.layout requires num_groups to be positive"; + << ", slots = " << slots + << "> requires slots to be positive and divide num_groups when " + "specified"; return success(); } @@ -451,8 +548,8 @@ Type VMIVRegType::parse(AsmParser &parser) { failed(parser.parseGreater())) return {}; - return parser.getChecked(loc, parser.getContext(), - shape.front(), elementType, layout); + return parser.getChecked(loc, parser.getContext(), shape.front(), + elementType, layout); } void VMIVRegType::print(AsmPrinter &printer) const { @@ -467,25 +564,25 @@ LogicalResult VMIVRegType::verify(function_ref emitError, int64_t elementCount, Type elementType, Attribute layout) { if (elementCount <= 0) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected a positive element count"; if (!isSupportedVMIElementType(elementType)) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected an integer, index, floating-point, or " "PTO low-precision element type"; if (layout && !mlir::isa(layout)) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected layout to be #pto.vmi.layout"; if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { if (layoutAttr.isGroupSlots() && elementCount % layoutAttr.getNumGroups() != 0) - return emitError() << "'" << formatVMIVRegType(elementCount, elementType, - layout) + return emitError() << "'" + << formatVMIVRegType(elementCount, elementType, layout) << "' expected num_groups layout to evenly divide " "the VMI logical lane count"; } @@ -515,8 +612,8 @@ Type VMIMaskType::parse(AsmParser &parser) { failed(parser.parseGreater())) return {}; - return parser.getChecked(loc, parser.getContext(), - shape.front(), granularity, layout); + return parser.getChecked(loc, parser.getContext(), shape.front(), + granularity, layout); } void VMIMaskType::print(AsmPrinter &printer) const { @@ -530,35 +627,35 @@ LogicalResult VMIMaskType::verify(function_ref emitError, int64_t elementCount, StringRef granularity, Attribute layout) { if (elementCount <= 0) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected a positive element count"; if (!isSupportedGranularity(granularity)) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected granularity to be one of pred, b8, b16, " "b32"; if (layout && !mlir::isa(layout)) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' expected layout to be #pto.vmi.layout"; if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { if (layoutAttr.isGroupSlots()) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' mask type must not carry num_groups layout"; } if (granularity == "pred" && layout) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' pred mask must not carry layout"; if (granularity != "pred" && !layout) - return emitError() << "'" << formatVMIMaskType(elementCount, granularity, - layout) + return emitError() << "'" + << formatVMIMaskType(elementCount, granularity, layout) << "' concrete mask granularity requires layout"; return success(); @@ -570,9 +667,11 @@ LogicalResult VMIConstantOp::verify() { if (!denseAttr) return emitOpError("requires dense elements constant attribute"); if (denseAttr.getElementType() != resultType.getElementType()) - return emitOpError("requires dense constant element type to match result element type"); + return emitOpError( + "requires dense constant element type to match result element type"); if (denseAttr.getNumElements() != resultType.getElementCount()) - return emitOpError("requires dense constant element count to match result logical lane count"); + return emitOpError("requires dense constant element count to match result " + "logical lane count"); return success(); } @@ -616,6 +715,22 @@ LogicalResult VMICreateMaskOp::verify() { return success(); } +LogicalResult VMICreateGroupMaskOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + int64_t groupSize = getGroupSizeAttr().getInt(); + if (numGroups <= 0) + return emitOpError("requires positive num_groups"); + if (groupSize <= 0) + return emitOpError("requires positive group_size"); + if (resultType.getElementCount() != numGroups * groupSize) + return emitOpError("requires result lane count to equal num_groups * " + "group_size"); + if (!resultType.isPred() && !isLayoutAssigned(resultType)) + return emitOpError("requires concrete mask result to carry layout"); + return success(); +} + LogicalResult VMIConstantMaskOp::verify() { auto resultType = cast(getResult().getType()); auto denseAttr = dyn_cast(getValue()); @@ -624,7 +739,8 @@ LogicalResult VMIConstantMaskOp::verify() { if (!denseAttr.getElementType().isInteger(1)) return emitOpError("requires dense mask constant element type to be i1"); if (denseAttr.getNumElements() != resultType.getElementCount()) - return emitOpError("requires dense mask constant element count to match result logical lane count"); + return emitOpError("requires dense mask constant element count to match " + "result logical lane count"); return success(); } @@ -655,8 +771,8 @@ LogicalResult VMIMaskXOrOp::verify() { LogicalResult VMIMaskNotOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); - return verifyAllSameMaskShapeLayoutAndGranularity( - getOperation(), {sourceType, resultType}); + return verifyAllSameMaskShapeLayoutAndGranularity(getOperation(), + {sourceType, resultType}); } LogicalResult VMIAddFOp::verify() { @@ -847,7 +963,8 @@ LogicalResult VMINotOp::verify() { auto resultType = cast(getResult().getType()); if (!isVMIIntegerLikeType(sourceType.getElementType())) return emitOpError("requires integer-like VMI element type"); - return verifyAllSameVRegShapeAndLayout(getOperation(), {sourceType, resultType}, + return verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, /*requireSameElement=*/true); } @@ -880,9 +997,9 @@ LogicalResult VMISelectOp::verify() { auto trueType = cast(getTrueValue().getType()); auto falseType = cast(getFalseValue().getType()); auto resultType = cast(getResult().getType()); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {trueType, falseType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {trueType, falseType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -903,9 +1020,9 @@ LogicalResult VMICompressOp::verify() { auto sourceType = cast(getSource().getType()); auto maskType = cast(getMask().getType()); auto resultType = cast(getResult().getType()); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {sourceType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {sourceType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, sourceType); } @@ -970,14 +1087,14 @@ LogicalResult VMIReduceAddFOp::verify() { return verifyMaskMatchesData(getOperation(), maskType, sourceType); } -template -LogicalResult verifyReduceMinMaxFOp(OpTy op) { +template LogicalResult verifyReduceMinMaxFOp(OpTy op) { auto sourceType = cast(op.getSource().getType()); auto initType = cast(op.getInit().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); if (!isVMIFloatLikeType(sourceType.getElementType())) - return op.emitOpError("requires floating-point-like VMI source element type"); + return op.emitOpError( + "requires floating-point-like VMI source element type"); if (sourceType.getElementType() != initType.getElementType() || sourceType.getElementType() != resultType.getElementType()) return op.emitOpError( @@ -991,13 +1108,9 @@ LogicalResult verifyReduceMinMaxFOp(OpTy op) { return verifyMaskMatchesData(op.getOperation(), maskType, sourceType); } -LogicalResult VMIReduceMaxFOp::verify() { - return verifyReduceMinMaxFOp(*this); -} +LogicalResult VMIReduceMaxFOp::verify() { return verifyReduceMinMaxFOp(*this); } -LogicalResult VMIReduceMinFOp::verify() { - return verifyReduceMinMaxFOp(*this); -} +LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } LogicalResult VMIGroupReduceAddFOp::verify() { auto sourceType = cast(getSource().getType()); @@ -1015,17 +1128,25 @@ LogicalResult VMIGroupReduceAddFOp::verify() { if (sourceType.getElementType() != resultType.getElementType()) return emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { - if (!sourceLayout.isContiguous()) + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) return emitOpError( - "requires layout-assigned source to use contiguous layout"); + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); } if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() - << "requires layout-assigned result to use " - "#pto.vmi.layout"; + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; } if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) return failure(); @@ -1044,10 +1165,9 @@ LogicalResult VMIGroupBroadcastOp::verify() { if (auto sourceLayout = sourceType.getLayoutAttr()) { if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() - << "requires layout-assigned source to use " - "#pto.vmi.layout"; + return emitOpError() << "requires layout-assigned source to use " + "#pto.vmi.layout"; } if (auto resultLayout = resultType.getLayoutAttr()) { if (resultLayout.isGroupSlots()) @@ -1062,13 +1182,16 @@ LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result logical lane counts to match"); + return emitOpError( + "requires source and result logical lane counts to match"); if (!isVMIFloatLikeType(sourceType.getElementType()) || !isVMIFloatLikeType(resultType.getElementType())) - return emitOpError("requires floating-point-like source and result element types"); + return emitOpError( + "requires floating-point-like source and result element types"); if (getVMIElementBitWidth(sourceType.getElementType()) >= getVMIElementBitWidth(resultType.getElementType())) - return emitOpError("requires result element type to be wider than source element type"); + return emitOpError( + "requires result element type to be wider than source element type"); return success(); } @@ -1076,13 +1199,16 @@ LogicalResult VMITruncFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result logical lane counts to match"); + return emitOpError( + "requires source and result logical lane counts to match"); if (!isVMIFloatLikeType(sourceType.getElementType()) || !isVMIFloatLikeType(resultType.getElementType())) - return emitOpError("requires floating-point-like source and result element types"); + return emitOpError( + "requires floating-point-like source and result element types"); if (getVMIElementBitWidth(sourceType.getElementType()) <= getVMIElementBitWidth(resultType.getElementType())) - return emitOpError("requires result element type to be narrower than source element type"); + return emitOpError( + "requires result element type to be narrower than source element type"); return success(); } @@ -1114,6 +1240,10 @@ LogicalResult VMIBitcastOp::verify() { } LogicalResult VMILoadOp::verify() { + if (auto fullReadElems = getFullReadElemsAttr()) { + if (fullReadElems.getInt() <= 0) + return emitOpError("requires full_read_elems to be positive"); + } return verifyMemoryElementMatches(getOperation(), getSource().getType(), cast(getResult().getType()), "source"); @@ -1140,6 +1270,28 @@ void VMIGroupLoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIGroupSlotLoadOp::verify() { + auto resultType = cast(getResult().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + return verifyNumGroups(getOperation(), resultType, + getNumGroupsAttr().getInt()); +} + +void VMIGroupSlotLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedLoadOp::verify() { auto maskType = cast(getMask().getType()); auto passthruType = cast(getPassthru().getType()); @@ -1147,9 +1299,9 @@ LogicalResult VMIMaskedLoadOp::verify() { if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), resultType, "source"))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1178,9 +1330,9 @@ LogicalResult VMIGatherOp::verify() { getOperation(), {indicesType, passthruType, resultType}, /*requireSameElement=*/false))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1198,9 +1350,9 @@ LogicalResult VMIExpandLoadOp::verify() { if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), resultType, "source"))) return failure(); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {passthruType, resultType}, - /*requireSameElement=*/true))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {passthruType, resultType}, + /*requireSameElement=*/true))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1212,8 +1364,7 @@ void VMIExpandLoadOp::getEffects( } LogicalResult VMIStoreOp::verify() { - return verifyMemoryElementMatches(getOperation(), - getDestination().getType(), + return verifyMemoryElementMatches(getOperation(), getDestination().getType(), cast(getValue().getType()), "destination"); } @@ -1244,8 +1395,8 @@ LogicalResult VMIMaskedStoreOp::verify() { auto valueType = cast(getValue().getType()); auto maskType = cast(getMask().getType()); if (failed(verifyMemoryElementMatches(getOperation(), - getDestination().getType(), - valueType, "destination"))) + getDestination().getType(), valueType, + "destination"))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, valueType); } @@ -1261,8 +1412,8 @@ LogicalResult VMIScatterOp::verify() { auto indicesType = cast(getIndices().getType()); auto maskType = cast(getMask().getType()); if (failed(verifyMemoryElementMatches(getOperation(), - getDestination().getType(), - valueType, "destination"))) + getDestination().getType(), valueType, + "destination"))) return failure(); auto indexElementType = dyn_cast(indicesType.getElementType()); @@ -1270,9 +1421,9 @@ LogicalResult VMIScatterOp::verify() { indexElementType.isSigned()) return emitOpError("requires signless or unsigned 32-bit integer indices"); - if (failed(verifyAllSameVRegShapeAndLayout( - getOperation(), {valueType, indicesType}, - /*requireSameElement=*/false))) + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {valueType, indicesType}, + /*requireSameElement=*/false))) return failure(); return verifyMaskMatchesData(getOperation(), maskType, valueType); } @@ -1296,8 +1447,7 @@ void VMITileReadOp::getEffects( } LogicalResult VMITileWriteOp::verify() { - return verifyMemoryElementMatches(getOperation(), - getDestination().getType(), + return verifyMemoryElementMatches(getOperation(), getDestination().getType(), cast(getValue().getType()), "destination"); } @@ -1312,16 +1462,20 @@ LogicalResult VMIShuffleOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires result element type to match source element type"); + return emitOpError( + "requires result element type to match source element type"); if (static_cast(getIndices().size()) != resultType.getElementCount()) - return emitOpError("requires shuffle index count to match result logical lane count"); + return emitOpError( + "requires shuffle index count to match result logical lane count"); for (int64_t index : getIndices()) { if (index < 0 || index >= sourceType.getElementCount()) - return emitOpError("requires every shuffle index to select an existing source logical lane"); + return emitOpError("requires every shuffle index to select an existing " + "source logical lane"); } if (isLayoutAssigned(sourceType) || isLayoutAssigned(resultType)) { if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) - return emitOpError("requires either both source and result to carry layout or neither to carry layout"); + return emitOpError("requires either both source and result to carry " + "layout or neither to carry layout"); } return success(); } @@ -1332,26 +1486,32 @@ LogicalResult VMIChannelSplitOp::verify() { return emitOpError("requires at least two channel results"); auto firstResultType = cast(getResults().front().getType()); if (sourceType.getElementCount() != - static_cast(getResults().size()) * firstResultType.getElementCount()) - return emitOpError("requires source lane count to equal result count times per-channel lane count"); + static_cast(getResults().size()) * + firstResultType.getElementCount()) + return emitOpError("requires source lane count to equal result count times " + "per-channel lane count"); for (Value result : getResults()) { auto resultType = cast(result.getType()); if (resultType.getElementCount() != firstResultType.getElementCount() || resultType.getElementType() != sourceType.getElementType()) - return emitOpError("requires every channel result to have equal lane count and source element type"); + return emitOpError("requires every channel result to have equal lane " + "count and source element type"); } bool anyLayout = isLayoutAssigned(sourceType); for (Value result : getResults()) anyLayout |= isLayoutAssigned(cast(result.getType())); if (anyLayout) { if (!isLayoutAssigned(sourceType)) - return emitOpError("requires layout-assigned channel_split source when any channel result has layout"); + return emitOpError("requires layout-assigned channel_split source when " + "any channel result has layout"); for (Value result : getResults()) { auto resultType = cast(result.getType()); if (!isLayoutAssigned(resultType)) - return emitOpError("requires every channel_split result to carry layout when source has layout"); + return emitOpError("requires every channel_split result to carry " + "layout when source has layout"); if (!cast(resultType.getLayout()).isContiguous()) - return emitOpError("requires layout-assigned channel_split results to be contiguous"); + return emitOpError( + "requires layout-assigned channel_split results to be contiguous"); } int64_t channels = getResults().size(); if (channels == 2 || channels == 4) { @@ -1359,7 +1519,8 @@ LogicalResult VMIChannelSplitOp::verify() { auto expectedLayout = VMILayoutAttr::getDeinterleaved(getContext(), channels); if (!sourceLayout.isContiguous() && sourceLayout != expectedLayout) - return emitOpError("requires layout-assigned channel_split source to be contiguous or deinterleaved by result count"); + return emitOpError("requires layout-assigned channel_split source to " + "be contiguous or deinterleaved by result count"); } } return success(); @@ -1374,24 +1535,29 @@ LogicalResult VMIChannelMergeOp::verify() { auto inputType = cast(input.getType()); if (inputType.getElementCount() != firstInputType.getElementCount() || inputType.getElementType() != firstInputType.getElementType()) - return emitOpError("requires all channel inputs to have the same lane count and element type"); + return emitOpError("requires all channel inputs to have the same lane " + "count and element type"); } - if (resultType.getElementCount() != - static_cast(getInputs().size()) * firstInputType.getElementCount() || + if (resultType.getElementCount() != static_cast(getInputs().size()) * + firstInputType.getElementCount() || resultType.getElementType() != firstInputType.getElementType()) - return emitOpError("requires result lane count and element type to match merged channels"); + return emitOpError( + "requires result lane count and element type to match merged channels"); bool anyLayout = isLayoutAssigned(resultType); for (Value input : getInputs()) anyLayout |= isLayoutAssigned(cast(input.getType())); if (anyLayout) { if (!isLayoutAssigned(resultType)) - return emitOpError("requires layout-assigned channel_merge result when any channel input has layout"); + return emitOpError("requires layout-assigned channel_merge result when " + "any channel input has layout"); for (Value input : getInputs()) { auto inputType = cast(input.getType()); if (!isLayoutAssigned(inputType)) - return emitOpError("requires every channel_merge input to carry layout when result has layout"); + return emitOpError("requires every channel_merge input to carry layout " + "when result has layout"); if (!cast(inputType.getLayout()).isContiguous()) - return emitOpError("requires layout-assigned channel_merge inputs to be contiguous"); + return emitOpError( + "requires layout-assigned channel_merge inputs to be contiguous"); } int64_t channels = getInputs().size(); if (channels == 2 || channels == 4) { @@ -1399,7 +1565,8 @@ LogicalResult VMIChannelMergeOp::verify() { auto expectedLayout = VMILayoutAttr::getDeinterleaved(getContext(), channels); if (!resultLayout.isContiguous() && resultLayout != expectedLayout) - return emitOpError("requires layout-assigned channel_merge result to be contiguous or deinterleaved by input count"); + return emitOpError("requires layout-assigned channel_merge result to " + "be contiguous or deinterleaved by input count"); } } return success(); @@ -1410,7 +1577,8 @@ LogicalResult VMIEnsureLayoutOp::verify() { auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount() || sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires source and result to preserve VMI data shape and element type"); + return emitOpError("requires source and result to preserve VMI data shape " + "and element type"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); return success(); @@ -1421,7 +1589,8 @@ LogicalResult VMIEnsureMaskLayoutOp::verify() { auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount() || sourceType.getGranularity() != resultType.getGranularity()) - return emitOpError("requires source and result to preserve VMI mask shape and granularity"); + return emitOpError("requires source and result to preserve VMI mask shape " + "and granularity"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); return success(); @@ -1431,13 +1600,15 @@ LogicalResult VMIEnsureMaskGranularityOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError("requires source and result to preserve VMI mask lane count"); + return emitOpError( + "requires source and result to preserve VMI mask lane count"); if (!isLayoutAssigned(sourceType) || !isLayoutAssigned(resultType)) return emitOpError("requires source and result to be layout-assigned"); if (sourceType.getLayout() != resultType.getLayout()) return emitOpError("requires source and result mask layouts to match"); if (sourceType.isPred() || resultType.isPred()) - return emitOpError("requires concrete source and result mask granularities"); + return emitOpError( + "requires concrete source and result mask granularities"); return success(); } @@ -1473,14 +1644,22 @@ FailureOr mlir::pto::getMaskLanesPerPart(StringRef granularity) { FailureOr mlir::pto::getVMIPhysicalArity(Type type) { FailureOr elementCount = getVMIElementCount(type); - FailureOr factor = getLayoutFactor(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + FailureOr layout = getAssignedVMILayout(type); + if (failed(elementCount) || failed(lanesPerPart) || failed(layout)) return failure(); + if ((*layout).isGroupSlots() && (*layout).getSlots() > 0) + return divideCeilNonNegative((*layout).getNumGroups(), + (*layout).getSlots()); + + int64_t factor = (*layout).isDeinterleaved() ? (*layout).getFactor() : 1; + int64_t blockElems = + (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; int64_t arity = 0; - for (int64_t part = 0; part < *factor; ++part) { - int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + for (int64_t part = 0; part < factor; ++part) { + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, factor, blockElems, part); arity += divideCeilNonNegative(lanesInPart, *lanesPerPart); } return arity; @@ -1490,16 +1669,21 @@ FailureOr mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (logicalLane < 0 || logicalLane >= *elementCount) return failure(); - int64_t part = logicalLane % *factor; - int64_t indexInPart = logicalLane / *factor; - return VMIPhysicalLane{part, indexInPart / *lanesPerPart, - indexInPart % *lanesPerPart}; + int64_t part = 0; + std::optional indexInPart = mapDenseLogicalLaneToPartIndex( + *elementCount, *factor, *blockElems, logicalLane, part); + if (!indexInPart) + return failure(); + return VMIPhysicalLane{part, *indexInPart / *lanesPerPart, + *indexInPart % *lanesPerPart}; } FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, @@ -1507,32 +1691,38 @@ FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, int64_t lane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || lane >= *lanesPerPart) return failure(); int64_t indexInPart = chunk * *lanesPerPart + lane; - int64_t logicalLane = indexInPart * *factor + part; - if (logicalLane >= *elementCount) + std::optional logicalLane = mapDensePartIndexToLogicalLane( + *elementCount, *factor, *blockElems, part, indexInPart); + if (!logicalLane) return failure(); - return logicalLane; + return *logicalLane; } -FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, - int64_t chunk, int64_t lane) { +FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, int64_t chunk, + int64_t lane) { FailureOr elementCount = getVMIElementCount(type); FailureOr factor = getLayoutFactor(type); + FailureOr blockElems = getLayoutBlockElems(type); FailureOr lanesPerPart = getPhysicalLanesPerPart(type); - if (failed(elementCount) || failed(factor) || failed(lanesPerPart)) + if (failed(elementCount) || failed(factor) || failed(blockElems) || + failed(lanesPerPart)) return failure(); if (part < 0 || part >= *factor || chunk < 0 || lane < 0 || lane >= *lanesPerPart) return failure(); - int64_t lanesInPart = getLogicalLanesInPart(*elementCount, *factor, part); + int64_t lanesInPart = + getDenseLogicalLanesInPart(*elementCount, *factor, *blockElems, part); int64_t indexInPart = chunk * *lanesPerPart + lane; return indexInPart >= lanesInPart; } diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 27d6b806fe..c95e8772ec 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMILayoutAssignment.cpp - Assign VMI layouts ----------------------===// //===----------------------------------------------------------------------===// @@ -63,6 +65,8 @@ struct MaskUseRequest { std::string granularity; }; +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + static unsigned getElementBitWidth(Type type) { if (isa(type)) return 64; @@ -82,6 +86,15 @@ static StringRef getMaskGranularityForElement(Type elementType) { } } +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) + if (auto integerAttr = dyn_cast(constant.getValue())) + return integerAttr.getInt(); + return std::nullopt; +} + static bool isLane0SplatShuffle(VMIShuffleOp op) { auto sourceType = cast(op.getSource().getType()); ArrayRef indices = op.getIndices(); @@ -93,12 +106,10 @@ bool containsVMIType(Type type) { if (isa(type)) return true; if (auto functionType = dyn_cast(type)) { - return llvm::any_of(functionType.getInputs(), [](Type input) { - return containsVMIType(input); - }) || - llvm::any_of(functionType.getResults(), [](Type result) { - return containsVMIType(result); - }); + return llvm::any_of(functionType.getInputs(), + [](Type input) { return containsVMIType(input); }) || + llvm::any_of(functionType.getResults(), + [](Type result) { return containsVMIType(result); }); } if (auto shapedType = dyn_cast(type)) return containsVMIType(shapedType.getElementType()); @@ -191,11 +202,10 @@ struct LayoutSolver { if (!lhsNode.requestedGranularity.empty() && !rhsNode.requestedGranularity.empty() && lhsNode.requestedGranularity != rhsNode.requestedGranularity) - return op->emitError() - << kVMIDiagLayoutContractPrefix - << "conflicting mask granularities " - << lhsNode.requestedGranularity << " and " - << rhsNode.requestedGranularity; + return op->emitError() << kVMIDiagLayoutContractPrefix + << "conflicting mask granularities " + << lhsNode.requestedGranularity << " and " + << rhsNode.requestedGranularity; rhsNode.parent = lhsRoot; if (!lhsNode.requestedLayout) @@ -228,6 +238,123 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups); } + VMILayoutAttr getPreferredGroupSlotsLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 8) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 16) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 32) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + if (groupSize == 64) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + } + return getGroupSlotsLayout(numGroups); + } + + VMILayoutAttr getPreferredGroupReduceSourceLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 16) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + if (groupSize == 32) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + } + return getContiguousLayout(); + } + + VMILayoutAttr getPreferredGroupSlotLoadLayout(VMIVRegType type, + int64_t numGroups) { + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + if (numGroups > 0 && type.getElementCount() % numGroups == 0) { + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize == 64) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + } + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + } + + VMILayoutAttr getPreferredGroupLoadResultLayout(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return getContiguousLayout(); + + if (!type.getElementType().isF32()) + return getContiguousLayout(); + + int64_t groupSize = type.getElementCount() / numGroups; + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return getContiguousLayout(); + + if (groupSize == 16) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + if (groupSize == 32) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + + return getContiguousLayout(); + } + + LogicalResult validateGroupLoadLayoutPlan(VMIGroupLoadOp op) { + auto type = cast(op.getResult().getType()); + if (type.getLayoutAttr()) + return success(); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return success(); + if (!type.getElementType().isF32()) + return success(); + + int64_t groupSize = type.getElementCount() / numGroups; + if (groupSize != 16 && groupSize != 32) + return success(); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (rowStride && *rowStride > 0 && *rowStride % 8 == 0) + return success(); + + return op.emitError() + << kVMIDiagLayoutContractPrefix << "pto.vmi.group_load group_size " + << groupSize + << " requires constant positive row_stride divisible by 8 f32 " + "elements for the block8 stride plan; stable gather fallback is " + "not implemented"; + } + + VMILayoutAttr getPreferredGroupStoreUseLayout(Value value, + int64_t numGroups) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutAttr solved = getDataLayout(value); + if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && + solved.getSlots() > 0) + return solved; + if (value.getDefiningOp()) + return getPreferredGroupSlotsLayout(type, numGroups); + if (value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(type, numGroups); + return getContiguousLayout(); + } + VMILayoutAttr getDataLayout(Value value) { unsigned id = addDataValue(value); if (id == ~0u) @@ -238,6 +365,35 @@ struct LayoutSolver { return getContiguousLayout(); } + VMILayoutAttr getExplicitDataLayout(Value value) { + unsigned id = addDataValue(value); + if (id == ~0u) + return {}; + return dataNodes[find(id)].naturalLayout; + } + + bool hasCompatibleTruncFUseForGroupReduce(Value value, int64_t groupSize) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType || !sourceType.getElementType().isF32()) + return false; + + for (OpOperand &use : value.getUses()) { + auto truncf = dyn_cast(use.getOwner()); + if (!truncf || use.getOperandNumber() != 0) + continue; + + auto resultType = dyn_cast(truncf.getResult().getType()); + if (!resultType) + continue; + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (groupSize == 16 && resultBits == 16) + return true; + if (groupSize == 32 && resultBits == 8) + return true; + } + return false; + } + LogicalResult requestMask(Value mask, VMILayoutAttr layout, StringRef granularity, Operation *op) { unsigned id = addMaskValue(mask); @@ -256,8 +412,8 @@ struct LayoutSolver { node.requestedGranularity != granularity) return op->emitError() << kVMIDiagLayoutContractPrefix - << "conflicting mask granularities " - << node.requestedGranularity << " and " << granularity; + << "conflicting mask granularities " << node.requestedGranularity + << " and " << granularity; node.requestedLayout = layout; node.requestedGranularity = granularity.str(); return success(); @@ -268,17 +424,54 @@ struct LayoutSolver { dataUseRequests.push_back(DataUseRequest{&operand, layout}); } - bool canAdoptConsumerRequestedLayout(Value value) { - if (!value.hasOneUse()) + bool canProducerAdoptConsumerLayout(Operation *op) { + if (!op) return false; + return isa(op); + } + + bool canAdoptConsumerRequestedLayout(Value value, + VMILayoutAttr requestedLayout) { Operation *definingOp = value.getDefiningOp(); - return definingOp && isa(definingOp); + if (!definingOp) + return false; + if (!isa(definingOp)) { + if (!requestedLayout || requestedLayout.isContiguous()) + return false; + if (!canProducerAdoptConsumerLayout(definingOp)) + return false; + } + if (value.hasOneUse()) + return true; + + unsigned matchingRequests = 0; + unsigned totalUses = 0; + for (OpOperand &use : value.getUses()) { + ++totalUses; + bool foundRequest = false; + for (DataUseRequest request : dataUseRequests) { + if (request.operand != &use) + continue; + if (request.layout != requestedLayout) + return false; + foundRequest = true; + } + if (!foundRequest) + return false; + ++matchingRequests; + } + return matchingRequests == totalUses; } LogicalResult applyConsumerDrivenDataLayouts() { for (DataUseRequest request : dataUseRequests) { Value value = request.operand->get(); - if (!canAdoptConsumerRequestedLayout(value)) + if (!canAdoptConsumerRequestedLayout(value, request.layout)) continue; unsigned id = addDataValue(value); if (id == ~0u) @@ -287,8 +480,7 @@ struct LayoutSolver { VMILayoutAttr existing = dataNodes[root].naturalLayout; if (existing && existing != request.layout) return request.operand->getOwner()->emitError() - << kVMIDiagLayoutContractPrefix - << "conflicting natural layouts " + << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " << existing << " and " << request.layout; dataNodes[root].naturalLayout = request.layout; } @@ -324,6 +516,71 @@ struct LayoutSolver { return success(); } + bool shouldCommuteTruncFAfterGroupBroadcast(VMIGroupBroadcastOp broadcast) { + auto truncf = broadcast.getSource().getDefiningOp(); + if (!truncf) + return false; + + auto truncSourceType = dyn_cast(truncf.getSource().getType()); + auto truncResultType = dyn_cast(truncf.getResult().getType()); + auto broadcastResultType = + dyn_cast(broadcast.getResult().getType()); + if (!truncSourceType || !truncResultType || !broadcastResultType) + return false; + if (truncSourceType.getElementCount() != + truncResultType.getElementCount() || + truncResultType.getElementCount() != + broadcastResultType.getElementCount()) + return false; + + VMILayoutAttr sourceLayout = truncSourceType.getLayoutAttr(); + bool sourceIsGroupSlotValue = + (sourceLayout && sourceLayout.isGroupSlots()) || + truncf.getSource().getDefiningOp() || + truncf.getSource().getDefiningOp(); + if (!sourceIsGroupSlotValue) + return false; + + unsigned sourceBits = getElementBitWidth(truncSourceType.getElementType()); + unsigned resultBits = getElementBitWidth(truncResultType.getElementType()); + return truncSourceType.getElementType().isF32() && sourceBits > resultBits; + } + + LogicalResult commuteTruncFAfterGroupBroadcast() { + SmallVector broadcasts; + module.walk([&](VMIGroupBroadcastOp broadcast) { + if (shouldCommuteTruncFAfterGroupBroadcast(broadcast)) + broadcasts.push_back(broadcast); + }); + + OpBuilder builder(ctx); + for (VMIGroupBroadcastOp broadcast : broadcasts) { + auto truncf = broadcast.getSource().getDefiningOp(); + if (!truncf) + continue; + + auto truncSourceType = cast(truncf.getSource().getType()); + auto broadcastResultType = + cast(broadcast.getResult().getType()); + auto wideBroadcastType = + VMIVRegType::get(ctx, broadcastResultType.getElementCount(), + truncSourceType.getElementType(), + broadcastResultType.getLayoutAttr()); + + builder.setInsertionPoint(broadcast); + auto wideBroadcast = builder.create( + broadcast.getLoc(), wideBroadcastType, truncf.getSource(), + broadcast.getNumGroupsAttr()); + auto narrow = builder.create( + broadcast.getLoc(), broadcastResultType, wideBroadcast.getResult()); + broadcast.getResult().replaceAllUsesWith(narrow.getResult()); + broadcast.erase(); + if (truncf->use_empty()) + truncf.erase(); + } + return success(); + } + LogicalResult addConstraints() { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto maskAnd = dyn_cast(op)) { @@ -504,56 +761,117 @@ struct LayoutSolver { } if (auto compress = dyn_cast(op)) { requestDataUse(compress.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(compress.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(compress.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); requestDataUse(reduce.getInitMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(reduce.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { - requestDataUse(reduce.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(reduce.getResult(), - getGroupSlotsLayout( - reduce.getNumGroupsAttr().getInt()), - op))) + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( + sourceType, reduce.getNumGroupsAttr().getInt()); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + if (solvedSourceLayout && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (groupSize == 16 && solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 2 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + if (groupSize == 32 && solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 4 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + groupSize)) { + if (groupSize == 16) + sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + if (groupSize == 32) + sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + } + } + if (sourceLayout && sourceLayout.isDeinterleaved() && + sourceLayout.getFactor() == 4 && + sourceLayout.getBlockElems() == 8 && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + if (groupSize == 32) { + if (auto groupMask = + reduce.getMask().getDefiningOp()) { + std::optional activeElems = + getConstantIndexValue(groupMask.getActiveElemsPerGroup()); + if (activeElems && *activeElems >= 0 && + *activeElems < groupSize) { + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addf s32 block8 lowering does " + "not yet support partial create_group_mask " + "active_elems_per_group during layout assignment"; + return WalkResult::interrupt(); + } + } + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + getPreferredGroupSlotsLayout( + resultType, reduce.getNumGroupsAttr().getInt()), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto broadcast = dyn_cast(op)) { + auto sourceType = cast(broadcast.getSource().getType()); requestDataUse(broadcast.getSourceMutable(), - getGroupSlotsLayout( - broadcast.getNumGroupsAttr().getInt())); + getPreferredGroupSlotsLayout( + sourceType, broadcast.getNumGroupsAttr().getInt())); return WalkResult::advance(); } if (auto extf = dyn_cast(op)) { @@ -581,6 +899,14 @@ struct LayoutSolver { auto resultType = cast(truncf.getResult().getType()); unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); + if (sourceBits == 32 && resultBits == 16 && sourceLayout && + sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + requestDataUse(truncf.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (sourceBits == 32 && resultBits == 16) requestDataUse(truncf.getSourceMutable(), VMILayoutAttr::getDeinterleaved(ctx, 2)); @@ -599,8 +925,8 @@ struct LayoutSolver { } if (auto load = dyn_cast(op)) { requestDataUse(load.getPassthruMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -608,27 +934,37 @@ struct LayoutSolver { auto resultType = cast(gather.getResult().getType()); requestDataUse(gather.getIndicesMutable(), getContiguousLayout()); requestDataUse(gather.getPassthruMutable(), getContiguousLayout()); - if (failed(requestMaskUse(gather.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + gather.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(gather.getResult(), - getContiguousLayout(), op))) + if (failed(setNaturalLayout(gather.getResult(), getContiguousLayout(), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { requestDataUse(load.getPassthruMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { - if (failed(setNaturalLayout(load.getResult(), getContiguousLayout(), - op))) + if (failed(validateGroupLoadLayoutPlan(load))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + load.getResult(), getPreferredGroupLoadResultLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed(setNaturalLayout( + load.getResult(), + getPreferredGroupSlotLoadLayout( + resultType, load.getNumGroupsAttr().getInt()), + op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -637,17 +973,18 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto store = dyn_cast(op)) { - requestDataUse(store.getValueMutable(), getContiguousLayout()); + requestDataUse( + store.getValueMutable(), + getPreferredGroupStoreUseLayout(store.getValue(), + store.getNumGroupsAttr().getInt())); return WalkResult::advance(); } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); - if (failed(requestMaskUse(store.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -655,22 +992,18 @@ struct LayoutSolver { auto valueType = cast(scatter.getValue().getType()); requestDataUse(scatter.getValueMutable(), getContiguousLayout()); requestDataUse(scatter.getIndicesMutable(), getContiguousLayout()); - if (failed(requestMaskUse(scatter.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + scatter.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); requestDataUse(store.getValueMutable(), getContiguousLayout()); - if (failed(requestMaskUse(store.getMaskMutable(), - getContiguousLayout(), - getMaskGranularityForElement( - valueType.getElementType()), - op))) + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -680,16 +1013,14 @@ struct LayoutSolver { } if (auto split = dyn_cast(op)) { int64_t channels = split.getNumResults(); - VMICapabilityResult capability = - capabilities.supportsChannelCount("pto.vmi.channel_split", - channels); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_split", channels); if (!capability.isSupported()) { split.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; return WalkResult::interrupt(); } - requestDataUse( - split.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, channels)); + requestDataUse(split.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, channels)); for (Value result : split.getResults()) if (failed(setNaturalLayout(result, getContiguousLayout(), op))) return WalkResult::interrupt(); @@ -697,9 +1028,8 @@ struct LayoutSolver { } if (auto merge = dyn_cast(op)) { int64_t channels = merge.getInputs().size(); - VMICapabilityResult capability = - capabilities.supportsChannelCount("pto.vmi.channel_merge", - channels); + VMICapabilityResult capability = capabilities.supportsChannelCount( + "pto.vmi.channel_merge", channels); if (!capability.isSupported()) { merge.emitError() << kVMIDiagUnsupportedPrefix << capability.reason; return WalkResult::interrupt(); @@ -771,9 +1101,8 @@ struct LayoutSolver { if (failed(addBranchConstraints(switchOp.getDefaultDestination(), switchOp.getDefaultOperands(), op))) return WalkResult::interrupt(); - for (auto [dest, operands] : - llvm::zip(switchOp.getCaseDestinations(), - switchOp.getCaseOperands())) { + for (auto [dest, operands] : llvm::zip(switchOp.getCaseDestinations(), + switchOp.getCaseOperands())) { if (failed(addBranchConstraints(dest, operands, op))) return WalkResult::interrupt(); } @@ -825,8 +1154,7 @@ struct LayoutSolver { for (Region *region : {&ifOp.getThenRegion(), &ifOp.getElseRegion()}) { if (region->empty()) continue; - auto yieldOp = - dyn_cast(region->front().getTerminator()); + auto yieldOp = dyn_cast(region->front().getTerminator()); if (!yieldOp || resultNo >= yieldOp.getNumOperands()) continue; if (failed(uniteEquivalentValues(result, yieldOp.getOperand(resultNo), @@ -852,8 +1180,8 @@ struct LayoutSolver { WalkResult result = executeOp.getRegion().walk([&](scf::YieldOp yieldOp) { if (yieldOp->getParentOp() != executeOp.getOperation()) return WalkResult::advance(); - if (failed(addYieldConstraints(executeOp->getResults(), yieldOp, - executeOp))) + if (failed( + addYieldConstraints(executeOp->getResults(), yieldOp, executeOp))) return WalkResult::interrupt(); return WalkResult::advance(); }); @@ -903,8 +1231,8 @@ struct LayoutSolver { whileOp))) return failure(); if (index < whileOp.getNumResults() && - failed(uniteEquivalentValues(anchor, whileOp.getResult(index), - whileOp))) + failed( + uniteEquivalentValues(anchor, whileOp.getResult(index), whileOp))) return failure(); } return success(); @@ -927,8 +1255,8 @@ struct LayoutSolver { failed(uniteEquivalentValues(anchor, results[index], forOp))) return failure(); if (yieldOp && index < yieldOp.getNumOperands() && - failed(uniteEquivalentValues(anchor, yieldOp.getOperand(index), - forOp))) + failed( + uniteEquivalentValues(anchor, yieldOp.getOperand(index), forOp))) return failure(); } return success(); @@ -963,7 +1291,8 @@ struct LayoutSolver { for (auto [index, operand] : llvm::enumerate(returnOp.getOperands())) { if (index >= firstOperands.size()) break; - if (failed(uniteEquivalentValues(firstOperands[index], operand, returnOp))) + if (failed( + uniteEquivalentValues(firstOperands[index], operand, returnOp))) return failure(); } return success(); @@ -1020,13 +1349,12 @@ struct LayoutSolver { } std::optional rematerializeDataUse(Value value, VMIVRegType resultType, - Location loc, - OpBuilder &builder) { + Location loc, OpBuilder &builder) { if (auto constant = value.getDefiningOp()) { auto denseAttr = dyn_cast(constant.getValue()); if (denseAttr && denseAttr.isSplat()) - return builder.create(loc, resultType, - constant.getValue()) + return builder + .create(loc, resultType, constant.getValue()) .getResult(); } if (auto broadcast = value.getDefiningOp()) @@ -1034,8 +1362,9 @@ struct LayoutSolver { .create(loc, resultType, broadcast.getValue()) .getResult(); if (auto iota = value.getDefiningOp()) - return builder.create(loc, resultType, iota.getBase(), - iota.getOrderAttr()) + return builder + .create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) .getResult(); return std::nullopt; } @@ -1060,9 +1389,8 @@ struct LayoutSolver { VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), request.layout); builder.setInsertionPoint(request.operand->getOwner()); - std::optional rematerialized = - rematerializeDataUse(value, resultType, - request.operand->getOwner()->getLoc(), builder); + std::optional rematerialized = rematerializeDataUse( + value, resultType, request.operand->getOwner()->getLoc(), builder); if (rematerialized) { request.operand->set(*rematerialized); continue; @@ -1078,120 +1406,97 @@ struct LayoutSolver { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto cmpf = dyn_cast(op)) { auto lhsType = cast(cmpf.getLhs().getType()); - if (failed(requestMask(cmpf.getResult(), lhsType.getLayoutAttr(), - getMaskGranularityForElement( - lhsType.getElementType()), - op))) + if (failed(requestMask( + cmpf.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto cmpi = dyn_cast(op)) { auto lhsType = cast(cmpi.getLhs().getType()); - if (failed(requestMask(cmpi.getResult(), lhsType.getLayoutAttr(), - getMaskGranularityForElement( - lhsType.getElementType()), - op))) + if (failed(requestMask( + cmpi.getResult(), lhsType.getLayoutAttr(), + getMaskGranularityForElement(lhsType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto select = dyn_cast(op)) { auto resultType = cast(select.getResult().getType()); - if (failed(requestMaskUse(select.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + select.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto activePrefix = dyn_cast(op)) { - auto resultType = - cast(activePrefix.getResult().getType()); - if (failed(requestMaskUse(activePrefix.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + auto resultType = cast(activePrefix.getResult().getType()); + if (failed(requestMaskUse( + activePrefix.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto compress = dyn_cast(op)) { auto resultType = cast(compress.getResult().getType()); - if (failed(requestMaskUse(compress.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + compress.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); - if (failed(requestMaskUse(reduce.getMaskMutable(), - sourceType.getLayoutAttr(), - getMaskGranularityForElement( - sourceType.getElementType()), - op))) + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); - if (failed(requestMaskUse(load.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); - if (failed(requestMaskUse(load.getMaskMutable(), - resultType.getLayoutAttr(), - getMaskGranularityForElement( - resultType.getElementType()), - op))) + if (failed(requestMaskUse( + load.getMaskMutable(), resultType.getLayoutAttr(), + getMaskGranularityForElement(resultType.getElementType()), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -1203,8 +1508,8 @@ struct LayoutSolver { void rewriteMaskTypes() { for (MaskNode &node : maskNodes) { MaskNode &root = maskNodes[findMask(maskIds.lookup(node.value))]; - VMILayoutAttr layout = root.requestedLayout ? root.requestedLayout - : getContiguousLayout(); + VMILayoutAttr layout = + root.requestedLayout ? root.requestedLayout : getContiguousLayout(); StringRef granularity = root.requestedGranularity.empty() ? StringRef("b32") : StringRef(root.requestedGranularity); @@ -1214,11 +1519,17 @@ struct LayoutSolver { } std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, - Location loc, - OpBuilder &builder) { + Location loc, OpBuilder &builder) { if (auto createMask = value.getDefiningOp()) - return builder.create(loc, resultType, - createMask.getActiveLanes()) + return builder + .create(loc, resultType, createMask.getActiveLanes()) + .getResult(); + if (auto createGroupMask = value.getDefiningOp()) + return builder + .create( + loc, resultType, createGroupMask.getActiveElemsPerGroup(), + createGroupMask.getNumGroupsAttr(), + createGroupMask.getGroupSizeAttr()) .getResult(); if (auto constantMask = value.getDefiningOp()) return builder @@ -1245,9 +1556,9 @@ struct LayoutSolver { builder.setInsertionPoint(request.operand->getOwner()); Value current = value; VMIMaskType currentType = sourceType; - auto requestedType = VMIMaskType::get(ctx, sourceType.getElementCount(), - request.granularity, - request.layout); + auto requestedType = + VMIMaskType::get(ctx, sourceType.getElementCount(), + request.granularity, request.layout); if (sourceType != requestedType) { std::optional rematerialized = rematerializeMaskUse( value, requestedType, request.operand->getOwner()->getLoc(), @@ -1259,9 +1570,9 @@ struct LayoutSolver { } if (sourceLayout != request.layout) { - auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), - currentType.getGranularity(), - request.layout); + auto layoutType = + VMIMaskType::get(ctx, currentType.getElementCount(), + currentType.getGranularity(), request.layout); auto ensureLayout = builder.create( request.operand->getOwner()->getLoc(), layoutType, current); current = ensureLayout.getResult(); @@ -1272,10 +1583,8 @@ struct LayoutSolver { auto granularityType = VMIMaskType::get(ctx, currentType.getElementCount(), request.granularity, request.layout); - auto ensureGranularity = - builder.create( - request.operand->getOwner()->getLoc(), granularityType, - current); + auto ensureGranularity = builder.create( + request.operand->getOwner()->getLoc(), granularityType, current); current = ensureGranularity.getResult(); } @@ -1285,6 +1594,143 @@ struct LayoutSolver { return success(); } + std::optional getGroupReduceSelectedPlan(VMIGroupReduceAddFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return std::nullopt; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return std::nullopt; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || sourceType.getElementCount() % numGroups != 0) + return std::nullopt; + int64_t groupSize = sourceType.getElementCount() / numGroups; + + if (sourceLayout.isContiguous()) { + if (groupSize == 8) + return StringRef("s8_reduce_contiguous"); + if (groupSize == 64) + return StringRef("s64_reduce_row_local"); + return std::nullopt; + } + + if (!sourceLayout.isDeinterleaved()) + return std::nullopt; + + if (groupSize == 16 && sourceLayout.getFactor() == 2) { + if (sourceLayout.getBlockElems() == 1) + return StringRef("s16_reduce_parity"); + if (sourceLayout.getBlockElems() == 8) + return StringRef("s16_reduce_block8"); + } + + if (groupSize == 32 && sourceLayout.getFactor() == 4) { + if (sourceLayout.getBlockElems() == 1) + return StringRef("s32_reduce_dintlv4"); + if (sourceLayout.getBlockElems() == 8) + return StringRef("s32_reduce_block8_stride"); + } + + return std::nullopt; + } + + std::optional getGroupSlotLoadSelectedPlan(VMIGroupSlotLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return std::nullopt; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return std::nullopt; + if (layout.getSlots() == 8) + return StringRef("group_slot_load_slots8_unit_stride"); + if (layout.getSlots() == 1) + return StringRef("group_slot_load_slots1_row_local"); + return std::nullopt; + } + + std::optional getGroupLoadSelectedPlan(VMIGroupLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return std::nullopt; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return std::nullopt; + if (layout.isContiguous()) + return StringRef("group_load_contiguous_chunks"); + if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) + return std::nullopt; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || resultType.getElementCount() % numGroups != 0) + return std::nullopt; + int64_t groupSize = resultType.getElementCount() / numGroups; + if (groupSize == 16 && layout.getFactor() == 2) + return StringRef("s16_group_load_block8_stride"); + if (groupSize == 32 && layout.getFactor() == 4) + return StringRef("s32_group_load_block8_stride"); + return std::nullopt; + } + + std::optional + getGroupBroadcastSelectedPlan(VMIGroupBroadcastOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return std::nullopt; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !sourceLayout.isGroupSlots() || + sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || + resultLayout.isGroupSlots()) + return std::nullopt; + if (sourceLayout.getSlots() == 8) + return StringRef("group_broadcast_slots8_vselr"); + if (sourceLayout.getSlots() == 1) + return StringRef("group_broadcast_slots1_vselr"); + return std::nullopt; + } + + std::optional getTruncFSelectedPlan(VMITruncFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return std::nullopt; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || sourceLayout != resultLayout || + !sourceLayout.isGroupSlots() || sourceLayout.getSlots() != 1) + return std::nullopt; + + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 32 && resultBits == 16) + return StringRef("group_slot_cast_slots1_f32_to_f16"); + return std::nullopt; + } + + void attachSelectedPlanAttrs() { + Builder builder(ctx); + module.walk([&](Operation *op) { + std::optional plan; + if (auto reduce = dyn_cast(op)) + plan = getGroupReduceSelectedPlan(reduce); + else if (auto load = dyn_cast(op)) + plan = getGroupLoadSelectedPlan(load); + else if (auto load = dyn_cast(op)) + plan = getGroupSlotLoadSelectedPlan(load); + else if (auto broadcast = dyn_cast(op)) + plan = getGroupBroadcastSelectedPlan(broadcast); + else if (auto truncf = dyn_cast(op)) + plan = getTruncFSelectedPlan(truncf); + + if (plan) + op->setAttr(kVMISelectedPlanAttrName, builder.getStringAttr(*plan)); + }); + } + void rewriteFunctionType() { module.walk([&](func::FuncOp func) { if (func.empty()) @@ -1320,6 +1766,8 @@ struct LayoutSolver { } LogicalResult run() { + if (failed(commuteTruncFAfterGroupBroadcast())) + return failure(); if (failed(collect())) return failure(); if (failed(addConstraints())) @@ -1329,6 +1777,7 @@ struct LayoutSolver { rewriteDataTypes(); if (failed(insertDataUseMaterializations())) return failure(); + attachSelectedPlanAttrs(); if (failed(inferMaskRequests())) return failure(); rewriteMaskTypes(); @@ -1351,8 +1800,7 @@ struct LayoutSolver { }; struct VMILayoutAssignmentPass - : public mlir::pto::impl::VMILayoutAssignmentBase< - VMILayoutAssignmentPass> { + : public mlir::pto::impl::VMILayoutAssignmentBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutAssignmentPass) void runOnOperation() override { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index cf91af1142..95141bada7 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMIToVPTO.cpp - Convert VMI to physical VPTO IR -------------------===// //===----------------------------------------------------------------------===// @@ -48,6 +50,8 @@ using namespace mlir::pto; namespace { +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + bool isVMIType(Type type) { return isa(type); } bool containsVMIType(Type type) { @@ -87,9 +91,8 @@ bool hasVMIType(Attribute attr) { return true; if (auto arrayAttr = dyn_cast(attr)) - return llvm::any_of(arrayAttr, [](Attribute element) { - return hasVMIType(element); - }); + return llvm::any_of(arrayAttr, + [](Attribute element) { return hasVMIType(element); }); if (auto dictAttr = dyn_cast(attr)) return llvm::any_of(dictAttr, [](NamedAttribute namedAttr) { @@ -130,9 +133,8 @@ bool isLayoutAssignedVMIType(Type type) { LogicalResult verifyLayoutAssignedVMITypeTree(Operation *op, Type type) { if (!isLayoutAssignedVMIType(type)) - return op->emitError() - << kVMIDiagPassInvariantPrefix - << "vmi-to-vpto requires layout-assigned VMI types"; + return op->emitError() << kVMIDiagPassInvariantPrefix + << "vmi-to-vpto requires layout-assigned VMI types"; if (auto functionType = dyn_cast(type)) { for (Type input : functionType.getInputs()) @@ -233,28 +235,28 @@ class VMIToVPTOTypeConverter final : public OneToNTypeConverter { public: VMIToVPTOTypeConverter() { addConversion([](Type type) { return type; }); - addConversion([](VMIVRegType type, SmallVectorImpl &results) - -> LogicalResult { - FailureOr arity = getVMIPhysicalArity(type); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (failed(arity) || failed(lanesPerPart)) - return failure(); - for (int64_t i = 0; i < *arity; ++i) - results.push_back(VRegType::get(type.getContext(), *lanesPerPart, - type.getElementType())); - return success(); - }); - addConversion([](VMIMaskType type, SmallVectorImpl &results) - -> LogicalResult { - FailureOr arity = getVMIPhysicalArity(type); - if (failed(arity)) - return failure(); - for (int64_t i = 0; i < *arity; ++i) - results.push_back(MaskType::get(type.getContext(), - type.getGranularity())); - return success(); - }); + addConversion( + [](VMIVRegType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (failed(arity) || failed(lanesPerPart)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back(VRegType::get(type.getContext(), *lanesPerPart, + type.getElementType())); + return success(); + }); + addConversion( + [](VMIMaskType type, SmallVectorImpl &results) -> LogicalResult { + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failure(); + for (int64_t i = 0; i < *arity; ++i) + results.push_back( + MaskType::get(type.getContext(), type.getGranularity())); + return success(); + }); TypeConverter::addSourceMaterialization(materializeVPTOToVMI); TypeConverter::addArgumentMaterialization(materializeVPTOToVMI); OneToNTypeConverter::addTargetMaterialization(materializeVMIToVPTO); @@ -284,8 +286,7 @@ FailureOr createAllTrueMaskForVReg(Location loc, VRegType vregType, return failure(); } -FailureOr getMaskTypeForVReg(VRegType vregType, - MLIRContext *ctx) { +FailureOr getMaskTypeForVReg(VRegType vregType, MLIRContext *ctx) { unsigned elementBits = pto::getPTOStorageElemBitWidth(vregType.getElementType()); if (elementBits == 8) @@ -341,10 +342,12 @@ FailureOr createPrefixMask(Location loc, MaskType maskType, return rewriter.create(loc, MaskType::get(ctx, "b8"), patternAttr) .getResult(); if (maskType.isB16()) - return rewriter.create(loc, MaskType::get(ctx, "b16"), patternAttr) + return rewriter + .create(loc, MaskType::get(ctx, "b16"), patternAttr) .getResult(); if (maskType.isB32()) - return rewriter.create(loc, MaskType::get(ctx, "b32"), patternAttr) + return rewriter + .create(loc, MaskType::get(ctx, "b32"), patternAttr) .getResult(); return failure(); } @@ -372,18 +375,17 @@ createRuntimePrefixMask(Location loc, MaskType maskType, Value activeLanes, return failure(); } -LogicalResult checkSupportedMaskableVReg( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - std::string *reason = nullptr) { +LogicalResult +checkSupportedMaskableVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMICapabilityResult elementCapability = - capabilities.supportsElementType(type.getElementType(), - VMIElementPurpose::PredicateMask); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + type.getElementType(), VMIElementPurpose::PredicateMask); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -395,10 +397,11 @@ LogicalResult checkSupportedMaskableVReg( return success(); } -LogicalResult checkSupportedTargetElementVReg( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - VMIElementPurpose purpose, StringRef elementContract, - std::string *reason = nullptr) { +LogicalResult +checkSupportedTargetElementVReg(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, VMIElementPurpose purpose, + StringRef elementContract, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -416,19 +419,46 @@ LogicalResult checkSupportedTargetElementVReg( return success(); } -Value createI32Constant(Location loc, int64_t value, PatternRewriter &rewriter) { +Value createI32Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { return rewriter.create(loc, value, 32); } +FailureOr createPrefixMaskForActiveLanes(Location loc, MaskType maskType, + int64_t activeLanes, + PatternRewriter &rewriter) { + if (activeLanes <= 0) + return createPrefixMask(loc, maskType, "PAT_ALLF", rewriter); + + switch (activeLanes) { + case 1: + case 2: + case 3: + case 4: + case 8: + case 16: + case 32: + case 64: + case 128: + return createPrefixMask( + loc, maskType, (Twine("PAT_VL") + Twine(activeLanes)).str(), rewriter); + default: { + FailureOr> dynamicMask = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); + if (failed(dynamicMask)) + return failure(); + return dynamicMask->first; + } + } +} + Value clampDynamicActiveLanes(Location loc, Value activeLanes, int64_t maxActiveLanes, PatternRewriter &rewriter) { - Value activeI32 = - rewriter.create(loc, rewriter.getI32Type(), - activeLanes); + Value activeI32 = rewriter.create( + loc, rewriter.getI32Type(), activeLanes); Value zeroI32 = createI32Constant(loc, 0, rewriter); - Value nonNegative = - rewriter.create(loc, activeI32, zeroI32); + Value nonNegative = rewriter.create(loc, activeI32, zeroI32); Value maxI32 = createI32Constant(loc, maxActiveLanes, rewriter); return rewriter.create(loc, nonNegative, maxI32); } @@ -441,9 +471,8 @@ Value createPartitionActiveLanes(Location loc, Value activeLanesI32, int64_t bias = factor - 1 - part; Value biased = activeLanesI32; if (bias != 0) - biased = - rewriter.create(loc, biased, - createI32Constant(loc, bias, rewriter)); + biased = rewriter.create( + loc, biased, createI32Constant(loc, bias, rewriter)); return rewriter.create( loc, biased, createI32Constant(loc, factor, rewriter)); } @@ -643,8 +672,7 @@ LogicalResult checkSupportedLayoutMaterialization( }; VMICapabilityResult layoutCapability = - capabilities.supportsLayoutConversion(sourceLayout, resultLayout, - Type{}); + capabilities.supportsLayoutConversion(sourceLayout, resultLayout, Type{}); if (!layoutCapability.isSupported()) return fail(layoutCapability.reason); @@ -682,10 +710,10 @@ LogicalResult checkSupportedLayoutMaterialization( return success(); if (failed(sourceFull)) - return fail(Twine("source ") + sourceReason + - "; source materialization " + sourceMaterializationReason); - return fail(Twine("result ") + resultReason + - "; result materialization " + resultMaterializationReason); + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); } FailureOr getContiguousMaterializationPartCount(Type type, @@ -861,8 +889,7 @@ buildContiguousIdentityLaneAddressMap(int64_t constantOffset, return map; } -VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, - StringRef role, +VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, Value memoryValue = {}) { auto memrefType = dyn_cast(memoryType); if (!memrefType || memrefType.getLayout().isIdentity()) @@ -878,9 +905,10 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, return VMICapabilityResult::missingCapability(reason); } -VMIMemorySafeReadProof -computeSafeFullReadProof(Type sourceType, std::optional constantOffset, - VMIVRegType resultType) { +VMIMemorySafeReadProof computeSafeFullReadProof( + Type sourceType, std::optional constantOffset, + VMIVRegType resultType, + std::optional explicitFullReadElems = std::nullopt) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -893,9 +921,14 @@ computeSafeFullReadProof(Type sourceType, std::optional constantOffset, if (!constantOffset) return fail("requires constant index offset"); - FailureOr elements = getStaticMemRefElementCount(sourceType); - if (failed(elements)) - return fail("requires statically shaped memref source"); + std::optional elements = explicitFullReadElems; + if (!elements) { + FailureOr staticElements = getStaticMemRefElementCount(sourceType); + if (failed(staticElements)) + return fail("requires statically shaped memref source or explicit " + "full_read_elems"); + elements = *staticElements; + } proof.staticElementCount = *elements; if (*constantOffset < 0) @@ -920,11 +953,11 @@ computeSafeFullReadProof(Type sourceType, std::optional constantOffset, return proof; } -VMIMemoryAccessPlan -buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, - Value source, Type sourceType, VMIVRegType resultType, - std::optional constantOffset, - VMIMemoryValidMaskKind validMask) { +VMIMemoryAccessPlan buildReadAccessPlan( + const VMITargetCapabilityRegistry &capabilities, Value source, + Type sourceType, VMIVRegType resultType, + std::optional constantOffset, VMIMemoryValidMaskKind validMask, + std::optional explicitFullReadElems = std::nullopt) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -933,19 +966,19 @@ buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, plan.validMask = validMask; plan.permutation = VMIMemoryPermutationKind::Identity; plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; - plan.safeReadProof = - computeSafeFullReadProof(sourceType, constantOffset, resultType); + plan.safeReadProof = computeSafeFullReadProof( + sourceType, constantOffset, resultType, explicitFullReadElems); plan.laneAddressMap = plan.safeReadProof.laneAddressMap; - plan.targetCapability = capabilities.supportsDirectMemory(sourceType, - "source"); + plan.targetCapability = + capabilities.supportsDirectMemory(sourceType, "source"); if (plan.targetCapability.isSupported()) plan.targetCapability = requireIdentityMemRefLayout(sourceType, "source", source); if (validMask == VMIMemoryValidMaskKind::ExplicitMask) plan.trueMaskedLoadCapability = capabilities.supportsTrueMaskedLoad(sourceType, resultType, Type{}); - plan.scratchFallbackCapability = - capabilities.supportsFallbackResource(VMIFallbackResourceKind::ScratchMemory); + plan.scratchFallbackCapability = capabilities.supportsFallbackResource( + VMIFallbackResourceKind::ScratchMemory); plan.guardedFallbackCapability = capabilities.supportsFallbackResource( VMIFallbackResourceKind::GuardedControlFlow); return plan; @@ -954,8 +987,7 @@ buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, VMIMemoryAccessPlan buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, Value destination, Type destinationType, - VMIVRegType valueType, - VMIMemoryWriteMaskKind writeMask) { + VMIVRegType valueType, VMIMemoryWriteMaskKind writeMask) { VMIMemoryAccessPlan plan; plan.baseType = destinationType; plan.valueType = valueType; @@ -966,9 +998,8 @@ buildWriteAccessPlan(const VMITargetCapabilityRegistry &capabilities, plan.targetCapability = capabilities.supportsDirectMemory(destinationType, "destination"); if (plan.targetCapability.isSupported()) - plan.targetCapability = - requireIdentityMemRefLayout(destinationType, "destination", - destination); + plan.targetCapability = requireIdentityMemRefLayout( + destinationType, "destination", destination); return plan; } @@ -991,18 +1022,18 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { maskedLoadReason + scratchReason + guardedReason); } -FailureOr -verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, - Type sourceType, Value offset, - PatternRewriter &rewriter) { +FailureOr verifyFullOrSafeReadVRegChunks( + Operation *op, VMIVRegType type, Type sourceType, Value offset, + PatternRewriter &rewriter, + std::optional explicitFullReadElems = std::nullopt) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); if (succeeded(lanesPerPart)) return *lanesPerPart; - VMIMemorySafeReadProof safeReadProof = - computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); + VMIMemorySafeReadProof safeReadProof = computeSafeFullReadProof( + sourceType, getConstantIndexValue(offset), type, explicitFullReadElems); if (safeReadProof.proven) { lanesPerPart = getDataLanesPerPart(type.getElementType()); if (succeeded(lanesPerPart)) @@ -1018,16 +1049,16 @@ verifyFullOrSafeReadVRegChunks(Operation *op, VMIVRegType type, LogicalResult checkSupportedLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, Value source, Type sourceType, std::optional constantOffset, - std::string *reason) { + std::optional explicitFullReadElems, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMIMemoryAccessPlan accessPlan = - buildReadAccessPlan(capabilities, source, sourceType, type, - constantOffset, VMIMemoryValidMaskKind::AllTrue); + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, source, sourceType, type, constantOffset, + VMIMemoryValidMaskKind::AllTrue, explicitFullReadElems); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -1039,14 +1070,14 @@ LogicalResult checkSupportedLoadShape( return success(); requireUnavailableReadFallback(accessPlan); return fail(Twine(fullChunkReason) + - "; safe-read proof failed: " + - accessPlan.safeReadProof.reason + + "; safe-read proof failed: " + accessPlan.safeReadProof.reason + "; fallback decision: " + accessPlan.fallbackDecision.reason); } -LogicalResult checkSupportedStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - Value destination, Type destinationType, std::string *reason) { +LogicalResult +checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value destination, + Type destinationType, std::string *reason) { VMIMemoryAccessPlan accessPlan = buildWriteAccessPlan(capabilities, destination, destinationType, type, VMIMemoryWriteMaskKind::AllTrue); @@ -1083,8 +1114,7 @@ LogicalResult checkSupportedStoreShape( return fail(Twine("partial/tail store requires contiguous layout or " "deinterleaved layout that can materialize to contiguous; " "value ") + - fullChunkReason + ", materialization " + - materializationReason); + fullChunkReason + ", materialization " + materializationReason); } FailureOr getGroupSizeFromNumGroups(VMIVRegType type, @@ -1102,8 +1132,7 @@ FailureOr getGroupSizeFromNumGroups(VMIVRegType type, return type.getElementCount() / numGroups; } -LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, - int64_t groupSize, +LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, int64_t groupSize, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -1129,29 +1158,211 @@ LogicalResult checkSupportedGroupChunkShape(VMIVRegType type, return success(); } -LogicalResult checkSupportedGroupLoadShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, - std::string *reason) { +LogicalResult +checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto resultType = cast(op.getResult().getType()); - FailureOr groupSize = - getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), - reason); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout) + return fail("requires assigned result layout"); + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); - if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), - op.getSource().getType(), - std::nullopt, reason))) - return failure(); - return checkSupportedGroupChunkShape(resultType, *groupSize, reason); + + if (resultLayout.isContiguous()) { + StringRef expectedPlan = "group_load_contiguous_chunks"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), + op.getSource().getType(), std::nullopt, + std::nullopt, reason))) + return failure(); + return checkSupportedGroupChunkShape(resultType, *groupSize, reason); + } + + if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && + resultType.getElementType().isF32()) { + StringRef expectedPlan; + if (*groupSize == 16 && resultLayout.getFactor() == 2) + expectedPlan = "s16_group_load_block8_stride"; + else if (*groupSize == 32 && resultLayout.getFactor() == 4) + expectedPlan = "s32_group_load_block8_stride"; + else + return fail("block8 strided group_load requires S=16/factor=2 or " + "S=32/factor=4"); + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + if (!isa(op.getSource().getType())) + return fail("block8 strided group_load requires !pto.ptr source"); + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return fail("block8 strided group_load requires num_groups multiple " + "of 8"); + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return fail("block8 strided group_load requires constant positive " + "row_stride divisible by 8 f32 elements"); + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("block8 strided group_load requires full physical " + "result chunks; ") + + fullChunkReason); + return success(); + } + + return fail("requires contiguous layout or deinterleaved block8 f32 layout"); } -LogicalResult checkSupportedGroupStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, +LogicalResult checkSupportedGroupSlotLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt() || + layout.getSlots() <= 0) + return fail("requires explicit group_slots result layout matching " + "num_groups"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + + StringRef expectedPlan; + if (layout.getSlots() == 8) + expectedPlan = "group_slot_load_slots8_unit_stride"; + else if (layout.getSlots() == 1) + expectedPlan = "group_slot_load_slots1_row_local"; + else + return fail("supports only slots=8 or slots=1 group_slot_load layouts"); + + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for vsldb lowering"); + if (layout.getSlots() == 8) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return fail("slots=8 group_slot_load requires constant unit " + "source_group_stride"); + return success(); + } + if (layout.getSlots() == 1) { + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_slot_load requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_slot_load currently lowers as one " + "lane-0 vsldb per group and requires constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B load alignment; packed or unaligned " + "scalar load lowering is not implemented"); + return success(); + } + llvm_unreachable("unsupported group_slot_load slots should be rejected"); +} + +LogicalResult +checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto valueType = cast(op.getValue().getType()); - FailureOr groupSize = - getGroupSizeFromNumGroups(valueType, op.getNumGroupsAttr().getInt(), - reason); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (layout && layout.isGroupSlots()) { + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMIMemoryAccessPlan accessPlan = buildWriteAccessPlan( + capabilities, op.getDestination(), op.getDestination().getType(), + valueType, VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + if (failed(checkSupportedMaskableVReg(capabilities, valueType, reason))) + return failure(); + + FailureOr arity = getVMIPhysicalArity(valueType); + if (failed(arity)) + return fail("requires computable physical arity"); + if (layout.getSlots() == 1) { + if (*arity != numGroups) + return fail("slots=1 group_store requires one physical part per " + "group"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_store requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional rowStride = + getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_store currently lowers as one " + "lane-0 vsts per group and requires constant " + "positive row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B store alignment; packed or unaligned " + "contiguous store lowering is not implemented"); + return success(); + } + if (layout.getSlots() == 8) { + std::optional rowStride = + getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride != 1) + return fail("slots=8 group_store currently requires constant unit " + "row_stride"); + if (*arity != ceilDivNonNegative(numGroups, 8)) + return fail("slots=8 group_store arity must equal ceil(num_groups / " + "8)"); + return success(); + } + return fail("group_slots group_store currently supports only slots=1 or " + "unit-stride slots=8"); + } + + FailureOr groupSize = getGroupSizeFromNumGroups( + valueType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (failed(checkSupportedStoreShape(capabilities, valueType, @@ -1202,9 +1413,9 @@ checkSupportedMaskedLoadShape(const VMITargetCapabilityRegistry &capabilities, "; fallback decision: " + accessPlan.fallbackDecision.reason); } -LogicalResult checkSupportedGatherShape( - const VMITargetCapabilityRegistry &capabilities, VMIGatherOp op, - std::string *reason) { +LogicalResult +checkSupportedGatherShape(const VMITargetCapabilityRegistry &capabilities, + VMIGatherOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1260,8 +1471,7 @@ LogicalResult checkSupportedGatherShape( std::string passthruReason; std::string maskReason; if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) - return fail(Twine("result requires full physical chunks; ") + - resultReason); + return fail(Twine("result requires full physical chunks; ") + resultReason); if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) return fail(Twine("indices require full physical chunks; ") + indicesReason); @@ -1274,9 +1484,9 @@ LogicalResult checkSupportedGatherShape( return success(); } -LogicalResult checkSupportedScatterShape( - const VMITargetCapabilityRegistry &capabilities, VMIScatterOp op, - std::string *reason) { +LogicalResult +checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, + VMIScatterOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1300,9 +1510,9 @@ LogicalResult checkSupportedScatterShape( return fail("requires contiguous value, indices, and mask layouts"); VMICapabilityResult destinationCapability = - capabilities.supportsUBPointerMemory( - op.getDestination().getType(), "destination", "pto.vscatter", - "pto.vscatter writes only UB"); + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vscatter", + "pto.vscatter writes only UB"); if (!destinationCapability.isSupported()) return fail(destinationCapability.reason); @@ -1421,9 +1631,8 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, return fail("requires contiguous result, passthru, and mask layouts"); std::string maskReason; - bool staticAllActive = - isStaticAllActiveMask(op.getMask(), resultType.getElementCount(), - &maskReason); + bool staticAllActive = isStaticAllActiveMask( + op.getMask(), resultType.getElementCount(), &maskReason); std::string fullChunkReason; if (staticAllActive && @@ -1435,16 +1644,16 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, std::string allActivePathReason; if (!staticAllActive) { - allActivePathReason = maskReason.empty() ? "requires static all-active mask" - : maskReason; + allActivePathReason = + maskReason.empty() ? "requires static all-active mask" : maskReason; } else { requireUnavailableReadFallback(accessPlan); allActivePathReason = (Twine("requires full physical chunks or statically safe full-read " "footprint; value ") + fullChunkReason + ", safe-read proof " + - accessPlan.safeReadProof.reason + "; fallback decision: " + - accessPlan.fallbackDecision.reason) + accessPlan.safeReadProof.reason + + "; fallback decision: " + accessPlan.fallbackDecision.reason) .str(); } @@ -1487,13 +1696,14 @@ checkSupportedExpandLoadShape(const VMITargetCapabilityRegistry &capabilities, return success(); } -LogicalResult checkSupportedMaskedStoreShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType valueType, - VMIMaskType maskType, Value destination, Type destinationType, - std::string *reason) { +LogicalResult +checkSupportedMaskedStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType valueType, VMIMaskType maskType, + Value destination, Type destinationType, + std::string *reason) { VMIMemoryAccessPlan accessPlan = - buildWriteAccessPlan(capabilities, destination, destinationType, valueType, - VMIMemoryWriteMaskKind::ExplicitMask); + buildWriteAccessPlan(capabilities, destination, destinationType, + valueType, VMIMemoryWriteMaskKind::ExplicitMask); if (!accessPlan.targetCapability.isSupported()) { if (reason) *reason = accessPlan.targetCapability.reason; @@ -1535,10 +1745,10 @@ LogicalResult checkSupportedMaskedStoreShape( maskType, &maskMaterializationReason); if (failed(maskParts)) return fail(Twine("mask cannot materialize to contiguous; mask ") + - maskReason + ", materialization " + - maskMaterializationReason); + maskReason + ", materialization " + maskMaterializationReason); if (*valueParts != *maskParts) - return fail("requires value/mask contiguous materialization arity to match"); + return fail( + "requires value/mask contiguous materialization arity to match"); return success(); } @@ -1561,8 +1771,7 @@ FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, if (failed(lanesPerPart)) return failure(); - FailureOr activeLanes = - getContiguousActiveDataLanes(vmiType, chunk); + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); if (failed(activeLanes)) return failure(); if (*activeLanes == *lanesPerPart) @@ -1572,10 +1781,8 @@ FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, getMaskTypeForVReg(vregType, rewriter.getContext()); if (failed(maskType)) return failure(); - FailureOr> maskAndRemaining = - createRuntimePrefixMask(loc, *maskType, - createI32Constant(loc, *activeLanes, rewriter), - rewriter); + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, *maskType, createI32Constant(loc, *activeLanes, rewriter), rewriter); if (failed(maskAndRemaining)) return failure(); return maskAndRemaining->first; @@ -1590,8 +1797,7 @@ FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, if (failed(lanesPerPart)) return failure(); - FailureOr activeLanes = - getContiguousActiveDataLanes(vmiType, chunk); + FailureOr activeLanes = getContiguousActiveDataLanes(vmiType, chunk); if (failed(activeLanes)) return failure(); if (*activeLanes == *lanesPerPart) @@ -1651,8 +1857,7 @@ computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { continue; FailureOr resultLogicalLane = - mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, - lane); + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); if (failed(resultLogicalLane) || *resultLogicalLane >= static_cast(indices.size())) return fail("failed to map result lane"); @@ -1694,7 +1899,7 @@ struct ShuffleVselrPlan { }; FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, - std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1721,7 +1926,8 @@ FailureOr computeShuffleLane0SplatSourcePart(VMIShuffleOp op, FailureOr> computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> FailureOr> { + auto fail = + [&](const Twine &message) -> FailureOr> { if (reason) *reason = message.str(); return failure(); @@ -1761,8 +1967,7 @@ computeShuffleVselrPlans(VMIShuffleOp op, std::string *reason) { return fail("requires full physical result chunks"); FailureOr resultLogicalLane = - mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, - lane); + mapPhysicalLaneToLogical(resultType, resultPart, resultChunk, lane); if (failed(resultLogicalLane) || *resultLogicalLane >= static_cast(indices.size())) return fail("failed to map result lane"); @@ -1873,6 +2078,81 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { return materializations; } +FailureOr> +computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr> { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant) + return fail("requires constant active_elems_per_group"); + auto activeAttr = dyn_cast(activeConstant.getValue()); + if (!activeAttr) + return fail("active_elems_per_group must be an integer constant"); + + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) + return fail("requires concrete layout and granularity"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(resultVMIType.getGranularity()); + if (failed(lanesPerPart)) + return fail("requires known physical mask lanes per part"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + resultVMIType.getElementCount() != numGroups * groupSize) + return fail("requires result lane count to match num_groups * group_size"); + + int64_t activeElems = activeAttr.getInt(); + if (activeElems < 0) + activeElems = 0; + if (activeElems > groupSize) + activeElems = groupSize; + + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + SmallVector materializations; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0;; ++chunk) { + bool anyLane = false; + ConstantMaskChunkMaterialization materialization; + materialization.activeLanes.reserve(*lanesPerPart); + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = + isPaddingLane(resultVMIType, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) { + materialization.activeLanes.push_back(0); + continue; + } + anyLane = true; + + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return fail("failed to map physical lane"); + int64_t laneInGroup = *logicalLane % groupSize; + materialization.activeLanes.push_back(laneInGroup < activeElems ? 1 + : 0); + } + if (!anyLane) + break; + materializations.push_back(std::move(materialization)); + } + } + + return materializations; +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -1897,19 +2177,16 @@ FailureOr materializePrefixMask(Location loc, MaskType maskType, if (pattern) return createPatternMask(loc, maskType, *pattern, rewriter); - FailureOr> maskAndRemaining = - createRuntimePrefixMask(loc, maskType, - createI32Constant(loc, activeLanes, rewriter), - rewriter); + FailureOr> maskAndRemaining = createRuntimePrefixMask( + loc, maskType, createI32Constant(loc, activeLanes, rewriter), rewriter); if (failed(maskAndRemaining)) return failure(); return maskAndRemaining->first; } -FailureOr -materializeConstantMaskChunk(Location loc, MaskType maskType, - ArrayRef activeLanes, - PatternRewriter &rewriter) { +FailureOr materializeConstantMaskChunk(Location loc, MaskType maskType, + ArrayRef activeLanes, + PatternRewriter &rewriter) { FailureOr lanesPerPart = getMaskLanesPerPart(maskType.getGranularity()); if (failed(lanesPerPart) || @@ -1952,10 +2229,10 @@ materializeConstantMaskChunk(Location loc, MaskType maskType, Value notPrefixBegin = rewriter.create(loc, maskType, *prefixBegin, *allTrue) .getResult(); - runMask = - rewriter.create(loc, maskType, *prefixEnd, notPrefixBegin, - *allTrue) - .getResult(); + runMask = rewriter + .create(loc, maskType, *prefixEnd, notPrefixBegin, + *allTrue) + .getResult(); } if (!result) { @@ -1996,12 +2273,9 @@ Value createGroupChunkOffset(Location loc, Value baseOffset, Value rowStride, return createChunkOffset(loc, offset, inGroupLaneOffset, rewriter); } -LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, - int64_t groupSize, - int64_t *lanesPerPart, - int64_t *groupCount, - int64_t *chunksPerGroup, - PatternRewriter &rewriter) { +LogicalResult checkContiguousFullGroupChunks( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t *lanesPerPart, + int64_t *groupCount, int64_t *chunksPerGroup, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; @@ -2027,19 +2301,15 @@ LogicalResult checkContiguousFullGroupChunks(Operation *op, VMIVRegType type, return success(); } -LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, - int64_t groupSize, - int64_t numGroups, - int64_t *lanesPerPart, - int64_t *groupCount, - PatternRewriter &rewriter) { +LogicalResult checkFullGroupSlotSourceShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t numGroups, + int64_t *lanesPerPart, int64_t *groupCount, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != numGroups) + if (!layout || !layout.isGroupSlots() || layout.getNumGroups() != numGroups) return fail("group slot op requires matching num_groups VMI layout"); if (failed(checkFullDataPhysicalChunks(type, nullptr))) return fail("group slot op requires full physical chunks"); @@ -2047,8 +2317,8 @@ LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, if (failed(lanes)) return fail("group slot op requires known physical lanes per part"); if (groupSize <= 0 || type.getElementCount() % groupSize != 0) - return fail( - "group slot op requires derived group size to evenly divide lane count"); + return fail("group slot op requires derived group size to evenly divide " + "lane count"); if (*lanes % groupSize != 0 && groupSize % *lanes != 0) return fail("group slot op requires group size to divide or be a " "multiple of physical lanes per part"); @@ -2058,13 +2328,9 @@ LogicalResult checkFullGroupSlotSourceShape(Operation *op, VMIVRegType type, return success(); } -LogicalResult checkFullGroupBroadcastResultShape(Operation *op, - VMIVRegType type, - int64_t groupSize, - int64_t lanesPerPart, - int64_t *layoutFactor, - int64_t *groupCount, - PatternRewriter &rewriter) { +LogicalResult checkFullGroupBroadcastResultShape( + Operation *op, VMIVRegType type, int64_t groupSize, int64_t lanesPerPart, + int64_t *layoutFactor, int64_t *groupCount, PatternRewriter &rewriter) { auto fail = [&](const Twine &message) { return rewriter.notifyMatchFailure(op, message); }; @@ -2076,8 +2342,7 @@ LogicalResult checkFullGroupBroadcastResultShape(Operation *op, return fail("group_broadcast result requires a dense VMI layout"); if (failed(checkFullDataPhysicalChunks(type, nullptr))) return fail("group_broadcast result requires full physical chunks"); - FailureOr resultLanes = - getDataLanesPerPart(type.getElementType()); + FailureOr resultLanes = getDataLanesPerPart(type.getElementType()); if (failed(resultLanes) || *resultLanes != lanesPerPart) return fail("group_broadcast result requires matching physical lanes"); if (groupSize <= 0 || type.getElementCount() % groupSize != 0) @@ -2092,9 +2357,13 @@ LogicalResult checkFullGroupBroadcastResultShape(Operation *op, return fail("group_broadcast contiguous result requires group size to " "divide or be a multiple of physical lanes per part"); } else { + bool blockFragmentSmallGroup = + layout.isDeinterleaved() && layout.getBlockElems() > 1 && + groupSize < lanesPerPart && lanesPerPart % layout.getBlockElems() == 0; int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; - if (groupSize < lanesPerPart || - groupSize % logicalSpanPerResultChunk != 0) + if (!blockFragmentSmallGroup && + (groupSize < lanesPerPart || + groupSize % logicalSpanPerResultChunk != 0)) return fail("group_broadcast deinterleaved result requires every " "physical result chunk to stay within one logical group"); } @@ -2111,8 +2380,9 @@ FailureOr createZeroVector(Location loc, VRegType type, FailureOr mask = createAllTrueMaskForVReg(loc, type, rewriter); if (failed(zero) || failed(mask)) return failure(); - return rewriter.create(loc, type, *zero, *mask, - /*position=*/nullptr) + return rewriter + .create(loc, type, *zero, *mask, + /*position=*/nullptr) .getResult(); } @@ -2121,8 +2391,7 @@ FailureOr createLaneRangeMask(Location loc, MaskType maskType, PatternRewriter &rewriter) { FailureOr lanesPerPart = getMaskLanesPerPart(maskType.getGranularity()); - if (failed(lanesPerPart) || begin < 0 || begin > end || - end > *lanesPerPart) + if (failed(lanesPerPart) || begin < 0 || begin > end || end > *lanesPerPart) return failure(); SmallVector active(*lanesPerPart, 0); for (int64_t lane = begin; lane < end; ++lane) @@ -2132,34 +2401,38 @@ FailureOr createLaneRangeMask(Location loc, MaskType maskType, FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, int64_t groupSize, + int64_t baseGroupSlot, PatternRewriter &rewriter) { int64_t lanesPerPart = indexType.getElementCount(); - FailureOr zero = - createZeroVector(loc, indexType, rewriter); - FailureOr maskType = getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr baseScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), baseGroupSlot, rewriter); + FailureOr maskType = + getMaskTypeForVReg(indexType, rewriter.getContext()); FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); - if (failed(zero) || failed(maskType) || failed(allMask)) + if (failed(baseScalar) || failed(maskType) || failed(allMask)) return failure(); + Value result = rewriter + .create(loc, indexType, *baseScalar, *allMask, + /*position=*/nullptr) + .getResult(); if (groupSize >= lanesPerPart) - return *zero; + return result; if (lanesPerPart % groupSize != 0) return failure(); - Value result = *zero; int64_t groupsPerChunk = lanesPerPart / groupSize; for (int64_t localGroup = 1; localGroup < groupsPerChunk; ++localGroup) { FailureOr groupScalar = createScalarOffsetConstant( - loc, indexType.getElementType(), localGroup, rewriter); + loc, indexType.getElementType(), baseGroupSlot + localGroup, rewriter); FailureOr laneMask = createLaneRangeMask(loc, *maskType, localGroup * groupSize, (localGroup + 1) * groupSize, rewriter); if (failed(groupScalar) || failed(laneMask)) return failure(); - Value splat = - rewriter - .create(loc, indexType, *groupScalar, *allMask, - /*position=*/nullptr) - .getResult(); + Value splat = rewriter + .create(loc, indexType, *groupScalar, *allMask, + /*position=*/nullptr) + .getResult(); result = rewriter.create(loc, indexType, splat, result, *laneMask) .getResult(); } @@ -2189,8 +2462,7 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, int64_t numGroups = sourceType.getElementCount() / groupSize; if (!sourceLayout || !resultLayout || !maskLayout || !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || - !maskLayout.isContiguous()) + resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) return fail("vcgadd group_reduce_addf path requires contiguous source/mask " "layouts and matching num_groups result layout"); std::string sourceFullReason; @@ -2211,6 +2483,135 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, return success(); } +LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("s16 block8 group_reduce_addf requires f32 source/result"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize) || *groupSize != 16) + return fail("s16 block8 group_reduce_addf requires group size 16"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !sourceLayout.isDeinterleaved() || + sourceLayout.getFactor() != 2 || + (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) + return fail("s16 group_reduce_addf requires source layout " + "deinterleaved=2 with block_elems=1 or block_elems=8"); + if (!maskLayout || !maskLayout.isDeinterleaved() || + maskLayout.getFactor() != 2 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s16 group_reduce_addf requires matching mask layout " + "deinterleaved=2 with the same block_elems"); + if (!resultLayout || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) + return fail("s16 block8 group_reduce_addf requires " + "group_slots(num_groups, slots=8) result layout"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("s16 block8 group_reduce_addf requires computable physical " + "arity"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || + *maskArity != *sourceArity) + return fail("s16 block8 group_reduce_addf requires two source/mask " + "parts per result part"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = sourceLayout.getBlockElems() == 1 + ? "s16_reduce_parity" + : "s16_reduce_block8"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + + return success(); +} + +LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("s32 block8 group_reduce_addf requires f32 source/result"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize) || *groupSize != 32) + return fail("s32 block8 group_reduce_addf requires group size 32"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !sourceLayout.isDeinterleaved() || + sourceLayout.getFactor() != 4 || + (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) + return fail("s32 group_reduce_addf requires source layout " + "deinterleaved=4 with block_elems=1 or block_elems=8"); + if (!maskLayout || !maskLayout.isDeinterleaved() || + maskLayout.getFactor() != 4 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s32 group_reduce_addf requires matching mask layout " + "deinterleaved=4 with the same block_elems"); + if (!resultLayout || !resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) + return fail("s32 block8 group_reduce_addf requires " + "group_slots(num_groups, slots=8) result layout"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("s32 block8 group_reduce_addf requires computable physical " + "arity"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || + *maskArity != *sourceArity) + return fail("s32 block8 group_reduce_addf requires four source/mask " + "parts per result part"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = sourceLayout.getBlockElems() == 1 + ? "s32_reduce_dintlv4" + : "s32_reduce_block8_stride"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + return success(); +} + std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -2289,7 +2690,8 @@ struct OneToNVMIPackOpPattern : OneToNOpConversionPattern { } }; -LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts, +LogicalResult verifyIdentityPartForwarding(Operation *op, + ValueRange sourceParts, TypeRange resultTypes, PatternRewriter &rewriter) { if (sourceParts.size() != resultTypes.size()) @@ -2303,12 +2705,10 @@ LogicalResult verifyIdentityPartForwarding(Operation *op, ValueRange sourceParts return success(); } -FailureOr> -materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, - TypeRange resultTypes, - VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - PatternRewriter &rewriter) { +FailureOr> materializeDataLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { if (!sourceLayout || !resultLayout) { (void)rewriter.notifyMatchFailure( op, "layout materialization requires assigned source/result layouts"); @@ -2332,8 +2732,7 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 2 != 0) { (void)rewriter.notifyMatchFailure( - op, - "deinterleaved=2 layout materialization requires 2*N parts"); + op, "deinterleaved=2 layout materialization requires 2*N parts"); return failure(); } if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, @@ -2378,8 +2777,7 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 4 != 0) { (void)rewriter.notifyMatchFailure( - op, - "deinterleaved=4 layout materialization requires 4*N parts"); + op, "deinterleaved=4 layout materialization requires 4*N parts"); return failure(); } if (failed(verifyIdentityPartForwarding(op, sourceParts, resultTypes, @@ -2395,20 +2793,16 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, Value p1 = sourceParts[groups + i]; Value p2 = sourceParts[2 * groups + i]; Value p3 = sourceParts[3 * groups + i]; - auto even = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], p0, p2); - auto odd = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], p1, p3); - auto low = - rewriter.create(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], even.getLow(), - odd.getLow()); - auto high = - rewriter.create(op->getLoc(), resultTypes[4 * i + 2], - resultTypes[4 * i + 3], even.getHigh(), - odd.getHigh()); + auto even = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p0, p2); + auto odd = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], p1, p3); + auto low = rewriter.create(op->getLoc(), resultTypes[4 * i], + resultTypes[4 * i + 1], + even.getLow(), odd.getLow()); + auto high = rewriter.create( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even.getHigh(), odd.getHigh()); results.append( {low.getLow(), low.getHigh(), high.getLow(), high.getHigh()}); } @@ -2422,21 +2816,19 @@ materializeDataLayoutConversion(Operation *op, ValueRange sourceParts, part2.reserve(groups); part3.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - auto low = - rewriter.create(op->getLoc(), resultTypes[i], - resultTypes[groups + i], - sourceParts[4 * i], - sourceParts[4 * i + 1]); + auto low = rewriter.create( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1]); auto high = rewriter.create( op->getLoc(), resultTypes[2 * groups + i], resultTypes[3 * groups + i], sourceParts[4 * i + 2], sourceParts[4 * i + 3]); - auto even = rewriter.create( - op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], - low.getLow(), high.getLow()); + auto even = rewriter.create(op->getLoc(), resultTypes[i], + resultTypes[2 * groups + i], + low.getLow(), high.getLow()); auto odd = rewriter.create( - op->getLoc(), resultTypes[groups + i], - resultTypes[3 * groups + i], low.getHigh(), high.getHigh()); + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low.getHigh(), high.getHigh()); part0.push_back(even.getLow()); part1.push_back(odd.getLow()); part2.push_back(even.getHigh()); @@ -2497,12 +2889,10 @@ createPredicateIntlv(Location loc, Type lowType, Type highType, Value lhs, return failure(); } -FailureOr> -materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, - TypeRange resultTypes, - VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - PatternRewriter &rewriter) { +FailureOr> materializeMaskLayoutConversion( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, + PatternRewriter &rewriter) { if (!sourceLayout || !resultLayout) { (void)rewriter.notifyMatchFailure( op, "mask layout materialization requires assigned source/result " @@ -2540,10 +2930,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, results.reserve(sourceParts.size()); if (deint2ToContiguous) { for (int64_t i = 0; i < groups; ++i) { - FailureOr> materialize = - createPredicateIntlv(op->getLoc(), resultTypes[2 * i], - resultTypes[2 * i + 1], sourceParts[i], - sourceParts[groups + i], rewriter); + FailureOr> materialize = createPredicateIntlv( + op->getLoc(), resultTypes[2 * i], resultTypes[2 * i + 1], + sourceParts[i], sourceParts[groups + i], rewriter); if (failed(materialize)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); @@ -2555,10 +2944,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, part0.reserve(groups); part1.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - FailureOr> materialize = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[groups + i], sourceParts[2 * i], - sourceParts[2 * i + 1], rewriter); + FailureOr> materialize = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[2 * i], sourceParts[2 * i + 1], rewriter); if (failed(materialize)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); @@ -2607,14 +2995,12 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, if (failed(even) || failed(odd)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); - FailureOr> low = - createPredicateIntlv(op->getLoc(), resultTypes[4 * i], - resultTypes[4 * i + 1], even->first, - odd->first, rewriter); - FailureOr> high = - createPredicateIntlv(op->getLoc(), resultTypes[4 * i + 2], - resultTypes[4 * i + 3], even->second, - odd->second, rewriter); + FailureOr> low = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i], resultTypes[4 * i + 1], + even->first, odd->first, rewriter); + FailureOr> high = createPredicateIntlv( + op->getLoc(), resultTypes[4 * i + 2], resultTypes[4 * i + 3], + even->second, odd->second, rewriter); if (failed(low) || failed(high)) return rewriter.notifyMatchFailure( op, "unsupported predicate intlv mask type"); @@ -2630,11 +3016,9 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, part2.reserve(groups); part3.reserve(groups); for (int64_t i = 0; i < groups; ++i) { - FailureOr> low = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[groups + i], - sourceParts[4 * i], sourceParts[4 * i + 1], - rewriter); + FailureOr> low = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[groups + i], + sourceParts[4 * i], sourceParts[4 * i + 1], rewriter); FailureOr> high = createPredicateDintlv( op->getLoc(), resultTypes[2 * groups + i], resultTypes[3 * groups + i], sourceParts[4 * i + 2], @@ -2642,14 +3026,12 @@ materializeMaskLayoutConversion(Operation *op, ValueRange sourceParts, if (failed(low) || failed(high)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); - FailureOr> even = - createPredicateDintlv(op->getLoc(), resultTypes[i], - resultTypes[2 * groups + i], low->first, - high->first, rewriter); - FailureOr> odd = - createPredicateDintlv(op->getLoc(), resultTypes[groups + i], - resultTypes[3 * groups + i], low->second, - high->second, rewriter); + FailureOr> even = createPredicateDintlv( + op->getLoc(), resultTypes[i], resultTypes[2 * groups + i], + low->first, high->first, rewriter); + FailureOr> odd = createPredicateDintlv( + op->getLoc(), resultTypes[groups + i], resultTypes[3 * groups + i], + low->second, high->second, rewriter); if (failed(even) || failed(odd)) return rewriter.notifyMatchFailure( op, "unsupported predicate dintlv mask type"); @@ -2760,19 +3142,17 @@ FailureOr> materializeAdjacentMaskGranularityConversion( for (int64_t chunk = 0; chunk < *sourceChunks && produced < *resultChunks; ++chunk) { Value source = sourceParts[sourceOffset + chunk]; - results.push_back( - rewriter - .create(op->getLoc(), resultMaskType, source, - partAttr("LOWER")) - .getResult()); + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("LOWER")) + .getResult()); ++produced; if (produced >= *resultChunks) break; - results.push_back( - rewriter - .create(op->getLoc(), resultMaskType, source, - partAttr("HIGHER")) - .getResult()); + results.push_back(rewriter + .create(op->getLoc(), resultMaskType, + source, partAttr("HIGHER")) + .getResult()); ++produced; } if (produced != *resultChunks) @@ -2786,18 +3166,16 @@ FailureOr> materializeAdjacentMaskGranularityConversion( return fail("narrowing mask granularity conversion ran out of " "source chunks"); Value lowerSource = sourceParts[sourceOffset + consumed++]; - Value packed = - rewriter - .create(op->getLoc(), resultMaskType, lowerSource, - partAttr("LOWER")) - .getResult(); + Value packed = rewriter + .create(op->getLoc(), resultMaskType, + lowerSource, partAttr("LOWER")) + .getResult(); if (consumed < *sourceChunks) { Value higherSource = sourceParts[sourceOffset + consumed++]; - Value higher = - rewriter - .create(op->getLoc(), resultMaskType, higherSource, - partAttr("HIGHER")) - .getResult(); + Value higher = rewriter + .create(op->getLoc(), resultMaskType, + higherSource, partAttr("HIGHER")) + .getResult(); if (!allTrue) { FailureOr mask = createAllTrueMask(op->getLoc(), resultMaskType, rewriter); @@ -2832,9 +3210,8 @@ FailureOr> materializeMaskGranularityConversion( VMIMaskType sourceType, VMIMaskType resultType, ValueRange sourceParts, PatternRewriter &rewriter) { std::string reason; - if (failed(checkSupportedMaskGranularityMaterialization(capabilities, - sourceType, - resultType, &reason))) { + if (failed(checkSupportedMaskGranularityMaterialization( + capabilities, sourceType, resultType, &reason))) { (void)rewriter.notifyMatchFailure(op, reason); return failure(); } @@ -2856,8 +3233,8 @@ FailureOr> materializeMaskGranularityConversion( VMIMaskType::get(op->getContext(), currentType.getElementCount(), nextGranularity, currentType.getLayoutAttr()); FailureOr> nextParts = - materializeAdjacentMaskGranularityConversion( - op, currentType, nextType, currentParts, rewriter); + materializeAdjacentMaskGranularityConversion(op, currentType, nextType, + currentParts, rewriter); if (failed(nextParts)) return failure(); currentType = nextType; @@ -2942,9 +3319,8 @@ struct OneToNVMIEnsureMaskGranularityOpPattern TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceType.getGranularity() != resultType.getGranularity()) { FailureOr> results = - materializeMaskGranularityConversion(op, capabilities, sourceType, - resultType, sourceParts, - rewriter); + materializeMaskGranularityConversion( + op, capabilities, sourceType, resultType, sourceParts, rewriter); if (failed(results)) return failure(); if (results->size() != resultTypes.size()) @@ -2969,8 +3345,7 @@ struct OneToNVMIEnsureMaskGranularityOpPattern const VMITargetCapabilityRegistry &capabilities; }; -struct OneToNVMIBroadcastOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIBroadcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -2988,8 +3363,7 @@ struct OneToNVMIBroadcastOpPattern for (Type resultType : resultTypes) { auto vregType = dyn_cast(resultType); if (!vregType) - return rewriter.notifyMatchFailure(op, - "broadcast result must be vreg"); + return rewriter.notifyMatchFailure(op, "broadcast result must be vreg"); FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); if (failed(mask)) @@ -2997,11 +3371,10 @@ struct OneToNVMIBroadcastOpPattern op, "unsupported element type for broadcast mask"); StringAttr position = inputIsVReg ? rewriter.getStringAttr("LOWEST") : StringAttr{}; - results.push_back( - rewriter - .create(op.getLoc(), resultType, inputParts.front(), - *mask, position) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + inputParts.front(), *mask, position) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3019,18 +3392,15 @@ FailureOr createScalarOffsetConstant(Location loc, Type type, } if (auto floatType = dyn_cast(type)) { return rewriter - .create(loc, - rewriter.getFloatAttr(floatType, - static_cast( - value))) + .create( + loc, rewriter.getFloatAttr(floatType, static_cast(value))) .getResult(); } return failure(); } FailureOr createIotaChunkBase(Location loc, Value base, - int64_t laneOffset, - StringRef order, + int64_t laneOffset, StringRef order, PatternRewriter &rewriter) { if (laneOffset == 0) return base; @@ -3078,8 +3448,8 @@ FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, return failure(); FailureOr mask = createAllTrueMaskForVReg(loc, vregType, rewriter); - FailureOr zero = createScalarOffsetConstant(loc, base.getType(), 0, - rewriter); + FailureOr zero = + createScalarOffsetConstant(loc, base.getType(), 0, rewriter); FailureOr factorScalar = createScalarOffsetConstant(loc, base.getType(), factor, rewriter); if (failed(mask) || failed(zero) || failed(factorScalar)) @@ -3099,11 +3469,10 @@ FailureOr createIotaDeinterleavedChunk(Location loc, Type resultType, return failure(); if (order == "DESC") { - Value baseVector = - rewriter - .create(loc, resultType, *biasedBase, *mask, - /*position=*/nullptr) - .getResult(); + Value baseVector = rewriter + .create(loc, resultType, *biasedBase, *mask, + /*position=*/nullptr) + .getResult(); return rewriter.create(loc, resultType, baseVector, scaled, *mask) .getResult(); } @@ -3121,8 +3490,7 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); if (!layout) - return rewriter.notifyMatchFailure(op, - "iota requires assigned layout"); + return rewriter.notifyMatchFailure(op, "iota requires assigned layout"); FailureOr lanesPerPart = getDataLanesPerPart(resultVMIType.getElementType()); @@ -3130,9 +3498,8 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "iota requires known physical lanes per part"); - FailureOr base = - getSingleValue(op, adaptor.getBase(), - "iota base must convert to one value", rewriter); + FailureOr base = getSingleValue( + op, adaptor.getBase(), "iota base must convert to one value", rewriter); if (failed(base)) return failure(); @@ -3167,8 +3534,8 @@ struct OneToNVMIIotaOpPattern : OneToNOpConversionPattern { for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { Type resultType = resultTypes[part * chunksPerPart + chunk]; FailureOr result = createIotaDeinterleavedChunk( - op.getLoc(), resultType, *base, factor, part, chunk, - *lanesPerPart, op.getOrderAttr(), rewriter); + op.getLoc(), resultType, *base, factor, part, chunk, *lanesPerPart, + op.getOrderAttr(), rewriter); if (failed(result)) return rewriter.notifyMatchFailure( op, "failed to materialize deinterleaved iota chunk"); @@ -3193,8 +3560,7 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { op, "only splat dense data constants are supported"); auto splatAttr = dyn_cast(denseAttr.getSplatValue()); if (!splatAttr) - return rewriter.notifyMatchFailure(op, - "splat constant must be typed"); + return rewriter.notifyMatchFailure(op, "splat constant must be typed"); Value scalar = rewriter.create(op.getLoc(), splatAttr).getResult(); @@ -3210,11 +3576,11 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for constant mask"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, scalar, *mask, - /*position=*/nullptr) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, scalar, + *mask, + /*position=*/nullptr) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3224,8 +3590,7 @@ struct OneToNVMIConstantOpPattern : OneToNOpConversionPattern { struct OneToNVMIConstantMaskOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIConstantMaskOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIConstantMaskOp op, OpAdaptor adaptor, @@ -3235,8 +3600,7 @@ struct OneToNVMIConstantMaskOpPattern FailureOr> materializations = computeConstantMaskMaterialization(op, &reason); if (failed(materializations)) - return rewriter.notifyMatchFailure( - op, Twine("constant_mask ") + reason); + return rewriter.notifyMatchFailure(op, Twine("constant_mask ") + reason); SmallVector results; results.reserve(resultTypes.size()); @@ -3276,8 +3640,8 @@ struct OneToNVMICreateMaskOpPattern op.getActiveLanes().getDefiningOp(); auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); - if (!layout || !VMIMaskType::isConcreteGranularity( - resultVMIType.getGranularity())) + if (!layout || + !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) return rewriter.notifyMatchFailure( op, "create_mask requires concrete layout and granularity"); FailureOr lanesPerPart = @@ -3306,8 +3670,8 @@ struct OneToNVMICreateMaskOpPattern SmallVector results; results.reserve(resultTypes.size()); for (int64_t part = 0; part < factor; ++part) { - Value remaining = createPartitionActiveLanes( - op.getLoc(), activeI32, factor, part, rewriter); + Value remaining = createPartitionActiveLanes(op.getLoc(), activeI32, + factor, part, rewriter); for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { Type resultType = resultTypes[part * chunksPerPart + chunk]; auto maskType = dyn_cast(resultType); @@ -3408,6 +3772,49 @@ struct OneToNVMICreateMaskOpPattern } }; +struct OneToNVMICreateGroupMaskOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMICreateGroupMaskOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMICreateGroupMaskOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + std::string reason; + FailureOr> materializations = + computeGroupMaskMaterialization(op, &reason); + if (failed(materializations)) + return rewriter.notifyMatchFailure(op, + Twine("create_group_mask ") + reason); + + SmallVector results; + results.reserve(resultTypes.size()); + for (const ConstantMaskChunkMaterialization &materialization : + *materializations) { + if (results.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many physical masks"); + auto maskType = dyn_cast(resultTypes[results.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask physical chunk"); + results.push_back(*mask); + } + + if (results.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask physical result count mismatch"); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -3423,8 +3830,12 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { "load offset must convert to one value", rewriter); if (failed(source) || failed(offset)) return failure(); + std::optional explicitFullReadElems; + if (auto attr = op.getFullReadElemsAttr()) + explicitFullReadElems = attr.getInt(); FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( - op, resultVMIType, (*source).getType(), *offset, rewriter); + op, resultVMIType, op.getSource().getType(), *offset, rewriter, + explicitFullReadElems); if (failed(lanesPerPart)) return failure(); @@ -3448,9 +3859,9 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { op, "vldsx2 requires matching low/high result types"); Value chunkOffset = createChunkOffset( op.getLoc(), *offset, group * 2 * *lanesPerPart, rewriter); - auto load = rewriter.create( - op.getLoc(), lowType, highType, *source, chunkOffset, - rewriter.getStringAttr(*dist)); + auto load = rewriter.create(op.getLoc(), lowType, highType, + *source, chunkOffset, + rewriter.getStringAttr(*dist)); lows.push_back(load.getLow()); highs.push_back(load.getHigh()); } @@ -3469,14 +3880,14 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { auto vregType = dyn_cast(resultType); if (!vregType) return rewriter.notifyMatchFailure(op, "load result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); - contiguousParts.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + contiguousParts.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, + *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); } FailureOr> results = materializeDataLayoutConversion( @@ -3491,10 +3902,8 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } }; -struct OneToNVMIGroupLoadOpPattern - : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupLoadOp>::OneToNOpConversionPattern; +struct OneToNVMIGroupLoadOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIGroupLoadOp op, OpAdaptor adaptor, @@ -3502,19 +3911,103 @@ struct OneToNVMIGroupLoadOpPattern auto resultVMIType = cast(op.getResult().getType()); FailureOr source = getSingleValue(op, adaptor.getSource(), - "group_load source must convert to one value", - rewriter); + "group_load source must convert to one value", rewriter); FailureOr offset = getSingleValue(op, adaptor.getOffset(), - "group_load offset must convert to one value", - rewriter); - FailureOr rowStride = - getSingleValue(op, adaptor.getRowStride(), - "group_load row_stride must convert to one value", - rewriter); + "group_load offset must convert to one value", rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_load row_stride must convert to one value", rewriter); if (failed(source) || failed(offset) || failed(rowStride)) return failure(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 8 && + resultVMIType.getElementType().isF32()) { + FailureOr groupSize = getGroupSizeFromNumGroups( + resultVMIType, op.getNumGroupsAttr().getInt()); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_load requires num_groups to evenly divide lane count"); + if ((*groupSize != 16 || resultLayout.getFactor() != 2) && + (*groupSize != 32 || resultLayout.getFactor() != 4)) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires S=16/factor=2 or S=32/factor=4"); + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires num_groups multiple of 8"); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride <= 0 || + *constantRowStride % 8 != 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires constant positive row_stride " + "divisible by 8 f32 elements"); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t factor = resultLayout.getFactor(); + FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); + if (failed(chunksPerPart) || *chunksPerPart <= 0) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires known chunks per part"); + for (int64_t part = 1; part < factor; ++part) { + FailureOr currentChunks = + getDataChunksInPart(resultVMIType, part); + if (failed(currentChunks) || *currentChunks != *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "block8 group_load requires uniform chunks per part"); + } + if (static_cast(resultTypes.size()) != factor * *chunksPerPart) + return rewriter.notifyMatchFailure(op, + "block8 group_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op.getLoc(), value, 16); + }; + Value blockStride = makeI16(*constantRowStride / 8); + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op.getLoc(), (*source).getType(), *source, + elementOffset) + .getResult(); + }; + + SmallVector results; + results.reserve(resultTypes.size()); + constexpr int64_t kGroupsPerBlock8Load = 8; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + int64_t flatIndex = part * *chunksPerPart + chunk; + auto vregType = dyn_cast(resultTypes[flatIndex]); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "block8 group_load result must be vreg"); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create block8 group_load mask"); + Value chunkOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, chunk * kGroupsPerBlock8Load, + part * resultLayout.getBlockElems(), rewriter); + Value chunkBase = makePtr(chunkOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, + chunkBase, blockStride, + zeroI16, *allMask) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; @@ -3523,14 +4016,13 @@ struct OneToNVMIGroupLoadOpPattern if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_load requires num_groups to evenly divide lane count"); - if (failed(checkContiguousFullGroupChunks( - op, resultVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) + if (failed(checkContiguousFullGroupChunks(op, resultVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - if (static_cast(resultTypes.size()) != - groupCount * chunksPerGroup) + if (static_cast(resultTypes.size()) != groupCount * chunksPerGroup) return rewriter.notifyMatchFailure(op, "group_load arity mismatch"); SmallVector results; @@ -3542,15 +4034,156 @@ struct OneToNVMIGroupLoadOpPattern "group_load result must be vreg"); int64_t group = index / chunksPerGroup; int64_t chunkInGroup = index % chunksPerGroup; - Value chunkOffset = createGroupChunkOffset( - op.getLoc(), *offset, *rowStride, group, - chunkInGroup * lanesPerPart, rewriter); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIGroupSlotLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupSlotLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupSlotLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires explicit group_slots layout"); + + FailureOr source = getSingleValue( + op, adaptor.getSource(), + "group_slot_load source must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "group_slot_load offset must convert to one value", rewriter); + FailureOr sourceGroupStride = getSingleValue( + op, adaptor.getSourceGroupStride(), + "group_slot_load source_group_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(sourceGroupStride)) + return failure(); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t slots = layout.getSlots(); + int64_t expectedArity = ceilDivNonNegative(numGroups, slots); + if (static_cast(resultTypes.size()) != expectedArity) + return rewriter.notifyMatchFailure(op, "group_slot_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op.getLoc(), value, 16); + }; + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op.getLoc(), (*source).getType(), *source, + elementOffset) + .getResult(); + }; + + SmallVector results; + results.reserve(resultTypes.size()); + + if (slots == 8) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load requires constant unit stride"); + if (resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load expects one physical result"); + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType) + return rewriter.notifyMatchFailure( + op, "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(resultType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_slot_load mask"); + Value slotBase = makePtr(*offset); + results.push_back(rewriter + .create(op.getLoc(), resultType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (slots != 1) + return rewriter.notifyMatchFailure( + op, "group_slot_load supports only slots=8 or slots=1"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_slot_load requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantStride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!constantStride || *constantStride <= 0 || + *constantStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_slot_load requires constant positive " + "source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsldb alignment"); + + for (auto [group, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_slot_load mask"); + Value groupOffset = *offset; + if (group != 0) { + Value groupIndex = + rewriter.create(op.getLoc(), group); + Value rowOffset = rewriter + .create( + op.getLoc(), *sourceGroupStride, groupIndex) + .getResult(); + groupOffset = + rewriter.create(op.getLoc(), groupOffset, rowOffset) + .getResult(); + } + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3560,21 +4193,18 @@ struct OneToNVMIGroupLoadOpPattern struct OneToNVMIMaskedLoadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIMaskedLoadOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIMaskedLoadOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto resultVMIType = cast(op.getResult().getType()); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "masked_load source must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "masked_load offset must convert to one value", - rewriter); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "masked_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "masked_load offset must convert to one value", + rewriter); if (failed(source) || failed(offset)) return failure(); @@ -3588,22 +4218,21 @@ struct OneToNVMIMaskedLoadOpPattern TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (maskParts.size() != passthruParts.size() || passthruParts.size() != resultTypes.size()) - return rewriter.notifyMatchFailure( - op, "masked_load physical arity mismatch"); + return rewriter.notifyMatchFailure(op, + "masked_load physical arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); - for (auto [index, maskPassthruAndType] : - llvm::enumerate(llvm::zip_equal(maskParts, passthruParts, - resultTypes))) { + for (auto [index, maskPassthruAndType] : llvm::enumerate( + llvm::zip_equal(maskParts, passthruParts, resultTypes))) { auto [mask, passthru, resultType] = maskPassthruAndType; if (!isa(mask.getType()) || passthru.getType() != resultType || !isa(resultType)) return rewriter.notifyMatchFailure( op, "masked_load physical part type mismatch"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); Value loaded = rewriter .create(op.getLoc(), resultType, @@ -3645,22 +4274,19 @@ struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { SmallVector results; results.reserve(resultTypes.size()); for (auto [indices, mask, passthru, resultType] : - llvm::zip_equal(indicesParts, maskParts, passthruParts, - resultTypes)) { + llvm::zip_equal(indicesParts, maskParts, passthruParts, resultTypes)) { if (!isa(indices.getType()) || !isa(mask.getType()) || passthru.getType() != resultType || !isa(resultType)) - return rewriter.notifyMatchFailure(op, - "gather physical part type mismatch"); + return rewriter.notifyMatchFailure( + op, "gather physical part type mismatch"); - Value gathered = - rewriter - .create(op.getLoc(), resultType, *source, indices, - mask) - .getResult(); + Value gathered = rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult(); results.push_back( rewriter - .create(op.getLoc(), resultType, gathered, passthru, - mask) + .create(op.getLoc(), resultType, gathered, passthru, mask) .getResult()); } @@ -3671,21 +4297,18 @@ struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { struct OneToNVMIExpandLoadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIExpandLoadOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIExpandLoadOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto resultVMIType = cast(op.getResult().getType()); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "expand_load source must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "expand_load offset must convert to one value", - rewriter); + FailureOr source = getSingleValue( + op, adaptor.getSource(), "expand_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "expand_load offset must convert to one value", + rewriter); if (failed(source) || failed(offset)) return failure(); @@ -3700,16 +4323,16 @@ struct OneToNVMIExpandLoadOpPattern results.reserve(resultTypes.size()); for (auto [index, resultType] : llvm::enumerate(resultTypes)) { if (!isa(resultType)) - return rewriter.notifyMatchFailure( - op, "expand_load result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + return rewriter.notifyMatchFailure(op, + "expand_load result must be vreg"); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + /*dist=*/nullptr) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -3725,7 +4348,8 @@ struct OneToNVMIExpandLoadOpPattern auto resultType = dyn_cast(resultTypes.front()); auto maskType = dyn_cast(maskParts.front().getType()); - if (!resultType || !maskType || passthruParts.front().getType() != resultType) + if (!resultType || !maskType || + passthruParts.front().getType() != resultType) return rewriter.notifyMatchFailure( op, "runtime expand_load requires physical result/passthru/mask"); @@ -3733,11 +4357,10 @@ struct OneToNVMIExpandLoadOpPattern if (!baseType) return rewriter.notifyMatchFailure(op, "runtime expand_load requires ptr"); - Value gatherBase = - rewriter - .create(op.getLoc(), (*source).getType(), *source, - *offset) - .getResult(); + Value gatherBase = rewriter + .create(op.getLoc(), (*source).getType(), + *source, *offset) + .getResult(); auto indexType = VRegType::get(rewriter.getContext(), resultType.getElementCount(), rewriter.getI32Type()); @@ -3754,19 +4377,17 @@ struct OneToNVMIExpandLoadOpPattern .getResult(); Value indices = rewriter - .create(op.getLoc(), indexType, carrier, - maskParts.front()) + .create(op.getLoc(), indexType, carrier, maskParts.front()) .getResult(); Value gathered = rewriter .create(op.getLoc(), resultType, gatherBase, indices, maskParts.front()) .getResult(); - Value result = - rewriter - .create(op.getLoc(), resultType, gathered, - passthruParts.front(), maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, gathered, + passthruParts.front(), maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -3854,18 +4475,16 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { if (*activeLanes == 0) continue; } - FailureOr mask = fullPhysicalChunks - ? createAllTrueMaskForVReg(op.getLoc(), - vregType, rewriter) - : createContiguousStoreMask(op.getLoc(), - valueVMIType, - index, vregType, - rewriter); + FailureOr mask = + fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), valueVMIType, index, + vregType, rewriter); if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for store mask"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -3878,44 +4497,133 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { struct OneToNVMIGroupStoreOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupStoreOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIGroupStoreOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto valueVMIType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueVMIType.getLayoutAttr(); + + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "group_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "group_store offset must convert to one value", + rewriter); + FailureOr rowStride = getSingleValue( + op, adaptor.getRowStride(), + "group_store row_stride must convert to one value", rewriter); + if (failed(destination) || failed(offset) || failed(rowStride)) + return failure(); + + if (layout && layout.isGroupSlots() && layout.getSlots() == 1 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != layout.getNumGroups()) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store arity mismatch"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride <= 0 || + *constantRowStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_store requires constant positive " + "row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsts alignment"); + + for (auto [group, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + FailureOr mask = + createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=1 group_store mask"); + Value groupOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } + + if (layout && layout.isGroupSlots() && layout.getSlots() == 8 && + layout.getNumGroups() == op.getNumGroupsAttr().getInt()) { + int64_t numGroups = layout.getNumGroups(); + std::optional constantRowStride = + getConstantIndexValue(op.getRowStride()); + if (!constantRowStride || *constantRowStride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store requires constant unit row_stride"); + + ValueRange valueParts = adaptor.getValue(); + if (static_cast(valueParts.size()) != + ceilDivNonNegative(numGroups, 8)) + return rewriter.notifyMatchFailure( + op, "slots=8 group_store arity mismatch"); + + for (auto [slotBlock, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_store mask"); + int64_t activeGroups = std::min(8, numGroups - slotBlock * 8); + FailureOr mask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, activeGroups, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_store mask"); + Value groupOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, slotBlock * 8, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + groupOffset, /*dist=*/nullptr, *mask); + } + + rewriter.eraseOp(op); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; - FailureOr groupSize = getGroupSizeFromNumGroups( - valueVMIType, op.getNumGroupsAttr().getInt()); + FailureOr groupSize = + getGroupSizeFromNumGroups(valueVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_store requires num_groups to evenly divide lane count"); - if (failed(checkContiguousFullGroupChunks( - op, valueVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) - return failure(); - - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "group_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "group_store offset must convert to one value", - rewriter); - FailureOr rowStride = - getSingleValue(op, adaptor.getRowStride(), - "group_store row_stride must convert to one value", - rewriter); - if (failed(destination) || failed(offset) || failed(rowStride)) + if (failed(checkContiguousFullGroupChunks(op, valueVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); ValueRange valueParts = adaptor.getValue(); - if (static_cast(valueParts.size()) != - groupCount * chunksPerGroup) + if (static_cast(valueParts.size()) != groupCount * chunksPerGroup) return rewriter.notifyMatchFailure(op, "group_store arity mismatch"); for (auto [index, value] : llvm::enumerate(valueParts)) { @@ -3930,9 +4638,9 @@ struct OneToNVMIGroupStoreOpPattern op, "unsupported element type for group_store mask"); int64_t group = index / chunksPerGroup; int64_t chunkInGroup = index % chunksPerGroup; - Value chunkOffset = createGroupChunkOffset( - op.getLoc(), *offset, *rowStride, group, - chunkInGroup * lanesPerPart, rewriter); + Value chunkOffset = + createGroupChunkOffset(op.getLoc(), *offset, *rowStride, group, + chunkInGroup * lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -3945,8 +4653,7 @@ struct OneToNVMIGroupStoreOpPattern struct OneToNVMIMaskedStoreOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIMaskedStoreOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIMaskedStoreOp op, OpAdaptor adaptor, @@ -3958,14 +4665,12 @@ struct OneToNVMIMaskedStoreOpPattern return rewriter.notifyMatchFailure( op, "masked_store requires known physical lanes per part"); - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "masked_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "masked_store offset must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "masked_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "masked_store offset must convert to one value", rewriter); if (failed(destination) || failed(offset)) return failure(); @@ -4019,8 +4724,8 @@ struct OneToNVMIMaskedStoreOpPattern if (failed(storeMask)) return rewriter.notifyMatchFailure( op, "failed to materialize masked_store predicate"); - Value chunkOffset = createChunkOffset( - op.getLoc(), *offset, index * *lanesPerPart, rewriter); + Value chunkOffset = createChunkOffset(op.getLoc(), *offset, + index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *storeMask); @@ -4037,10 +4742,9 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { LogicalResult matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "scatter destination must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "scatter destination must convert to one value", rewriter); if (failed(destination)) return failure(); @@ -4049,8 +4753,7 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { ValueRange maskParts = adaptor.getMask(); if (valueParts.size() != indicesParts.size() || valueParts.size() != maskParts.size()) - return rewriter.notifyMatchFailure(op, - "scatter physical arity mismatch"); + return rewriter.notifyMatchFailure(op, "scatter physical arity mismatch"); for (auto [value, indices, mask] : llvm::zip_equal(valueParts, indicesParts, maskParts)) { @@ -4067,8 +4770,7 @@ struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { } }; -struct OneToNVMITileReadOpPattern - : OneToNOpConversionPattern { +struct OneToNVMITileReadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -4107,9 +4809,9 @@ struct OneToNVMITileReadOpPattern op, "vldsx2 requires matching low/high result types"); Value chunkOffset = createChunkOffset( op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); - auto load = rewriter.create( - op.getLoc(), lowType, highType, *source, chunkOffset, - rewriter.getStringAttr(*dist)); + auto load = rewriter.create(op.getLoc(), lowType, highType, + *source, chunkOffset, + rewriter.getStringAttr(*dist)); lows.push_back(load.getLow()); highs.push_back(load.getHigh()); } @@ -4128,14 +4830,14 @@ struct OneToNVMITileReadOpPattern auto vregType = dyn_cast(resultType); if (!vregType) return rewriter.notifyMatchFailure(op, "tile_read result must be vreg"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, index * *lanesPerPart, rewriter); - contiguousParts.push_back( - rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); + Value chunkOffset = + createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); + contiguousParts.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, + *source, chunkOffset, + /*dist=*/nullptr) + .getResult()); } FailureOr> results = materializeDataLayoutConversion( @@ -4150,8 +4852,7 @@ struct OneToNVMITileReadOpPattern } }; -struct OneToNVMITileWriteOpPattern - : OneToNOpConversionPattern { +struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -4231,18 +4932,16 @@ struct OneToNVMITileWriteOpPattern if (*activeLanes == 0) continue; } - FailureOr mask = fullPhysicalChunks - ? createAllTrueMaskForVReg(op.getLoc(), - vregType, rewriter) - : createContiguousStoreMask(op.getLoc(), - valueVMIType, - index, vregType, - rewriter); + FailureOr mask = + fullPhysicalChunks + ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) + : createContiguousStoreMask(op.getLoc(), valueVMIType, index, + vregType, rewriter); if (failed(mask)) return rewriter.notifyMatchFailure( op, "unsupported element type for tile_write mask"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, index * *lanesPerPart, rewriter); + Value chunkOffset = + createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, chunkOffset, /*dist=*/nullptr, *mask); @@ -4257,10 +4956,10 @@ template struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange lhsParts = adaptor.getLhs(); ValueRange rhsParts = adaptor.getRhs(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); @@ -4275,8 +4974,8 @@ struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { auto vregType = dyn_cast(resultType); if (!vregType || lhs.getType() != resultType || rhs.getType() != resultType) - return rewriter.notifyMatchFailure(op, - "physical binary part type mismatch"); + return rewriter.notifyMatchFailure( + op, "physical binary part type mismatch"); FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); if (failed(mask)) @@ -4322,8 +5021,8 @@ struct OneToNVMIFmaOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure(op, "unsupported element type for fma"); results.push_back( - rewriter.create(op.getLoc(), resultType, acc, lhs, rhs, - *mask) + rewriter + .create(op.getLoc(), resultType, acc, lhs, rhs, *mask) .getResult()); } @@ -4336,10 +5035,10 @@ template struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) @@ -4347,7 +5046,8 @@ struct OneToNVMIUnaryOpPattern : OneToNOpConversionPattern { SmallVector results; results.reserve(resultTypes.size()); - for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + for (auto [source, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { auto vregType = dyn_cast(resultType); if (!vregType || source.getType() != resultType) return rewriter.notifyMatchFailure(op, @@ -4371,10 +5071,10 @@ template struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange lhsParts = adaptor.getLhs(); ValueRange rhsParts = adaptor.getRhs(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); @@ -4398,8 +5098,8 @@ struct OneToNVMIMaskBinaryOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "unsupported mask type for all-true mask binary seed"); results.push_back( - rewriter.create(op.getLoc(), resultType, lhs, rhs, - *seedMask) + rewriter + .create(op.getLoc(), resultType, lhs, rhs, *seedMask) .getResult()); } @@ -4412,10 +5112,10 @@ template struct OneToNVMIMaskUnaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) @@ -4449,10 +5149,10 @@ template struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite(SourceOp op, typename OneToNOpConversionPattern< - SourceOp>::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite( + SourceOp op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { std::optional cmpMode = getVPTOCmpMode(op.getPredicate()); if (!cmpMode) return op.emitOpError() @@ -4484,11 +5184,11 @@ struct OneToNVMICmpOpPattern : OneToNOpConversionPattern { if (failed(seedMask)) return rewriter.notifyMatchFailure( op, "unsupported mask type for all-true cmp seed"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, lhs, rhs, *seedMask, - rewriter.getStringAttr(*cmpMode)) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, lhs, rhs, + *seedMask, + rewriter.getStringAttr(*cmpMode)) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -4519,11 +5219,10 @@ struct OneToNVMISelectOpPattern : OneToNOpConversionPattern { falseValue.getType() != resultType || !isa(resultType)) return rewriter.notifyMatchFailure( op, "physical select part type mismatch"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, trueValue, falseValue, - mask) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, trueValue, + falseValue, mask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -4562,18 +5261,17 @@ struct OneToNVMIActivePrefixIndexOpPattern return rewriter.notifyMatchFailure( op, "unsupported element type for active_prefix_index seed mask"); - Value zero = rewriter.create( - op.getLoc(), 0, intType.getWidth()); + Value zero = rewriter.create(op.getLoc(), 0, + intType.getWidth()); Value carrier = rewriter .create(op.getLoc(), resultType, zero, *seedMask, /*position=*/nullptr) .getResult(); - Value result = - rewriter - .create(op.getLoc(), resultType, carrier, - maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, carrier, + maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -4600,11 +5298,10 @@ struct OneToNVMICompressOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "compress requires physical source/mask/result parts"); - Value result = - rewriter - .create(op.getLoc(), resultType, sourceParts.front(), - maskParts.front()) - .getResult(); + Value result = rewriter + .create(op.getLoc(), resultType, + sourceParts.front(), maskParts.front()) + .getResult(); rewriter.replaceOp(op, SmallVector{result}, adaptor.getResultMapping()); return success(); @@ -4619,14 +5316,12 @@ struct OneToNVMICompressStoreOpPattern LogicalResult matchAndRewrite(VMICompressStoreOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - FailureOr destination = - getSingleValue(op, adaptor.getDestination(), - "compress_store destination must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "compress_store offset must convert to one value", - rewriter); + FailureOr destination = getSingleValue( + op, adaptor.getDestination(), + "compress_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "compress_store offset must convert to one value", rewriter); if (failed(destination) || failed(offset)) return failure(); @@ -4648,14 +5343,12 @@ struct OneToNVMICompressStoreOpPattern .create(op.getLoc(), (*destination).getType(), *destination, *offset) .getResult(); - Value squeezed = - rewriter - .create(op.getLoc(), valueType, valueParts.front(), - maskParts.front()) - .getResult(); - auto align = - rewriter.create(op.getLoc(), - AlignType::get(rewriter.getContext())); + Value squeezed = rewriter + .create(op.getLoc(), valueType, + valueParts.front(), maskParts.front()) + .getResult(); + auto align = rewriter.create( + op.getLoc(), AlignType::get(rewriter.getContext())); auto store = rewriter.create( op.getLoc(), align.getResult().getType(), align.getResult(), squeezed, storeBase, rewriter.getStringAttr("POST_UPDATE")); @@ -4667,8 +5360,7 @@ struct OneToNVMICompressStoreOpPattern struct OneToNVMIReduceAddIOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIReduceAddIOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIReduceAddIOp op, OpAdaptor adaptor, @@ -4708,16 +5400,16 @@ struct OneToNVMIReduceAddIOpPattern op, "failed to create reduce_addi first-lane mask"); Value accumulator = initParts.front(); - for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) + .create(op.getLoc(), resultType, sourcePart, maskPart) .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -4728,8 +5420,7 @@ struct OneToNVMIReduceAddIOpPattern struct OneToNVMIReduceAddFOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIReduceAddFOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIReduceAddFOp op, OpAdaptor adaptor, @@ -4769,16 +5460,16 @@ struct OneToNVMIReduceAddFOpPattern op, "failed to create reduce_addf first-lane mask"); Value accumulator = initParts.front(); - for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { + for (auto [sourcePart, maskPart] : + llvm::zip_equal(sourceParts, maskParts)) { Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) + .create(op.getLoc(), resultType, sourcePart, maskPart) .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -4808,8 +5499,7 @@ struct OneToNVMIGroupReduceAddFOpPattern op, "group_reduce_addf requires num_groups to evenly divide lane count"); if (succeeded(checkVcgaddGroupReduceShape( - sourceVMIType, maskVMIType, resultVMIType, - *groupSize, nullptr))) { + sourceVMIType, maskVMIType, resultVMIType, *groupSize, nullptr))) { if (sourceParts.size() != maskParts.size() || sourceParts.size() != resultTypes.size() || sourceParts.empty()) return rewriter.notifyMatchFailure( @@ -4832,10 +5522,63 @@ struct OneToNVMIGroupReduceAddFOpPattern SmallVector results; results.reserve(resultTypes.size()); for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, + maskParts[sourceIndex]) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if (succeeded(checkS16Block8GroupReduceShape(op, nullptr))) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 2 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s16 block8 combine mask"); + Value loSource = sourceParts[resultIndex]; + Value hiSource = sourceParts[resultPartCount + resultIndex]; + Value loMask = maskParts[resultIndex]; + Value hiMask = maskParts[resultPartCount + resultIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + loSource.getType() != resultType || + hiSource.getType() != resultType || loMask.getType() != maskType || + hiMask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s16 block8 group_reduce_addf requires uniform physical " + "types"); + Value lo = + rewriter.create(op.getLoc(), resultType, loSource, loMask) + .getResult(); + Value hi = + rewriter.create(op.getLoc(), resultType, hiSource, hiMask) + .getResult(); results.push_back( rewriter - .create(op.getLoc(), resultType, sourcePart, - maskParts[sourceIndex]) + .create(op.getLoc(), resultType, lo, hi, *combineMask) .getResult()); } @@ -4843,12 +5586,71 @@ struct OneToNVMIGroupReduceAddFOpPattern return success(); } + if (succeeded(checkS32Block8GroupReduceShape(op, nullptr))) { + int64_t resultPartCount = resultTypes.size(); + if (static_cast(sourceParts.size()) != resultPartCount * 4 || + maskParts.size() != sourceParts.size()) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf arity mismatch"); + + SmallVector results; + results.reserve(resultPartCount); + auto resultType = dyn_cast(resultTypes.front()); + auto maskType = dyn_cast(maskParts.front().getType()); + if (!resultType || !maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires physical vreg/mask"); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + + for (int64_t resultIndex = 0; resultIndex < resultPartCount; + ++resultIndex) { + int64_t activeGroups = + std::min(8, numGroups - resultIndex * 8); + FailureOr combineMask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, activeGroups, rewriter); + if (failed(combineMask)) + return rewriter.notifyMatchFailure( + op, "failed to create s32 block8 combine mask"); + SmallVector partials; + partials.reserve(4); + for (int64_t part = 0; part < 4; ++part) { + int64_t sourceIndex = part * resultPartCount + resultIndex; + Value source = sourceParts[sourceIndex]; + Value mask = maskParts[sourceIndex]; + Type physicalResultType = resultTypes[resultIndex]; + if (physicalResultType != resultType || + source.getType() != resultType || mask.getType() != maskType) + return rewriter.notifyMatchFailure( + op, "s32 block8 group_reduce_addf requires uniform physical " + "types"); + partials.push_back( + rewriter.create(op.getLoc(), resultType, source, mask) + .getResult()); + } + Value sum01 = rewriter + .create(op.getLoc(), resultType, partials[0], + partials[1], *combineMask) + .getResult(); + Value sum23 = rewriter + .create(op.getLoc(), resultType, partials[2], + partials[3], *combineMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, sum01, + sum23, *combineMask) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; - if (failed(checkContiguousFullGroupChunks( - op, sourceVMIType, *groupSize, &lanesPerPart, &groupCount, - &chunksPerGroup, rewriter))) + if (failed(checkContiguousFullGroupChunks(op, sourceVMIType, *groupSize, + &lanesPerPart, &groupCount, + &chunksPerGroup, rewriter))) return failure(); if (sourceParts.size() != maskParts.size() || static_cast(sourceParts.size()) != @@ -4901,11 +5703,10 @@ struct OneToNVMIGroupReduceAddFOpPattern .create(op.getLoc(), resultType, sourceParts[index], maskParts[index]) .getResult(); - *accumulator = - rewriter - .create(op.getLoc(), resultType, reduced, - *accumulator, *firstLaneMask) - .getResult(); + *accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + *accumulator, *firstLaneMask) + .getResult(); } int64_t destChunk = group * chunksPerGroup; @@ -4960,8 +5761,8 @@ struct OneToNVMIGroupBroadcastOpPattern auto firstSourceType = dyn_cast(sourceParts.front().getType()); if (!firstSourceType) - return rewriter.notifyMatchFailure( - op, "group_broadcast source must be vreg"); + return rewriter.notifyMatchFailure(op, + "group_broadcast source must be vreg"); unsigned indexBits = pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); if (indexBits != 8 && indexBits != 16 && indexBits != 32) @@ -4971,20 +5772,18 @@ struct OneToNVMIGroupBroadcastOpPattern auto indexType = VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), indexElementType); - std::optional groupSlotIndex; FailureOr allMask = createAllTrueMaskForVReg(op.getLoc(), firstSourceType, rewriter); if (failed(allMask)) return rewriter.notifyMatchFailure( op, "failed to create group_broadcast all mask"); - if (*groupSize < lanesPerPart) { - FailureOr index = createGroupSlotIndexVector( - op.getLoc(), indexType, *groupSize, rewriter); - if (failed(index)) - return rewriter.notifyMatchFailure( - op, "failed to create group_broadcast group-slot index vector"); - groupSlotIndex = *index; - } + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t selectionGroupSize = *groupSize; + if (resultLayoutFactor != 1 && resultLayout && + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < lanesPerPart) + selectionGroupSize = resultLayout.getBlockElems(); SmallVector results; results.resize(resultTypes.size()); @@ -4994,44 +5793,102 @@ struct OneToNVMIGroupBroadcastOpPattern return rewriter.notifyMatchFailure( op, "group_broadcast requires uniform physical vreg types"); int64_t sourceChunk = flatIndex; + int64_t baseGroupSlot = 0; if (resultLayoutFactor == 1) { if (*groupSize >= lanesPerPart) { int64_t chunksPerGroup = *groupSize / lanesPerPart; int64_t group = flatIndex / chunksPerGroup; sourceChunk = group * chunksPerGroup; + } else { + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + int64_t groupsPerResultChunk = lanesPerPart / *groupSize; + int64_t firstGroup = flatIndex * groupsPerResultChunk; + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; } } else { - int64_t runningFlatIndex = 0; - bool found = false; - for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { - FailureOr chunks = getDataChunksInPart(resultVMIType, part); - if (failed(chunks)) - return rewriter.notifyMatchFailure( - op, "group_broadcast failed to enumerate result chunks"); - for (int64_t chunk = 0; chunk < *chunks; ++chunk, ++runningFlatIndex) { - if (runningFlatIndex != static_cast(flatIndex)) - continue; - FailureOr firstLogical = - mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); - FailureOr lastLogical = mapPhysicalLaneToLogical( - resultVMIType, part, chunk, lanesPerPart - 1); - if (failed(firstLogical) || failed(lastLogical)) + bool blockFragmentSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + if (blockFragmentSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) return rewriter.notifyMatchFailure( - op, "group_broadcast failed to map result chunk lanes"); - int64_t firstGroup = *firstLogical / *groupSize; - int64_t lastGroup = *lastLogical / *groupSize; - if (firstGroup != lastGroup) + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t groupsPerResultChunk = + lanesPerPart / resultLayout.getBlockElems(); + int64_t firstGroup = chunk * groupsPerResultChunk; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, + "group_broadcast block-fragment source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) return rewriter.notifyMatchFailure( - op, "group_broadcast result chunk crosses logical groups"); - int64_t chunksPerGroup = *groupSize / lanesPerPart; - sourceChunk = firstGroup * chunksPerGroup; - found = true; - break; + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + sourceChunk = firstGroup * chunksPerGroup; + found = true; + break; + } } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); } - if (!found) - return rewriter.notifyMatchFailure( - op, "group_broadcast result chunk index is out of range"); } if (*groupSize >= lanesPerPart) { if (sourceChunk < 0 || @@ -5040,11 +5897,15 @@ struct OneToNVMIGroupBroadcastOpPattern op, "group_broadcast source chunk is out of range"); results[flatIndex] = rewriter - .create(op.getLoc(), resultType, sourceParts[sourceChunk], - *allMask, rewriter.getStringAttr("LOWEST")) + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) .getResult(); } else { - if (resultLayoutFactor != 1) + bool blockFragmentSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup) return rewriter.notifyMatchFailure( op, "group_broadcast small-group deinterleaved result is not " "supported"); @@ -5052,6 +5913,12 @@ struct OneToNVMIGroupBroadcastOpPattern sourceChunk >= static_cast(sourceParts.size())) return rewriter.notifyMatchFailure( op, "group_broadcast source chunk is out of range"); + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); results[flatIndex] = rewriter .create(op.getLoc(), resultType, @@ -5066,15 +5933,13 @@ struct OneToNVMIGroupBroadcastOpPattern }; template -struct OneToNVMIReduceMinMaxFOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult - matchAndRewrite( + LogicalResult matchAndRewrite( SourceOp op, typename OneToNOpConversionPattern::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); ValueRange initParts = adaptor.getInit(); ValueRange maskParts = adaptor.getMask(); @@ -5112,15 +5977,14 @@ struct OneToNVMIReduceMinMaxFOpPattern Value accumulator = initParts.front(); for (auto [sourcePart, maskPart] : llvm::zip_equal(sourceParts, maskParts)) { - Value reduced = - rewriter.create(op.getLoc(), resultType, sourcePart, - maskPart) - .getResult(); - accumulator = - rewriter - .create(op.getLoc(), resultType, reduced, accumulator, - *firstLaneMask) - .getResult(); + Value reduced = rewriter + .create(op.getLoc(), resultType, + sourcePart, maskPart) + .getResult(); + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } rewriter.replaceOp(op, SmallVector{accumulator}, @@ -5156,9 +6020,8 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { for (Type resultType : resultTypes) { auto resultVRegType = dyn_cast(resultType); if (!resultVRegType || - (resultVRegTypes.empty() - ? !resultVRegType.getElementType().isF32() - : resultVRegType != resultVRegTypes.front())) + (resultVRegTypes.empty() ? !resultVRegType.getElementType().isF32() + : resultVRegType != resultVRegTypes.front())) return rewriter.notifyMatchFailure( op, "unsupported physical extf result type"); resultVRegTypes.push_back(resultVRegType); @@ -5185,8 +6048,7 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { FailureOr mask = createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); if (failed(mask)) - return rewriter.notifyMatchFailure(op, - "failed to build extf seed mask"); + return rewriter.notifyMatchFailure(op, "failed to build extf seed mask"); SmallVector results; results.reserve(resultTypes.size()); @@ -5214,12 +6076,59 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { LogicalResult matchAndRewrite(VMITruncFOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceVMIType.getElementType().isF32() || + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr even = rewriter.getStringAttr("EVEN"); + FailureOr lane0Mask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", + rewriter); + if (failed(lane0Mask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot truncf lane0 mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || !sourceType.getElementType().isF32() || + !resultType || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot truncf physical type"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *lane0Mask, rnd, sat, + even) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || resultTypes.size() != 1) return rewriter.notifyMatchFailure( - op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is supported"); + op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is " + "supported"); auto sourceType0 = dyn_cast(sourceParts.front().getType()); auto resultType = dyn_cast(resultTypes.front()); @@ -5252,36 +6161,33 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { FailureOr resultMask = createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); if (failed(sourceMask) || failed(resultMask)) - return rewriter.notifyMatchFailure(op, - "failed to build truncf masks"); + return rewriter.notifyMatchFailure(op, "failed to build truncf masks"); StringAttr rnd = rewriter.getStringAttr("R"); StringAttr sat = rewriter.getStringAttr("SAT"); SmallVector partials; partials.reserve(parts.size()); for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { - partials.push_back( - rewriter - .create(op.getLoc(), resultType, sourcePart, *sourceMask, - rnd, sat, rewriter.getStringAttr(part)) - .getResult()); + partials.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *sourceMask, rnd, sat, + rewriter.getStringAttr(part)) + .getResult()); } Value merged = partials.front(); for (Value partial : llvm::drop_begin(partials)) - merged = - rewriter - .create(op.getLoc(), resultType, merged, partial, - *resultMask) - .getResult(); + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); rewriter.replaceOp(op, merged, adaptor.getResultMapping()); return success(); } }; -struct OneToNVMIBitcastOpPattern - : OneToNOpConversionPattern { +struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult @@ -5290,8 +6196,7 @@ struct OneToNVMIBitcastOpPattern ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.size() != resultTypes.size()) - return rewriter.notifyMatchFailure(op, - "physical bitcast arity mismatch"); + return rewriter.notifyMatchFailure(op, "physical bitcast arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); @@ -5312,8 +6217,7 @@ struct OneToNVMIBitcastOpPattern struct OneToNVMIChannelSplitOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIChannelSplitOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIChannelSplitOp op, OpAdaptor adaptor, @@ -5342,9 +6246,9 @@ struct OneToNVMIChannelSplitOpPattern } TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(); - FailureOr> results = materializeDataLayoutConversion( - op, adaptor.getSource(), resultTypes, sourceLayout, channelLayout, - rewriter); + FailureOr> results = + materializeDataLayoutConversion(op, adaptor.getSource(), resultTypes, + sourceLayout, channelLayout, rewriter); if (failed(results)) return failure(); @@ -5355,8 +6259,7 @@ struct OneToNVMIChannelSplitOpPattern struct OneToNVMIChannelMergeOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIChannelMergeOp>::OneToNOpConversionPattern; + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(VMIChannelMergeOp op, OpAdaptor adaptor, @@ -5417,8 +6320,8 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { results.push_back(sourceParts[sourceFlatIndex]); } - if (failed(verifyIdentityPartForwarding(op, results, resultTypes, - rewriter))) + if (failed( + verifyIdentityPartForwarding(op, results, resultTypes, rewriter))) return failure(); rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5448,11 +6351,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { if (failed(mask)) return rewriter.notifyMatchFailure( op, "failed to create shuffle lane0 splat mask"); - results.push_back( - rewriter - .create(op.getLoc(), resultType, sourcePart, *mask, - rewriter.getStringAttr("LOWEST")) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *mask, + rewriter.getStringAttr("LOWEST")) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5463,12 +6366,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { FailureOr> vselrPlans = computeShuffleVselrPlans(op, &vselrReason); if (failed(vselrPlans)) - return rewriter.notifyMatchFailure( - op, Twine("shuffle vselr ") + vselrReason); + return rewriter.notifyMatchFailure(op, + Twine("shuffle vselr ") + vselrReason); if (vselrPlans->size() != resultTypes.size()) - return rewriter.notifyMatchFailure(op, - "shuffle vselr arity mismatch"); + return rewriter.notifyMatchFailure(op, "shuffle vselr arity mismatch"); SmallVector results; results.reserve(resultTypes.size()); @@ -5496,8 +6398,8 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); Type indexType = - VRegType::get(rewriter.getContext(), - sourceVRegType.getElementCount(), indexElementType); + VRegType::get(rewriter.getContext(), sourceVRegType.getElementCount(), + indexElementType); FailureOr base = createScalarOffsetConstant( op.getLoc(), indexElementType, plan.baseLane, rewriter); if (failed(base)) @@ -5508,11 +6410,11 @@ struct OneToNVMIShuffleOpPattern : OneToNOpConversionPattern { Value indexVector = rewriter.create(op.getLoc(), indexType, *base, orderAttr) .getResult(); - results.push_back( - rewriter - .create(op.getLoc(), resultType, - sourceParts[plan.sourceFlatIndex], indexVector) - .getResult()); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourceParts[plan.sourceFlatIndex], + indexVector) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5590,17 +6492,15 @@ struct OneToNCFCondBranchOpPattern const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); unsigned operandIndex = 1; for (unsigned i = 0, e = op.getNumTrueOperands(); i < e; ++i) - llvm::append_range( - trueOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(trueOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); for (unsigned i = 0, e = op.getNumFalseOperands(); i < e; ++i) - llvm::append_range( - falseOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(falseOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); - rewriter.replaceOpWithNewOp( - op, condition.front(), trueDest, trueOperands, falseDest, - falseOperands); + rewriter.replaceOpWithNewOp(op, condition.front(), + trueDest, trueOperands, + falseDest, falseOperands); return success(); } }; @@ -5613,9 +6513,8 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { OneToNPatternRewriter &rewriter) const override { auto *converter = getTypeConverter(); llvm::DenseMap convertedBlocks; - Block *defaultDest = - convertBranchDestBlock(op.getDefaultDestination(), rewriter, - *converter, convertedBlocks); + Block *defaultDest = convertBranchDestBlock( + op.getDefaultDestination(), rewriter, *converter, convertedBlocks); SmallVector caseDests; caseDests.reserve(op.getCaseDestinations().size()); @@ -5633,7 +6532,8 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { ValueRange flag = adaptor.getFlag(); if (flag.size() != 1) - return rewriter.notifyMatchFailure(op, "flag converted to multiple values"); + return rewriter.notifyMatchFailure(op, + "flag converted to multiple values"); SmallVector defaultOperands; SmallVector> caseOperandStorage; @@ -5643,18 +6543,16 @@ struct OneToNCFSwitchOpPattern : OneToNOpConversionPattern { unsigned operandIndex = 1; for (unsigned i = 0, e = op.getDefaultOperands().size(); i < e; ++i) - llvm::append_range( - defaultOperands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(defaultOperands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); caseOperandStorage.reserve(op.getCaseOperandSegments().size()); caseOperands.reserve(op.getCaseOperandSegments().size()); for (int32_t segmentSize : op.getCaseOperandSegments()) { SmallVector operands; for (int32_t i = 0; i < segmentSize; ++i) - llvm::append_range( - operands, - operandMapping.getConvertedValues(flatOperands, operandIndex++)); + llvm::append_range(operands, operandMapping.getConvertedValues( + flatOperands, operandIndex++)); caseOperandStorage.push_back(std::move(operands)); } for (SmallVector &operands : caseOperandStorage) @@ -5694,7 +6592,8 @@ struct OneToNSCFExecuteRegionOpPattern struct OneToNSCFIndexSwitchOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + using OneToNOpConversionPattern< + scf::IndexSwitchOp>::OneToNOpConversionPattern; LogicalResult matchAndRewrite(scf::IndexSwitchOp op, OpAdaptor adaptor, @@ -5712,11 +6611,9 @@ struct OneToNSCFIndexSwitchOpPattern return failure(); auto newOp = rewriter.create( - op.getLoc(), resultTypes, arg.front(), op.getCases(), - op.getNumCases()); + op.getLoc(), resultTypes, arg.front(), op.getCases(), op.getNumCases()); newOp->setAttrs(op->getAttrs()); - rewriter.inlineRegionBefore(op.getDefaultRegion(), - newOp.getDefaultRegion(), + rewriter.inlineRegionBefore(op.getDefaultRegion(), newOp.getDefaultRegion(), newOp.getDefaultRegion().end()); for (auto [srcRegion, dstRegion] : llvm::zip(op.getCaseRegions(), newOp.getCaseRegions())) @@ -5731,80 +6628,59 @@ void populateVMIOneToNConversionPatterns( const VMITargetCapabilityRegistry &capabilities) { populateFuncTypeConversionPatterns(typeConverter, patterns); scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); - patterns - .add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, - patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); patterns.add( typeConverter, patterns.getContext()); - patterns.add, - OneToNVMIMaskBinaryOpPattern, - OneToNVMIMaskBinaryOpPattern, - OneToNVMIMaskUnaryOpPattern, - OneToNVMILoadOpPattern, - OneToNVMIGroupLoadOpPattern, - OneToNVMIMaskedLoadOpPattern, - OneToNVMIGatherOpPattern, - OneToNVMIExpandLoadOpPattern, - OneToNVMIStoreOpPattern, - OneToNVMIGroupStoreOpPattern, - OneToNVMIMaskedStoreOpPattern, - OneToNVMIScatterOpPattern, - OneToNVMITileReadOpPattern, - OneToNVMITileWriteOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIFmaOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIBinaryOpPattern, - OneToNVMIUnaryOpPattern, - OneToNVMICmpOpPattern, - OneToNVMICmpOpPattern, - OneToNVMISelectOpPattern, - OneToNVMIActivePrefixIndexOpPattern, - OneToNVMICompressOpPattern, - OneToNVMICompressStoreOpPattern, - OneToNVMIReduceAddIOpPattern, - OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddFOpPattern, - OneToNVMIGroupBroadcastOpPattern, - OneToNVMIReduceMinMaxFOpPattern, - OneToNVMIReduceMinMaxFOpPattern, - OneToNVMIExtFOpPattern, - OneToNVMITruncFOpPattern, - OneToNVMIBitcastOpPattern, - OneToNVMIChannelSplitOpPattern, - OneToNVMIChannelMergeOpPattern, - OneToNVMIShuffleOpPattern>(typeConverter, - patterns.getContext()); + patterns.add< + OneToNVMIEnsureLayoutOpPattern, OneToNVMIEnsureMaskLayoutOpPattern, + OneToNVMIBroadcastOpPattern, OneToNVMIIotaOpPattern, + OneToNVMIConstantOpPattern, OneToNVMIConstantMaskOpPattern, + OneToNVMICreateMaskOpPattern, OneToNVMICreateGroupMaskOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskBinaryOpPattern, + OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, + OneToNVMIGroupLoadOpPattern, OneToNVMIGroupSlotLoadOpPattern, + OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, + OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, + OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, + OneToNVMIScatterOpPattern, OneToNVMITileReadOpPattern, + OneToNVMITileWriteOpPattern, OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, OneToNVMIFmaOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIBinaryOpPattern, + OneToNVMIUnaryOpPattern, + OneToNVMICmpOpPattern, OneToNVMICmpOpPattern, + OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, + OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, + OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, + OneToNVMIGroupReduceAddFOpPattern, OneToNVMIGroupBroadcastOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIReduceMinMaxFOpPattern, + OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, + OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, + OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( + typeConverter, patterns.getContext()); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -5812,9 +6688,8 @@ void populateVMIOneToNConversionPatterns( LogicalResult verifyNoResidualVMIIR(ModuleOp module) { WalkResult result = module.walk([&](Operation *op) { if (isa(op)) { - op->emitError() - << kVMIDiagResidualOpPrefix - << "unrealized conversion cast remains after vmi-to-vpto"; + op->emitError() << kVMIDiagResidualOpPrefix + << "unrealized conversion cast remains after vmi-to-vpto"; return WalkResult::interrupt(); } if (auto createMask = dyn_cast(op)) { @@ -5837,9 +6712,8 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { } } if (isVMIOp(op) || hasVMIType(op)) { - op->emitError() - << kVMIDiagResidualOpPrefix - << "failed to convert all VMI ops/types to VPTO"; + op->emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; return WalkResult::interrupt(); } return WalkResult::advance(); @@ -5856,8 +6730,7 @@ LogicalResult checkSupportedExtFShape(VMIExtFOp op) { FailureOr resultArity = getVMIPhysicalArity(resultType); if (!sourceLayout || !resultLayout || failed(sourceArity) || failed(resultArity) || !sourceLayout.isContiguous() || - !resultLayout.isDeinterleaved() || - !resultType.getElementType().isF32()) + !resultLayout.isDeinterleaved() || !resultType.getElementType().isF32()) return failure(); unsigned sourceBits = @@ -5871,7 +6744,14 @@ LogicalResult checkSupportedExtFShape(VMIExtFOp op) { return failure(); } -LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { +LogicalResult checkSupportedTruncFShape(VMITruncFOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); @@ -5879,18 +6759,45 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op) { FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity) || !sourceLayout.isDeinterleaved() || - !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || - *resultArity != 1) - return failure(); + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); unsigned resultBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceType.getElementType().isF32() || resultBits != 16 || + *sourceArity != *resultArity) + return fail("group-slot truncf requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "f32 source, f16 result, and matching physical arity"); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = "group_slot_cast_slots1_f32_to_f16"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source/result layouts; expected '" + + expectedPlan + "'"); + return success(); + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + !sourceType.getElementType().isF32() || *resultArity != 1) + return fail("requires f32 deinterleaved source and contiguous result"); + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) return success(); if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) return success(); - return failure(); + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); } FailureOr> @@ -5900,8 +6807,7 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return failure(); FailureOr factor = getDataLayoutFactor(type); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); if (failed(factor) || failed(lanesPerPart)) return failure(); @@ -5925,8 +6831,7 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return bits; } -LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, - std::string *reason) { +LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -5967,9 +6872,10 @@ LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, return success(); } -LogicalResult checkSupportedChannelSplitShape( - const VMITargetCapabilityRegistry &capabilities, VMIChannelSplitOp op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedChannelSplitShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelSplitOp op, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -6024,9 +6930,10 @@ LogicalResult checkSupportedChannelSplitShape( return success(); } -LogicalResult checkSupportedChannelMergeShape( - const VMITargetCapabilityRegistry &capabilities, VMIChannelMergeOp op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedChannelMergeShape(const VMITargetCapabilityRegistry &capabilities, + VMIChannelMergeOp op, + std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -6176,9 +7083,9 @@ LogicalResult checkSupportedCompressStoreShape( return fail("requires contiguous value and mask layouts"); VMICapabilityResult destinationCapability = - capabilities.supportsUBPointerMemory( - op.getDestination().getType(), "destination", "pto.vstur", - "pto.vstur stores only to UB"); + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vstur", + "pto.vstur stores only to UB"); if (!destinationCapability.isSupported()) return fail(destinationCapability.reason); @@ -6275,6 +7182,15 @@ LogicalResult checkSupportedGroupReduceAddFShape( VMILayoutAttr maskLayout = maskType.getLayoutAttr(); if (!sourceLayout || !resultLayout || !maskLayout) return fail("requires assigned source, mask, and result layouts"); + + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceType, op.getNumGroupsAttr().getInt(), reason); + if (failed(groupSize)) + return failure(); + if (succeeded(checkS16Block8GroupReduceShape(op, reason))) + return success(); + if (succeeded(checkS32Block8GroupReduceShape(op, reason))) + return success(); if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || !maskLayout.isContiguous()) @@ -6296,15 +7212,43 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires computable source/result/mask physical arity"); if (*sourceArity != *resultArity || *sourceArity != *maskArity) return fail("requires source/result/mask physical arity to match"); - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), - reason); - if (failed(groupSize)) + if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, + *groupSize, nullptr))) { + if (resultLayout.getSlots() > 0) { + auto selectedPlan = + op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan = "s8_reduce_contiguous"; + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + + expectedPlan + "'"); + } + return success(); + } + if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) return failure(); - if (succeeded(checkVcgaddGroupReduceShape( - sourceType, maskType, resultType, *groupSize, nullptr))) + if (resultLayout.getSlots() <= 0) return success(); - return checkSupportedGroupChunkShape(sourceType, *groupSize, reason); + + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + StringRef expectedPlan; + if (sourceLayout.isContiguous() && *groupSize == 64 && + resultLayout.getSlots() == 1) + expectedPlan = "s64_reduce_row_local"; + else + return fail("explicit group_slots group_reduce_addf chunk path has no " + "registered selected_plan for the assigned layouts"); + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match result layout; expected '" + expectedPlan + + "'"); + return success(); } LogicalResult checkSupportedGroupBroadcastShape( @@ -6335,6 +7279,27 @@ LogicalResult checkSupportedGroupBroadcastShape( if (resultLayout.isGroupSlots()) return fail("requires dense result layout"); + if (sourceLayout.getSlots() > 0) { + auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); + if (!selectedPlan) + return fail("requires vmi.selected_plan selected by " + "vmi-layout-assignment"); + + StringRef expectedPlan; + if (sourceLayout.getSlots() == 8) + expectedPlan = "group_broadcast_slots8_vselr"; + else if (sourceLayout.getSlots() == 1) + expectedPlan = "group_broadcast_slots1_vselr"; + else + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + if (selectedPlan.getValue() != expectedPlan) + return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + + "' does not match source layout; expected '" + expectedPlan + + "'"); + } + std::string fullChunkReason; if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) return fail(Twine("requires full source physical chunks; ") + @@ -6350,9 +7315,8 @@ LogicalResult checkSupportedGroupBroadcastShape( if (failed(lanesPerPart) || failed(resultLanesPerPart) || *lanesPerPart != *resultLanesPerPart) return fail("requires matching physical lanes per part"); - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt(), - reason); + FailureOr groupSize = getGroupSizeFromNumGroups( + sourceType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) @@ -6364,9 +7328,14 @@ LogicalResult checkSupportedGroupBroadcastShape( return fail("requires known result layout factor"); if (*resultFactor == 1) return success(); + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return success(); int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; - if (*groupSize < *lanesPerPart || - *groupSize % logicalSpanPerResultChunk != 0) + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) return fail("deinterleaved result requires every physical result chunk to " "stay within one logical group"); return success(); @@ -6382,9 +7351,8 @@ checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, }; auto lhsType = cast(op.getLhs().getType()); - VMICapabilityResult elementCapability = - capabilities.supportsElementType(lhsType.getElementType(), - VMIElementPurpose::VMula); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + lhsType.getElementType(), VMIElementPurpose::VMula); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -6408,9 +7376,8 @@ checkSupportedReluShape(const VMITargetCapabilityRegistry &capabilities, if (failed(checkSupportedMaskableVReg(capabilities, resultType, reason))) return failure(); - VMICapabilityResult elementCapability = - capabilities.supportsElementType(resultType.getElementType(), - VMIElementPurpose::VRelu); + VMICapabilityResult elementCapability = capabilities.supportsElementType( + resultType.getElementType(), VMIElementPurpose::VRelu); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -6447,17 +7414,19 @@ void emitEnsureLayoutMaterializationError(VMIEnsureLayoutOp ensure, "packing plan"; } -LogicalResult verifySupportedVMIToVPTOOps( - ModuleOp module, const VMITargetCapabilityRegistry &capabilities, - bool enableStableGatherMaskedLoad) { - auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, - VMIVRegType type, Value source, - std::optional constantOffset) - -> WalkResult { +LogicalResult +verifySupportedVMIToVPTOOps(ModuleOp module, + const VMITargetCapabilityRegistry &capabilities, + bool enableStableGatherMaskedLoad) { + auto emitMemoryUnsupported = + [&](Operation *op, StringRef opName, VMIVRegType type, Value source, + std::optional constantOffset, + std::optional explicitFullReadElems = + std::nullopt) -> WalkResult { std::string reason; if (succeeded(checkSupportedLoadShape(capabilities, type, source, source.getType(), constantOffset, - &reason))) + explicitFullReadElems, &reason))) return WalkResult::advance(); op->emitError() @@ -6486,13 +7455,13 @@ LogicalResult verifySupportedVMIToVPTOOps( [&](Operation *op, StringRef opName, VMIVRegType type, VMIElementPurpose purpose, StringRef elementContract) -> WalkResult { std::string reason; - if (succeeded(checkSupportedTargetElementVReg( - capabilities, type, purpose, elementContract, &reason))) + if (succeeded(checkSupportedTargetElementVReg(capabilities, type, purpose, + elementContract, &reason))) return WalkResult::advance(); op->emitError() - << kVMIDiagUnsupportedPrefix << opName - << " direct lowering requires " << elementContract + << kVMIDiagUnsupportedPrefix << opName << " direct lowering requires " + << elementContract << " and physical vreg parts with b8/b16/b32 predicate masks (" << reason << ")"; return WalkResult::interrupt(); @@ -6532,10 +7501,15 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } - if (auto load = dyn_cast(op)) + if (auto load = dyn_cast(op)) { + std::optional explicitFullReadElems; + if (auto attr = load.getFullReadElemsAttr()) + explicitFullReadElems = attr.getInt(); return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), - load.getSource(), getConstantIndexValue(load.getOffset())); + load.getSource(), getConstantIndexValue(load.getOffset()), + explicitFullReadElems); + } if (auto load = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) @@ -6548,6 +7522,20 @@ LogicalResult verifySupportedVMIToVPTOOps( << reason << ")"; return WalkResult::interrupt(); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupSlotLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_slot_load requires explicit group_slots result " + "layout matching num_groups, a supported UB pointer source, " + "and either slots=8 with constant unit source_group_stride or " + "slots=1 row-local lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { if (enableStableGatherMaskedLoad) { load.emitError() @@ -6557,8 +7545,7 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } std::string reason; - if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, - &reason))) + if (succeeded(checkSupportedMaskedLoadShape(capabilities, load, &reason))) return WalkResult::advance(); load.emitError() << kVMIDiagUnsupportedPrefix @@ -6582,8 +7569,7 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto load = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedExpandLoadShape(capabilities, load, - &reason))) + if (succeeded(checkSupportedExpandLoadShape(capabilities, load, &reason))) return WalkResult::advance(); load.emitError() << kVMIDiagUnsupportedPrefix @@ -6611,8 +7597,8 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto store = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedGroupStoreShape(capabilities, store, - &reason))) + if (succeeded( + checkSupportedGroupStoreShape(capabilities, store, &reason))) return WalkResult::advance(); store.emitError() << kVMIDiagUnsupportedPrefix @@ -6640,8 +7626,7 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto scatter = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedScatterShape(capabilities, scatter, - &reason))) + if (succeeded(checkSupportedScatterShape(capabilities, scatter, &reason))) return WalkResult::advance(); scatter.emitError() << kVMIDiagUnsupportedPrefix @@ -6660,8 +7645,7 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto tileWrite = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedStoreShape( - capabilities, - cast(tileWrite.getValue().getType()), + capabilities, cast(tileWrite.getValue().getType()), tileWrite.getDestination(), tileWrite.getDestination().getType(), &reason))) return WalkResult::advance(); @@ -6728,88 +7712,62 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto addf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.addf", - cast(addf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.addf", cast(addf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto addi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.addi", - cast( - addi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.addi", cast(addi.getResult().getType())); if (auto subf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.subf", - cast(subf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.subf", cast(subf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto subi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.subi", - cast( - subi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.subi", cast(subi.getResult().getType())); if (auto mulf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.mulf", - cast(mulf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.mulf", cast(mulf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto muli = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.muli", - cast( - muli.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.muli", cast(muli.getResult().getType())); if (auto divf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.divf", - cast(divf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.divf", cast(divf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto minf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.minf", - cast(minf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.minf", cast(minf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto maxf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.maxf", - cast(maxf.getResult().getType()), - VMIElementPurpose::F16BF16F32, - "f16/bf16/f32 element type"); + op, "pto.vmi.maxf", cast(maxf.getResult().getType()), + VMIElementPurpose::F16BF16F32, "f16/bf16/f32 element type"); if (auto negf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.negf", - cast(negf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.negf", cast(negf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto absf = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.absf", - cast(absf.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.absf", cast(absf.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto absi = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.absi", - cast(absi.getResult().getType()), + op, "pto.vmi.absi", cast(absi.getResult().getType()), VMIElementPurpose::SignlessOrSignedI8I16I32, "signless/signed i8/i16/i32 element type"); if (auto sqrt = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.sqrt", - cast(sqrt.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.sqrt", cast(sqrt.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto exp = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.exp", - cast(exp.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.exp", cast(exp.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto ln = dyn_cast(op)) return emitTargetElementUnsupported( - op, "pto.vmi.ln", - cast(ln.getResult().getType()), - VMIElementPurpose::F16F32, - "f16/f32 element type"); + op, "pto.vmi.ln", cast(ln.getResult().getType()), + VMIElementPurpose::F16F32, "f16/f32 element type"); if (auto relu = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReluShape(capabilities, relu, &reason))) @@ -6822,32 +7780,27 @@ LogicalResult verifySupportedVMIToVPTOOps( return WalkResult::interrupt(); } if (auto andi = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.andi", - cast( - andi.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.andi", cast(andi.getResult().getType())); if (auto ori = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.ori", - cast(ori.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.ori", cast(ori.getResult().getType())); if (auto xori = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.xori", - cast( - xori.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.xori", cast(xori.getResult().getType())); if (auto shli = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.shli", - cast( - shli.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.shli", cast(shli.getResult().getType())); if (auto shrui = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.shrui", - cast( - shrui.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.shrui", cast(shrui.getResult().getType())); if (auto notOp = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.not", - cast( - notOp.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.not", cast(notOp.getResult().getType())); if (auto select = dyn_cast(op)) - return emitMaskableUnsupported(op, "pto.vmi.select", - cast( - select.getResult().getType())); + return emitMaskableUnsupported( + op, "pto.vmi.select", + cast(select.getResult().getType())); if (auto cmpf = dyn_cast(op)) { WalkResult target = emitTargetElementUnsupported( @@ -6874,8 +7827,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto activePrefix = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedActivePrefixIndexShape(activePrefix, - &reason))) + if (succeeded( + checkSupportedActivePrefixIndexShape(activePrefix, &reason))) return WalkResult::advance(); activePrefix.emitError() << kVMIDiagUnsupportedPrefix @@ -7013,14 +7966,18 @@ LogicalResult verifySupportedVMIToVPTOOps( } if (auto truncf = dyn_cast(op)) { - if (succeeded(checkSupportedTruncFShape(truncf))) + std::string reason; + if (succeeded(checkSupportedTruncFShape(truncf, &reason))) return WalkResult::advance(); truncf.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " "to one contiguous f16 result chunk or f32 deinterleaved=4 " - "source parts to one contiguous fp8-like result chunk"; + "source parts to one contiguous fp8-like result chunk, or f32 " + "group_slots(num_groups=G, slots=1) to f16 " + "group_slots(num_groups=G, slots=1) with selected_plan (" + << reason << ")"; return WalkResult::interrupt(); } @@ -7041,8 +7998,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto split = dyn_cast(op)) { int64_t channels = split.getNumResults(); std::string reason; - if (succeeded(checkSupportedChannelSplitShape(capabilities, split, - &reason))) + if (succeeded( + checkSupportedChannelSplitShape(capabilities, split, &reason))) return WalkResult::advance(); if (channels != 2 && channels != 4) @@ -7062,8 +8019,8 @@ LogicalResult verifySupportedVMIToVPTOOps( if (auto merge = dyn_cast(op)) { int64_t channels = merge.getInputs().size(); std::string reason; - if (succeeded(checkSupportedChannelMergeShape(capabilities, merge, - &reason))) + if (succeeded( + checkSupportedChannelMergeShape(capabilities, merge, &reason))) return WalkResult::advance(); if (channels != 2 && channels != 4) @@ -7119,8 +8076,7 @@ LogicalResult verifySupportedVMIToVPTOOps( return failure(result.wasInterrupted()); } -struct VMIToVPTOPass - : public mlir::pto::impl::VMIToVPTOBase { +struct VMIToVPTOPass : public mlir::pto::impl::VMIToVPTOBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMIToVPTOPass) void runOnOperation() override { @@ -7130,8 +8086,8 @@ struct VMIToVPTOPass return; } VMITargetCapabilityRegistry capabilities; - if (failed(verifySupportedVMIToVPTOOps( - module, capabilities, enableStableGatherMaskedLoad))) { + if (failed(verifySupportedVMIToVPTOOps(module, capabilities, + enableStableGatherMaskedLoad))) { signalPassFailure(); return; } @@ -7140,13 +8096,11 @@ struct VMIToVPTOPass VMIToVPTOTypeConverter typeConverter; RewritePatternSet patterns(context); - populateVMIOneToNConversionPatterns(typeConverter, patterns, - capabilities); + populateVMIOneToNConversionPatterns(typeConverter, patterns, capabilities); if (failed(applyPartialOneToNConversion(module, typeConverter, std::move(patterns)))) { - module.emitError() - << kVMIDiagResidualOpPrefix - << "failed to convert all VMI ops/types to VPTO"; + module.emitError() << kVMIDiagResidualOpPrefix + << "failed to convert all VMI ops/types to VPTO"; signalPassFailure(); return; } diff --git a/test/lit/vmi/vmi_create_group_mask_invalid.pto b/test/lit/vmi/vmi_create_group_mask_invalid.pto new file mode 100644 index 0000000000..0c3aec3d65 --- /dev/null +++ b/test/lit/vmi/vmi_create_group_mask_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_create_group_mask_lane_count_invalid() { + %c12 = arith.constant 12 : index + // CHECK: pto.vmi.create_group_mask + // CHECK-SAME: requires result lane count to equal num_groups * group_size + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<127xpred> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto new file mode 100644 index 0000000000..dce36f1b5d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_broadcast_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %scale: f32) { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY:.*]] = pto.vmi.addf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[COPY_DENSE:.*]] = pto.vmi.ensure_layout %[[COPY]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[COPY_DENSE]] +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[PROD:.*]] = pto.vmi.mulf %[[X]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto new file mode 100644 index 0000000000..49f2c5e2a8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + + func.func @caller(%base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c32 = arith.constant 32 : index + %x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_group_mask %c32 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + call @consume(%x, %mask, %out, %off) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + return + } +} + +// ASSIGN-LABEL: func.func private @consume( +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-LABEL: func.func @caller( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: call @consume(%[[X]], %[[MASK]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> + +// LOWER-LABEL: func.func private @consume( +// LOWER-SAME: !pto.vreg<64xf32> +// LOWER-SAME: !pto.mask +// LOWER: pto.vdintlv +// LOWER: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-LABEL: func.func @caller( +// LOWER: call @consume( +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto new file mode 100644 index 0000000000..f4790b5432 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s16( + %base: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( +// LOWER: pto.pset_b32 "PAT_ALL" +// LOWER: pto.plt_b32 +// LOWER: pto.pnot +// LOWER: pto.pand +// LOWER: pto.por +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: PAT_M4 +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto new file mode 100644 index 0000000000..147245c484 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -0,0 +1,77 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_dense_f32_to_f16_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %x16, %dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( +// LOWER: pto.vlds +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} +// LOWER: pto.vcvt {{.*}} {part = "ODD"} +// LOWER: pto.vintlv +// LOWER-COUNT-2: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( +// LOWER: pto.vldsx2 +// LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto new file mode 100644 index 0000000000..a93ae52c17 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %copy_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( +// LOWER-COUNT-4: pto.vlds +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-COUNT-4: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto new file mode 100644 index 0000000000..af6623a995 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_dense_store_group_slots_invalid( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.store operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto new file mode 100644 index 0000000000..e43d2e5591 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -0,0 +1,65 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f32_f8_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_dintlv4" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER-COUNT-3: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto new file mode 100644 index 0000000000..0ce6b6b295 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_f8_compute_f8( + %src: !pto.ptr, + %scale: f32, + %dst: !pto.ptr, + %off: index) { + %x8 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %dst[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// ASSIGN: %[[X8:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X8]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y32:.*]] = pto.vmi.mulf %[[X32]], %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[Y8]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( +// LOWER: pto.vlds +// LOWER-COUNT-4: pto.vcvt {{.*}} {part = "P{{[0-3]}}"} +// LOWER-COUNT-4: pto.vdup +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER-COUNT-3: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto new file mode 100644 index 0000000000..7df6946741 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_broadcast_multi_consumer( + %src: !pto.ptr, + %sum_out: !pto.ptr, + %dense_out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_out[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN: pto.vmi.group_store %[[YSUM]] +// ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[H]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vmul +// LOWER: pto.vmul +// LOWER: pto.vcgadd +// LOWER: pto.vsts +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto new file mode 100644 index 0000000000..7c1e569bf3 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_broadcast_slots8( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<1024xf32> { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32> + return %out : !pto.vmi.vreg<1024xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_broadcast +// CHECK-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load.pto b/test/lit/vmi/vmi_layout_assignment_group_load.pto new file mode 100644 index 0000000000..2a90d02d08 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load( + %source: !pto.ptr, + %row_stride: index) -> !pto.vmi.vreg<512xf32> { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_load( +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_load +// CHECK-SAME: vmi.selected_plan = "group_load_contiguous_chunks" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto new file mode 100644 index 0000000000..c928df5320 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_block8_truncf_invalid( + %src: !pto.ptr, + %sum_dst: !pto.ptr, + %dense_dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %src[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + %h = pto.vmi.truncf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %dense_dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto new file mode 100644 index 0000000000..113467b492 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_compact_stride12_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride12 = arith.constant 12 : index + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan + // CHECK-SAME: stable gather fallback is not implemented + %x = pto.vmi.group_load %base[%off], %stride12 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto new file mode 100644 index 0000000000..67215442e5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -0,0 +1,50 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s16_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride24 = arith.constant 24 : index + %x = pto.vmi.group_load %base[%off], %stride24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( +// LOWER-COUNT-2: pto.vsldb +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..ed2ed892f9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s16_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride20 = arith.constant 20 : index + %x = pto.vmi.group_load %base[%off], %stride20 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 16 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto new file mode 100644 index 0000000000..c97a35855b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %ysum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[YSUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-4: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto new file mode 100644 index 0000000000..0f506a3a1f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_load_s32_stride_store( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride40 = arith.constant 40 : index + %x = pto.vmi.group_load %base[%off], %stride40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( +// LOWER-COUNT-4: pto.vsldb +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd +// LOWER: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..7cd5ffd85d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto @@ -0,0 +1,31 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_load_s32_unaligned_stride_invalid( + %base: !pto.ptr, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %stride34 = arith.constant 34 : index + %x = pto.vmi.group_load %base[%off], %stride34 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 32 requires constant positive row_stride divisible by 8 f32 elements for the block8 stride plan; stable gather fallback is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto new file mode 100644 index 0000000000..3bea54d83f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s12_invalid( + %source: !pto.vmi.vreg<96xf32>, + %mask: !pto.vmi.mask<96xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf lowers through pto.vcgadd + // CHECK-SAME: num_groups deriving a group size aligned to physical chunks + // CHECK-SAME: found padding lane in physical chunk + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> + -> !pto.vmi.vreg<96xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<96xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto new file mode 100644 index 0000000000..c4652169d4 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( +// LOWER: %[[LO:.*]], %[[HI:.*]] = pto.vdintlv +// LOWER: %[[MLO:.*]], %[[MHI:.*]] = pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: %[[SLO:.*]] = pto.vcgadd %[[LO]], %[[MLO]] +// LOWER: %[[SHI:.*]] = pto.vcgadd %[[HI]], %[[MHI]] +// LOWER: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg4[%arg5], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto new file mode 100644 index 0000000000..e9a3e7c9e9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows, %dst[%off] : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE:.*]] = pto.vmi.ensure_layout %arg0 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32_SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[B16]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vselr +// LOWER: pto.vselr +// LOWER: pto.vcvt +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto new file mode 100644 index 0000000000..9fb03c80b2 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( +// LOWER-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C2]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr +// LOWER: pto.vdup %[[C6]] +// LOWER: pto.vselr +// LOWER-COUNT-4: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto new file mode 100644 index 0000000000..1d61b4196e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 16} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( +// LOWER-COUNT-8: pto.vdintlv +// LOWER-COUNT-8: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-8: pto.vcgadd +// LOWER: %[[STORE_MASK0:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[%arg17], %[[STORE_MASK0]] +// LOWER: %[[STORE_MASK1:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg16[{{.*}}], %[[STORE_MASK1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto new file mode 100644 index 0000000000..b51dd875b5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_store( + %source: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// ASSIGN-SAME: %[[SOURCE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SOURCE_SPLIT:.*]] = pto.vmi.ensure_layout %[[SOURCE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( +// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts {{.*}}, %arg8[%arg9], %[[STORE_MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto new file mode 100644 index 0000000000..0a7550d004 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( + %src: memref<256xf32>, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] + : memref<256xf32> -> !pto.vmi.vreg<192xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %x = pto.vmi.load %src[%c0] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( +// LOWER-DAG: %[[C6:.*]] = arith.constant 6 : i32 +// LOWER-DAG: %[[C48:.*]] = arith.constant 48 : i32 +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.plt_b32 %[[C48]] : i32 -> !pto.mask, i32 +// LOWER: %[[SLOTS:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd {{.*}}, {{.*}}, %[[SLOTS]] +// LOWER: %[[STORE:.*]], %{{.*}} = pto.plt_b32 %[[C6]] : i32 -> !pto.mask, i32 +// LOWER: pto.vsts {{.*}}, {{.*}}, %[[STORE]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// ASSIGN: %[[PX:.*]] = pto.vmi.load +// ASSIGN-SAME: {full_read_elems = 256 : i64} +// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-3: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto new file mode 100644 index 0000000000..c66ff0eb3c --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid( + %source: !pto.vmi.vreg<192xf32>, + %mask: !pto.vmi.mask<192xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: requires source and result to have the same physical arity + // CHECK-SAME: partial/tail layout materialization requires an explicit packing plan + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto new file mode 100644 index 0000000000..2e4c9dd02f --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s64( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: vmi.selected_plan = "s64_reduce_row_local" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto new file mode 100644 index 0000000000..6fffb7c636 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %scaled_sum, %dst[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots1_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( +// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vdup {{.*}} {position = "LOWEST"} +// LOWER-COUNT-8: pto.vmul +// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto new file mode 100644 index 0000000000..ec8816fbeb --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_tail_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<384xf32> + pto.vmi.group_store %sum, %dst[%off], %c8 {num_groups = 6} + : !pto.vmi.vreg<384xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( +// LOWER-COUNT-6: pto.vlds +// LOWER-COUNT-6: pto.vcadd +// LOWER-COUNT-6: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto new file mode 100644 index 0000000000..bf38aee552 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_s64_truncf( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c16 = arith.constant 16 : index + %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + pto.vmi.group_store %sum16, %dst[%off], %c16 {num_groups = 8} + : !pto.vmi.vreg<512xf16>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_truncf( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] +// ASSIGN-SAME: vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM16]] +// ASSIGN-SAME: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.ptr, !pto.mask -> !pto.vreg<128xf16> +// LOWER: pto.pge_b16 "PAT_VL1" +// LOWER: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto new file mode 100644 index 0000000000..c3e876be05 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8( +// CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: vmi.selected_plan = "s8_reduce_contiguous" +// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto new file mode 100644 index 0000000000..1329965530 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_slots8_store( + %source: !pto.vmi.vreg<64xf32>, + %mask: !pto.vmi.mask<64xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// ASSIGN-SAME: vmi.selected_plan = "s8_reduce_contiguous" +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( +// LOWER: %[[SUM:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// LOWER: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER: pto.vsts %[[SUM]], %arg2[%arg3], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto new file mode 100644 index 0000000000..9f4349d40e --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -0,0 +1,58 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<128xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots1( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } + + func.func @vmi_layout_assignment_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( +// CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: pto.vmi.group_store %[[OUT]] +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto new file mode 100644 index 0000000000..a96b847256 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slot_load_dual_layout( + %rhs_base: !pto.ptr, + %source16: !pto.vmi.vreg<128xf32>, + %mask16: !pto.vmi.mask<128xpred>, + %source64: !pto.vmi.vreg<512xf32>, + %mask64: !pto.vmi.mask<512xpred>, + %out16: !pto.ptr, + %out64: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %rhs16 = pto.vmi.group_slot_load %rhs_base[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %source16, %mask16 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %outv16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %outv16, %out16[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %rhs64 = pto.vmi.group_slot_load %rhs_base[%off], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum64 = pto.vmi.group_reduce_addf %source64, %mask64 + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %outv64 = pto.vmi.addf %sum64, %rhs64 + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %outv64, %out64[%off], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] +// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// LOWER: pto.vsldb +// LOWER: pto.vsts {{.*}}, %arg21[%arg23], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-COUNT-8: pto.vsldb +// LOWER-COUNT-8: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto new file mode 100644 index 0000000000..e6e459c435 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto new file mode 100644 index 0000000000..f8d7bc8af8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( + %src: !pto.ptr, %off: index) { + %c2 = arith.constant 2 : index + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group + // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned scalar load lowering is not implemented + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %out = pto.vmi.group_slot_load %src[%off], %c2 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto new file mode 100644 index 0000000000..d327a7b8bc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slots_cf_join( + %cond: i1, + %src: !pto.ptr, + %rhs: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %a : !pto.vmi.vreg<128xf32> + } else { + %b = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %b : !pto.vmi.vreg<128xf32> + } + %bias = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %out = pto.vmi.addf %sum, %bias + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slots_cf_join( +// CHECK: %[[IF:.*]] = scf.if +// CHECK-SAME: -> (!pto.vreg<64xf32>) +// CHECK: pto.vldsx2 +// CHECK: pto.vcgadd +// CHECK: pto.vcgadd +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: else +// CHECK: pto.vsldb +// CHECK: scf.yield {{.*}} : !pto.vreg<64xf32> +// CHECK: %[[BIAS:.*]] = pto.vsldb +// CHECK: pto.vadd %[[IF]], %[[BIAS]] +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto new file mode 100644 index 0000000000..d0ac525849 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_fanout( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %sum_dst: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %scaled = pto.vmi.mulf %source, %broadcast + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %scaled_sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] +// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( +// LOWER-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// LOWER: %[[FIRST_SUM:.*]] = pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: pto.vsts %[[FIRST_SUM]], %arg4[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER: pto.vdup %[[C4]] +// LOWER: pto.vselr %[[FIRST_SUM]] +// LOWER-COUNT-2: pto.vmul +// LOWER: pto.vsts {{.*}}, %arg5[%arg6], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto new file mode 100644 index 0000000000..e4b48121bc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slots_scf_for( + %init: !pto.ptr, + %base: !pto.ptr, + %out: !pto.ptr, + %off: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %acc0 = pto.vmi.group_slot_load %init[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.group_store %acc, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load +// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[ACC:.*]] = scf.for +// ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) +// ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN: %[[X:.*]] = pto.vmi.group_load +// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: scf.yield +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[ACC]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( +// LOWER: pto.vsldb +// LOWER: scf.for +// LOWER-COUNT-2: pto.vcgadd +// LOWER: pto.vadd +// LOWER: scf.yield +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto new file mode 100644 index 0000000000..452ee085ac --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_store_slots1_unit_stride_invalid( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + // CHECK: VMI-UNSUPPORTED: pto.vmi.group_store + // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group + // CHECK-SAME: requires constant positive row_stride divisible by 8 elements + // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented + // CHECK: note: see current operation: "pto.vmi.group_store" + // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto new file mode 100644 index 0000000000..8a74de4097 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( + %src: !pto.ptr, + %out32: !pto.ptr, + %out16: !pto.ptr, + %off: index) { + %c96 = arith.constant 96 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %out32[%off], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %out16[%off], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[X]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[X_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[H]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_granularity_f32_f16_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.pge_b32 "PAT_ALL" +// LOWER: pto.pge_b32 "PAT_VL32" +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vor +// LOWER: pto.plt_b16 +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto new file mode 100644 index 0000000000..62ef723511 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_mask_select_store.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_mask_select_store( + %src: !pto.ptr, + %rhs: !pto.ptr, + %dense: !pto.ptr, + %masked: !pto.ptr, + %off: index) { + %c48 = arith.constant 48 : index + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %y = pto.vmi.load %rhs[%off] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %y + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %dense[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %masked[%off], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[Y:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[X]], %[[Y]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[PASS:.*]] = pto.vmi.select %[[MASK]], %[[SUM]], %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[PASS]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN: pto.vmi.masked_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_mask_layout +// ASSIGN-NOT: pto.vmi.ensure_mask_granularity + +// LOWER-LABEL: func.func @vmi_layout_assignment_mask_select_store( +// LOWER: pto.vlds +// LOWER: pto.vlds +// LOWER: pto.plt_b32 +// LOWER: pto.vadd +// LOWER: pto.vsel +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto new file mode 100644 index 0000000000..4004ff6fcc --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -0,0 +1,66 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_masked_load_dense_group_users( + %base: !pto.ptr, + %copy_out: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %copy_out[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[ZERO:.*]] = pto.vmi.broadcast +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.masked_load +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X]] +// ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( +// LOWER-COUNT-4: pto.vsel +// LOWER-COUNT-4: pto.vsts +// LOWER: pto.vdintlv +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vadd +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto new file mode 100644 index 0000000000..bad43bb869 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_masked_load_group_tail_s32( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %mask = pto.vmi.create_group_mask %c25 + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED: pto.vmi.group_reduce_addf +// CHECK-SAME: s32 block8 lowering does not yet support partial create_group_mask active_elems_per_group during layout assignment +// CHECK-NOT: vmi.selected_plan = "s32_reduce_block8_stride" diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto new file mode 100644 index 0000000000..a2d4cab4d9 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_non_load_s32_reduce( + %base: !pto.ptr, + %bias: f32, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %a = pto.vmi.load %base[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.addf %a, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// ASSIGN-SAME: %[[BASE:arg[0-9]+]]: !pto.ptr +// ASSIGN-SAME: %[[BIAS:arg[0-9]+]]: f32 +// ASSIGN: %[[A:.*]] = pto.vmi.load %[[BASE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[BIASV:.*]] = pto.vmi.broadcast %[[BIAS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[X:.*]] = pto.vmi.addf %[[A]], %[[BIASV]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( +// LOWER-COUNT-4: pto.vdup %arg1 +// LOWER-COUNT-4: pto.vadd {{.*}}, {{.*}}, {{.*}} : !pto.vreg<64xf32> +// LOWER: %[[VL8:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// LOWER-COUNT-4: pto.vcgadd +// LOWER-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[VL8]] +// LOWER: pto.vsts {{.*}}, %arg2[%arg3], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto new file mode 100644 index 0000000000..3005e53c0a --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_packed_group_slots_truncf_invalid( + %source: !pto.vmi.vreg<128xf32>, + %mask: !pto.vmi.mask<128xpred>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: requires + // CHECK-SAME: #pto.vmi.layout + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion + // CHECK-SAME: unsupported source/result layout pair + %h = pto.vmi.truncf %sum + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf16>, !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto new file mode 100644 index 0000000000..01e8e55caf --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_f16_store_reduce( + %src: !pto.ptr, + %sum: !pto.ptr, + %dense: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %x16 = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %x32, %dense[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_parity" +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SUM]] +// ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X32_DENSE]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_f16_store_reduce( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vcgadd +// LOWER: pto.vcgadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto new file mode 100644 index 0000000000..f354adb6e8 --- /dev/null +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -0,0 +1,18 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_group_slots_invalid( + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + return + } +} + +// CHECK: #pto.vmi.layout requires slots to be positive and divide num_groups when specified diff --git a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto new file mode 100644 index 0000000000..102efd4f0e --- /dev/null +++ b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @vmi_load_full_read_elems_invalid(%src: !pto.ptr) { + %c0 = arith.constant 0 : index + %value = pto.vmi.load %src[%c0] {full_read_elems = 0} + : !pto.ptr -> !pto.vmi.vreg<100xf32> + return + } +} + +// CHECK: 'pto.vmi.load' op requires full_read_elems to be positive diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto index bff24c6e07..3ba8eb29dc 100644 --- a/test/lit/vmi/vmi_op_verifier_basic.pto +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -16,6 +16,7 @@ module { %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index %f32 = arith.constant 1.000000e+00 : f32 %f16 = arith.constant 1.000000e+00 : f16 %active = arith.constant 64 : index @@ -40,6 +41,10 @@ module { %trunc = pto.vmi.truncf %ext : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %slot_loaded = pto.vmi.group_slot_load %ptr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %slot_loaded, %ptr[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr %tile_read = pto.vmi.tile_read %tile : memref<128xf32> -> !pto.vmi.vreg<128xf32> pto.vmi.tile_write %tile_read, %tile : !pto.vmi.vreg<128xf32>, memref<128xf32> @@ -94,6 +99,8 @@ module { // CHECK: pto.vmi.extf // CHECK: pto.vmi.truncf // CHECK: pto.vmi.load +// CHECK: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_store // CHECK: pto.vmi.store // CHECK: pto.vmi.tile_read // CHECK: pto.vmi.tile_write diff --git a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto new file mode 100644 index 0000000000..950215e5e4 --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @callee(%x: !pto.vmi.vreg<128xf32>) + -> !pto.vmi.vreg<128xf32> { + %sum = pto.vmi.addf %x, %x + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + return %sum : !pto.vmi.vreg<128xf32> + } + + func.func @caller(%value: f32, %dst: !pto.ptr, %off: index) { + pto.vecscope { + %x = pto.vmi.broadcast %value : f32 -> !pto.vmi.vreg<128xf32> + %r = func.call @callee(%x) + : (!pto.vmi.vreg<128xf32>) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %r, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + return + } + } +} + +// CHECK: cannot infer resultless pto.vecscope because VPTO vector-scope data cannot have external users +// CHECK-SAME: escaping value type is '!pto.vreg<64xf32>' diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto new file mode 100644 index 0000000000..3a96e94d67 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_broadcast %source + {num_groups = 128, vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto new file mode 100644 index 0000000000..a03cdfd9df --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) { + %out = pto.vmi.group_broadcast %source {num_groups = 128} + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_broadcast requires full source chunks +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto new file mode 100644 index 0000000000..563f939f77 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_load_missing_plan_invalid( + %source: !pto.ptr, + %row_stride: index) { + %c0 = arith.constant 0 : index + %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_load requires contiguous full result chunks +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index 6a10e168dd..e757c583f6 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -15,7 +15,8 @@ module { %row_stride: index, %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index - %v = pto.vmi.group_load %src[%c0], %row_stride {num_groups = 2} + %v = pto.vmi.group_load %src[%c0], %row_stride + {num_groups = 2, vmi.selected_plan = "group_load_contiguous_chunks"} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto new file mode 100644 index 0000000000..ee12b742e8 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc, vmi.selected_plan = "s64_reduce_row_local"} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK: pto.vcadd +// CHECK: pto.vadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: pto.vcadd +// CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] +// CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto new file mode 100644 index 0000000000..96d975ab7d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s64_missing_plan_invalid( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto new file mode 100644 index 0000000000..305c488dd5 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc, vmi.selected_plan = "s8_reduce_contiguous"} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto new file mode 100644 index 0000000000..b67cb34f2d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_slots8_missing_plan_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto new file mode 100644 index 0000000000..5927f63069 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_slots8( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots1( + %src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots1_row_local"} + : !pto.ptr + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @vmi_to_vpto_group_slot_load_slots8_store( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8( +// CHECK-DAG: %[[MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[MASK]] : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots1( +// CHECK-COUNT-8: pto.vsldb + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8_store( +// CHECK: %[[LOAD_MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg2 : -> +// CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[LOAD_MASK]] +// CHECK: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts %[[OUT]], %arg1[%arg2], %[[STORE_MASK]] : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto new file mode 100644 index 0000000000..f442e2fbbe --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_missing_plan_invalid( + %src: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto new file mode 100644 index 0000000000..10d9a2d3fa --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_load_nonunit_slots8_invalid( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout +// CHECK: slots=8 group_slot_load requires constant unit source_group_stride diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto new file mode 100644 index 0000000000..d24f504e67 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %source + {vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16"} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1( +// CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" +// CHECK-COUNT-8: pto.vcvt {{.*}}, %[[VL1]] {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto new file mode 100644 index 0000000000..f265dc0912 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) { + %narrow = pto.vmi.truncf %source + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.truncf supports only +// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto new file mode 100644 index 0000000000..305b039d72 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots8_nonunit_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index, + %row_stride: index) { + pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK-SAME: pto.vmi.group_store +// CHECK-SAME: slots=8 group_store currently requires constant unit row_stride diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index 7d302805d6..a0cc8215cb 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -252,9 +252,8 @@ module { // CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> @@ -302,7 +301,6 @@ module { // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: scf.if // CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 // CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto index 04613c441d..5798114cc7 100644 --- a/test/lit/vmi/vmi_type_attr_parse.pto +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -11,17 +11,23 @@ module attributes { pto.vmi_contiguous = #pto.vmi.layout, pto.vmi_deinterleaved2 = #pto.vmi.layout, - pto.vmi_deinterleaved4 = #pto.vmi.layout + pto.vmi_deinterleaved4 = #pto.vmi.layout, + pto.vmi_deinterleaved4_block8 = + #pto.vmi.layout, + pto.vmi_group_slots8 = #pto.vmi.layout } { func.func @vmi_type_attr_parse( %surface: !pto.vmi.vreg<128xf32>, %contiguous: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %group_slots8: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %surface_mask: !pto.vmi.mask<128xpred>, %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, - %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %mask_b32_block8: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { return } } @@ -29,12 +35,17 @@ module attributes { // CHECK: pto.vmi_contiguous = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved2 = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout +// CHECK: pto.vmi_deinterleaved4_block8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots8 = #pto.vmi.layout // CHECK-LABEL: func.func @vmi_type_attr_parse( // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py new file mode 100644 index 0000000000..7df1eedef3 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SCALE = np.float32(0.5) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src + SCALE + golden_sum = np.sum(src * SCALE, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto new file mode 100644 index 0000000000..3881dfc10f --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_broadcast_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 5.000000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %scale_vec = pto.vmi.broadcast %scale : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %copy = pto.vmi.addf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %copy, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %prod = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp new file mode 100644 index 0000000000..21e26d6cf5 --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_broadcast_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_broadcast_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp new file mode 100644 index 0000000000..b43a794cdb --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_broadcast_dense_group_users_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py new file mode 100644 index 0000000000..837961af76 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v2", "v3"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py new file mode 100644 index 0000000000..6e5edd801a --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + copy_out = np.full(INPUT_ELEMS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden_sum[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + copy_out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + src.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto new file mode 100644 index 0000000000..2d0dcd2c64 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dense_group_reduce_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %copy_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp new file mode 100644 index 0000000000..1249378267 --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dense_group_reduce_multi_consumer_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *copy); + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream) { + vmi_dense_group_reduce_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)copy); +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp new file mode 100644 index 0000000000..0482d8339d --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dense_group_reduce_multi_consumer_kernel(float *src, float *sum, + float *copy, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kCopyElems = kInputElems; + size_t srcBytes = kInputElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t copyBytes = kCopyElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + float *copyHost = nullptr; + float *copyDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", copyBytes, copyHost, copyBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dense_group_reduce_multi_consumer_kernel(srcDevice, sumDevice, + copyDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", copyHost, copyBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(copyDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(copyHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py new file mode 100644 index 0000000000..d00c9b8b26 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) + if close: + return True + diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + return False + + +def main() -> None: + if not check_f32("v2", 1e-4, 1e-4) or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py new file mode 100644 index 0000000000..9034fe8d42 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/golden.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + golden_out8 = np.empty((ROWS, GROUP_SIZE), dtype=np.uint8) + for row in range(ROWS): + value_idx = row % len(VALUES) + if row == 0: + src[row, :] = np.tile(VALUES, GROUP_SIZE // len(VALUES)) + golden_out8[row, :] = np.tile(F8E4M3FN_BYTES, GROUP_SIZE // len(F8E4M3FN_BYTES)) + else: + src[row, :] = VALUES[value_idx] + golden_out8[row, :] = F8E4M3FN_BYTES[value_idx] + + golden_sum = np.sum(src, axis=1, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out8 = np.full(ROWS * GROUP_SIZE, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out8.tofile(output_dir / "v3.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_out8.reshape(-1).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto new file mode 100644 index 0000000000..6f68510ede --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f32_to_f8_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x32 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %x8 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %x8, %ub_out8_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp new file mode 100644 index 0000000000..eef7fac9d0 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f32_to_f8_store_reduce_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ uint8_t *out8); + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream) { + vmi_f32_to_f8_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp new file mode 100644 index 0000000000..1e3e7e8a86 --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f32_to_f8_store_reduce_kernel(float *src, float *sum, + uint8_t *out8, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kSrcElems = kRows * kGroupSize; + constexpr size_t kSumElems = kRows; + constexpr size_t kOut8Elems = kSrcElems; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + size_t out8Bytes = kOut8Elems * sizeof(uint8_t); + float *srcHost = nullptr; + float *sumHost = nullptr; + uint8_t *out8Host = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + uint8_t *out8Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f32_to_f8_store_reduce_kernel(srcDevice, sumDevice, out8Device, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", out8Host, out8Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(out8Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(out8Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/f8-compute-f8/compare.py b/test/vpto/cases/vmi/f8-compute-f8/compare.py new file mode 100644 index 0000000000..68c53a335e --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.uint8) + output = np.fromfile("v2.bin", dtype=np.uint8) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/golden.py b/test/vpto/cases/vmi/f8-compute-f8/golden.py new file mode 100644 index 0000000000..e150b09545 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/golden.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +F8E4M3FN_TIMES2 = np.array([0x00, 0x40, 0xC0, 0x38, 0x48, 0xC8, 0x50, 0xD0], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(F8E4M3FN_BYTES) - 1) // len(F8E4M3FN_BYTES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8) + dst = np.full(ELEMS, 0xA5, dtype=np.uint8) + golden = np.tile(F8E4M3FN_TIMES2, repeats)[:ELEMS].astype(np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/f8-compute-f8/kernel.pto b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto new file mode 100644 index 0000000000..568cf5fbde --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_f8_compute_f8_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %scale = arith.constant 2.000000e+00 : f32 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst_u8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_f8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x8 = pto.vmi.load %ub_src_f8[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<256xf32> + %y32 = pto.vmi.mulf %x32, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %y8 = pto.vmi.truncf %y32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %y8, %ub_dst_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_u8, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/launch.cpp b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp new file mode 100644 index 0000000000..63b5269670 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_f8_compute_f8_kernel(__gm__ uint8_t *src, __gm__ uint8_t *dst); + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, + void *stream) { + vmi_f8_compute_f8_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/main.cpp b/test/vpto/cases/vmi/f8-compute-f8/main.cpp new file mode 100644 index 0000000000..fffc2d6e65 --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/main.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_f8_compute_f8_kernel(uint8_t *src, uint8_t *dst, void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t bytes = kElems * sizeof(uint8_t); + uint8_t *srcHost = nullptr; + uint8_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", bytes, srcHost, bytes); + ReadFile("./v2.bin", bytes, dstHost, bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, bytes, srcHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, bytes, dstHost, bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_f8_compute_f8_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, bytes, dstDevice, bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/f8-compute-f8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py new file mode 100644 index 0000000000..da96a2ff71 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/compare.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_sum = np.fromfile("golden_v2.bin", dtype=np.float32) + output_sum = np.fromfile("v2.bin", dtype=np.float32) + if golden_sum.shape != output_sum.shape or not np.allclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden_sum, output_sum, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden_sum[idx] if idx >= 0 else 'n/a'} " + f"output={output_sum[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + golden_dense = np.fromfile("golden_v3.bin", dtype=np.float16) + output_dense = np.fromfile("v3.bin", dtype=np.float16) + if golden_dense.shape != output_dense.shape or not np.array_equal(golden_dense, output_dense): + diff = np.nonzero(golden_dense.view(np.uint16) != output_dense.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden_dense[idx] if idx >= 0 else 'n/a'} " + f"output={output_dense[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py new file mode 100644 index 0000000000..a238aaf082 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/golden.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) +SUM_SENTINEL = np.float32(-911.0) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + sum_out = np.full(ROWS, SUM_SENTINEL, dtype=np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_dense = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32) + golden_sum[row] = np.sum(values * row_sum, dtype=np.float32) + golden_dense[begin : begin + GROUP_SIZE] = row_sum.astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto new file mode 100644 index 0000000000..3c14b7fc38 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto @@ -0,0 +1,69 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_broadcast_multi_consumer_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b_for_mul = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b_for_mul + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b_for_cast = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %h = pto.vmi.truncf %b_for_cast + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %h, %ub_dense[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp new file mode 100644 index 0000000000..2a562a57e3 --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_broadcast_multi_consumer_kernel(__gm__ float *src, __gm__ float *sum, + __gm__ half *dense); + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream) { + vmi_group_broadcast_multi_consumer_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ half *)dense); +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp new file mode 100644 index 0000000000..dc39a0c47d --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_broadcast_multi_consumer_kernel(float *src, float *sum, + uint16_t *dense, + void *stream); + +int main() { + constexpr size_t kElems = 128; + constexpr size_t kRows = 8; + size_t srcBytes = kElems * sizeof(float); + size_t sumBytes = kRows * sizeof(float); + size_t denseBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *sumHost = nullptr; + float *sumDevice = nullptr; + uint16_t *denseHost = nullptr; + uint16_t *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_broadcast_multi_consumer_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py new file mode 100644 index 0000000000..5c25033808 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ROW_STRIDE = 24 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.25, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto new file mode 100644 index 0000000000..f28676f8d5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s16_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c24 = arith.constant 24 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp new file mode 100644 index 0000000000..ef8fa0d082 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s16_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s16_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp new file mode 100644 index 0000000000..414e34200e --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s16_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s16_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py new file mode 100644 index 0000000000..8cb473640d --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.zeros(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-1.0, 1.0, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto new file mode 100644 index 0000000000..cf2aea21d7 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto @@ -0,0 +1,59 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_broadcast_reduce_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %broadcast + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scaled_sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp new file mode 100644 index 0000000000..d9218a9389 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_broadcast_reduce_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream) { + vmi_group_load_s32_stride_broadcast_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp new file mode 100644 index 0000000000..b994c2192f --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_broadcast_reduce_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py new file mode 100644 index 0000000000..efe2d5f3b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +ROW_STRIDE = 40 +INPUT_ELEMS = ROWS * ROW_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(INPUT_ELEMS, np.float32(-9.0), dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.75, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto new file mode 100644 index 0000000000..7afde7d6f5 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_load_s32_stride_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1280_i64 = arith.constant 1280 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1280_i64 + nburst(%c1_i64, %c1280_i64, %c1280_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.group_load %ub_src[%c0], %c40 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp new file mode 100644 index 0000000000..9443a9cfb3 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_load_s32_stride_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_load_s32_stride_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp new file mode 100644 index 0000000000..b67ef78981 --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_load_s32_stride_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 40; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_load_s32_stride_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/compare.py b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py new file mode 100644 index 0000000000..dc3a89703c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/compare.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(output_name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(output_name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {output_name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {output_name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + check("v6.bin", "golden_v6.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/golden.py b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py new file mode 100644 index 0000000000..24071a1b49 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((ROWS, cols), dtype=np.float32) + for row in range(ROWS): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def write_case(output_dir: Path, matrix: np.ndarray, src_name: str, dst_name: str, golden_name: str) -> None: + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(matrix, axis=1, dtype=np.float32).astype(np.float32) + matrix.reshape(-1).tofile(output_dir / src_name) + dst.tofile(output_dir / dst_name) + golden.tofile(output_dir / golden_name) + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + write_case(output_dir, fill_matrix(8, -0.5, 0.03125), "v1.bin", "v4.bin", "golden_v4.bin") + write_case(output_dir, fill_matrix(16, -0.75, 0.046875), "v2.bin", "v5.bin", "golden_v5.bin") + write_case(output_dir, fill_matrix(32, -0.875, 0.0625), "v3.bin", "v6.bin", "golden_v6.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto new file mode 100644 index 0000000000..4db72772c1 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_basic_store_kernel(%src8_gm: !pto.ptr, + %src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %dst8_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src16 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_dst8 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src8_gm, %ub_src8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask8 = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x8 = pto.vmi.load %ub_src8[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %sum8 = pto.vmi.group_reduce_addf %x8, %mask8 {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf32> + pto.vmi.group_store %sum8, %ub_dst8[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xf32>, !pto.ptr + + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst8, %dst8_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp new file mode 100644 index 0000000000..a7304f9a15 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_basic_store_kernel(__gm__ float *src8, + __gm__ float *src16, + __gm__ float *src32, + __gm__ float *dst8, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream) { + vmi_group_reduce_basic_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src8, (__gm__ float *)src16, (__gm__ float *)src32, + (__gm__ float *)dst8, (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp new file mode 100644 index 0000000000..4ddb71365b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/main.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_basic_store_kernel(float *src8, float *src16, + float *src32, float *dst8, + float *dst16, float *dst32, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kSrc8Elems = kRows * 8; + constexpr size_t kSrc16Elems = kRows * 16; + constexpr size_t kSrc32Elems = kRows * 32; + constexpr size_t kOutputElems = kRows; + size_t src8Bytes = kSrc8Elems * sizeof(float); + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src8Host = nullptr; + float *src16Host = nullptr; + float *src32Host = nullptr; + float *dst8Host = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src8Device = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *dst8Device = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src8Host), src8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst8Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src8Device, src8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst8Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src8Bytes, src8Host, src8Bytes); + ReadFile("./v2.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v3.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v4.bin", dstBytes, dst8Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v6.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src8Device, src8Bytes, src8Host, src8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst8Device, dstBytes, dst8Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_basic_store_kernel( + src8Device, src16Device, src32Device, dst8Device, dst16Device, + dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst8Host, dstBytes, dst8Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst8Host, dstBytes); + WriteFile("./v5.bin", dst16Host, dstBytes); + WriteFile("./v6.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src8Device); + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(dst8Device); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src8Host); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(dst8Host); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-basic-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..05510a7bd9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..e41c4d656d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..f180d41359 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..f3b88b52fa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..f8e59f415f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.5, 0.75, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(21.0, 24.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.046875) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.5) + reduction = np.sum(src[row, :ACTIVE], dtype=np.float32) + golden[row] = np.sum(src[row, :ACTIVE] * reduction, dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..56f042af1e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..bd5cc88024 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + __gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..b87811e20c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_broadcast_reduce_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..808e7e271f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + active_base = np.linspace(-0.75, 0.375, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(25.0, 28.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.0625) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(2.0) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..c07f2782fd --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..745e836949 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s16_group_mask_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..3d55e6ccfa --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_group_mask_tail_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py new file mode 100644 index 0000000000..d3f358ba45 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ACTIVE = 12 +ROW_STRIDE = 24 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.full(ROWS * ROW_STRIDE, np.float32(99.0), dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + active_base = np.linspace(-0.625, 0.5, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(31.0, 35.0, GROUP_SIZE - ACTIVE, dtype=np.float32) + for row in range(ROWS): + begin = row * ROW_STRIDE + src[begin : begin + ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[begin + ACTIVE : begin + GROUP_SIZE] = inactive_base + np.float32(row) + golden[row] = np.sum(src[begin : begin + ACTIVE], dtype=np.float32) + + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto new file mode 100644 index 0000000000..b53a1a51ff --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + %src_gm: !pto.ptr, %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c24 = arith.constant 24 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c768_i64 = arith.constant 768 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c768_i64 + nburst(%c1_i64, %c768_i64, %c768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.group_load %ub_src[%c0], %c24 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c12 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp new file mode 100644 index 0000000000..ef2e2aaef2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_stride_group_mask_tail_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream) { + vmi_group_reduce_s16_stride_group_mask_tail_store_kernel<<<1, nullptr, + stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp new file mode 100644 index 0000000000..4a6af8cac7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + float *src, float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kRowStride = 24; + constexpr size_t kInputElems = kRows * kRowStride; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_stride_group_mask_tail_store_kernel( + srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py new file mode 100644 index 0000000000..39f37ccd7c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape != output.shape or not np.array_equal(golden, output): + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py new file mode 100644 index 0000000000..2010556d20 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SEED = 29 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-2.0, 2.0, size=ELEMS).astype(np.float32) + dst = np.full(ELEMS, SENTINEL, dtype=np.float16) + golden = np.full(ELEMS, SENTINEL, dtype=np.float16) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = src[begin : begin + GROUP_SIZE] + row_sum = np.sum(values, dtype=np.float32).astype(np.float16) + golden[begin : begin + GROUP_SIZE] = row_sum + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto new file mode 100644 index 0000000000..29193f5d6b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s16_truncf_broadcast_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> + pto.vmi.store %rows, %ub_dst[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp new file mode 100644 index 0000000000..21b6e43c3d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s16_truncf_broadcast_store_kernel(__gm__ float *src, + __gm__ half *dst); + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream) { + vmi_group_reduce_s16_truncf_broadcast_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp new file mode 100644 index 0000000000..13fe482440 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/main.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(float *src, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t dstBytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s16_truncf_broadcast_store_kernel(srcDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py new file mode 100644 index 0000000000..1614628a0b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +BIAS = np.float32(0.25) +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values + BIAS, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto new file mode 100644 index 0000000000..d21fb5efd2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_add_bias_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %bias = arith.constant 2.500000e-01 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %biasv = pto.vmi.broadcast %bias : f32 -> !pto.vmi.vreg<256xf32> + %biased = pto.vmi.addf %x, %biasv + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %biased, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp new file mode 100644 index 0000000000..b5526b9b23 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_add_bias_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_add_bias_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp new file mode 100644 index 0000000000..5c85668ceb --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_add_bias_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_add_bias_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..aef1ece1b4 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..f51fe89924 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..e8decb88f5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..eba17dbdd0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py new file mode 100644 index 0000000000..409f321f7d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto new file mode 100644 index 0000000000..de08d084e6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto @@ -0,0 +1,63 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_cf_join_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %zero = arith.constant 0.000000e+00 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond = arith.cmpi eq, %c0, %c0 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = scf.if %cond -> (!pto.vmi.vreg<256xf32>) { + %then_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + scf.yield %then_x : !pto.vmi.vreg<256xf32> + } else { + %else_x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %else_y = pto.vmi.addf %else_x, %zero_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.yield %else_y : !pto.vmi.vreg<256xf32> + } + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp new file mode 100644 index 0000000000..4204a6ca52 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_cf_join_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp new file mode 100644 index 0000000000..a504036a2e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_cf_join_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_cf_join_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py new file mode 100644 index 0000000000..a00c19efbe --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +GROUP_SIZE = 32 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.75, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto new file mode 100644 index 0000000000..758691c5cf --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_multitile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 16} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp new file mode 100644 index 0000000000..88c109d7d0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_multitile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s32_multitile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp new file mode 100644 index 0000000000..f30ea2a367 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_multitile_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_multitile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py new file mode 100644 index 0000000000..8c5fc67aca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/compare.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py new file mode 100644 index 0000000000..cf80936861 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +PHYSICAL_ROWS = 8 +ACTIVE_ROWS = 6 +GROUP_SIZE = 32 +INPUT_ELEMS = PHYSICAL_ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + golden = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + + base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) + for row in range(PHYSICAL_ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.0625) + src[begin : begin + GROUP_SIZE] = values + if row < ACTIVE_ROWS: + golden[row] = np.sum(values, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto new file mode 100644 index 0000000000..fabed4ee8b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s32_tail_full_tile_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c192 = arith.constant 192 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> + %x = pto.vmi.load %ub_src[%c0] {full_read_elems = 256} + : !pto.ptr -> !pto.vmi.vreg<192xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> + -> !pto.vmi.vreg<192xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 6} + : !pto.vmi.vreg<192xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp new file mode 100644 index 0000000000..5dd1b3c148 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s32_tail_full_tile_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s32_tail_full_tile_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp new file mode 100644 index 0000000000..5cd1b690d2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/main.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kPhysicalRows = 8; + constexpr size_t kGroupSize = 32; + constexpr size_t kInputElems = kPhysicalRows * kGroupSize; + constexpr size_t kOutputElems = kPhysicalRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s32_tail_full_tile_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py new file mode 100644 index 0000000000..24fa390b6c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +INPUT_ELEMS = ROWS * GROUP_SIZE +OUTPUT_STRIDE = 8 +OUTPUT_ELEMS = ROWS * OUTPUT_STRIDE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + dst = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + golden = np.full(OUTPUT_ELEMS, SENTINEL, dtype=np.float32) + + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.03125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto new file mode 100644 index 0000000000..bcb027a753 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_broadcast_reduce_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %ysum, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp new file mode 100644 index 0000000000..ba45139736 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_broadcast_reduce_store_kernel(__gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream) { + vmi_group_reduce_s64_broadcast_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp new file mode 100644 index 0000000000..91e2c97119 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/main.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(float *src, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_broadcast_reduce_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py new file mode 100644 index 0000000000..6d0d25229a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +RHS_STRIDE = 8 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base_row = np.linspace(-0.5, 0.5, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base_row + np.float32(row) * np.float32(0.03125) + + rhs = np.linspace(-0.75, 0.75, ROWS * RHS_STRIDE, dtype=np.float32) + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = ( + np.sum(src[row, :], dtype=np.float32) + rhs[row * RHS_STRIDE] + ) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto new file mode 100644 index 0000000000..04338c1c1b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_slot_add_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %rhs = pto.vmi.group_slot_load %ub_rhs[%c0], %c8 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %out = pto.vmi.addf %sum, %rhs + : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<512xf32> + pto.vmi.group_store %out, %ub_dst[%c0], %c8 {num_groups = 8} + : !pto.vmi.vreg<512xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp new file mode 100644 index 0000000000..7225148ff7 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_slot_add_store_kernel(__gm__ float *src, + __gm__ float *rhs, + __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream) { + vmi_group_reduce_s64_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp new file mode 100644 index 0000000000..1f5acfaa5c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_slot_add_store_kernel(float *src, float *rhs, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 64; + constexpr size_t kRhsStride = 8; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kRhsElems = kRows * kRhsStride; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_slot_add_store_kernel(srcDevice, rhsDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py new file mode 100644 index 0000000000..17b5e600cc --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/compare.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + if golden.shape != output.shape or not np.all(close): + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + g = golden[idx] if idx >= 0 and idx < golden.size else "n/a" + o = output[idx] if idx >= 0 and idx < output.size else "n/a" + print(f"[ERROR] compare failed idx={idx} golden={g} output={o}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py new file mode 100644 index 0000000000..83ac2d015e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 6 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 8 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float32) + for row in range(ROWS): + golden[row * OUTPUT_STRIDE] = np.sum(src[row, :], dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto new file mode 100644 index 0000000000..5167c9198a --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_tail_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c384 = arith.constant 384 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c192_i64 = arith.constant 192 : i64 + %c1536_i64 = arith.constant 1536 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1536_i64 + nburst(%c1_i64, %c1536_i64, %c1536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c384 : index -> !pto.vmi.mask<384xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<384xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} + : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> + -> !pto.vmi.vreg<384xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c8 {num_groups = 6} + : !pto.vmi.vreg<384xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c192_i64 + nburst(%c1_i64, %c192_i64, %c192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp new file mode 100644 index 0000000000..afdf98b76d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_tail_store_kernel(__gm__ float *src, __gm__ float *dst); + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream) { + vmi_group_reduce_s64_tail_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp new file mode 100644 index 0000000000..3223b3561b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/main.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_tail_store_kernel(float *src, float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 6; + constexpr size_t kGroupSize = 64; + constexpr size_t kOutputStride = 8; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows * kOutputStride; + size_t srcBytes = kInputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *dstHost = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_tail_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py new file mode 100644 index 0000000000..cce2c778b9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py new file mode 100644 index 0000000000..62b6de2d6e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/golden.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 64 +OUTPUT_STRIDE = 16 +SENTINEL = np.float16(-17.5) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float32) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.046875) + + dst = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + golden = np.full(ROWS * OUTPUT_STRIDE, SENTINEL, dtype=np.float16) + for row in range(ROWS): + row_sum = np.sum(src[row, :], dtype=np.float32) + golden[row * OUTPUT_STRIDE] = np.float16(row_sum) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto new file mode 100644 index 0000000000..6436738080 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_s64_truncf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c512 : index -> !pto.vmi.mask<512xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> + %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + %sum16 = pto.vmi.truncf %sum32 + : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + pto.vmi.group_store %sum16, %ub_dst[%c0], %c16 {num_groups = 8} + : !pto.vmi.vreg<512xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp new file mode 100644 index 0000000000..bd0c1e4fa2 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/launch.cpp @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_s64_truncf_store_kernel(__gm__ float *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_s64_truncf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp new file mode 100644 index 0000000000..941a7d4622 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/main.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_s64_truncf_store_kernel(float *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kSrcElems = 512; + constexpr size_t kDstElems = 128; + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *srcDevice = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_s64_truncf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py new file mode 100644 index 0000000000..edcf881e8d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/compare.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v4.bin", "golden_v4.bin") + check("v5.bin", "golden_v5.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py new file mode 100644 index 0000000000..7e57da8318 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +S16 = 16 +S32 = 32 +SENTINEL = np.float32(-777.0) + + +def fill_matrix(rows: int, cols: int, base_start: float, row_step: float) -> np.ndarray: + base = np.linspace(base_start, base_start + 1.0, cols, dtype=np.float32) + out = np.empty((rows, cols), dtype=np.float32) + for row in range(rows): + out[row, :] = base + np.float32(row) * np.float32(row_step) + return out + + +def generate(output_dir: Path) -> None: + src16 = fill_matrix(ROWS, S16, -0.75, 0.03125) + src32 = fill_matrix(ROWS, S32, -0.875, 0.0625) + rhs = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + dst16 = np.full(ROWS, SENTINEL, dtype=np.float32) + dst32 = np.full(ROWS, SENTINEL, dtype=np.float32) + + golden16 = np.sum(src16, axis=1, dtype=np.float32).astype(np.float32) + rhs + golden32 = np.sum(src32, axis=1, dtype=np.float32).astype(np.float32) + rhs + + output_dir.mkdir(parents=True, exist_ok=True) + src16.reshape(-1).tofile(output_dir / "v1.bin") + src32.reshape(-1).tofile(output_dir / "v2.bin") + rhs.tofile(output_dir / "v3.bin") + dst16.tofile(output_dir / "v4.bin") + dst32.tofile(output_dir / "v5.bin") + golden16.astype(np.float32).tofile(output_dir / "golden_v4.bin") + golden32.astype(np.float32).tofile(output_dir / "golden_v5.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto new file mode 100644 index 0000000000..291251e0bf --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_slot_add_store_kernel(%src16_gm: !pto.ptr, + %src32_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst16_gm: !pto.ptr, + %dst32_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src16 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src32 = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst32 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src16_gm, %ub_src16, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src32_gm, %ub_src32, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %rhs16 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %out16 = pto.vmi.addf %sum16, %rhs16 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %out16, %ub_dst16[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %rhs32 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %out32 = pto.vmi.addf %sum32, %rhs32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %out32, %ub_dst32[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst16, %dst16_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst32, %dst32_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp new file mode 100644 index 0000000000..ba7b786e51 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/launch.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_slot_add_store_kernel(__gm__ float *src16, + __gm__ float *src32, + __gm__ float *rhs, + __gm__ float *dst16, + __gm__ float *dst32); + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream) { + vmi_group_reduce_slot_add_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src16, (__gm__ float *)src32, (__gm__ float *)rhs, + (__gm__ float *)dst16, (__gm__ float *)dst32); +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp new file mode 100644 index 0000000000..111426c192 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/main.cpp @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_slot_add_store_kernel(float *src16, float *src32, + float *rhs, float *dst16, + float *dst32, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kS16 = 16; + constexpr size_t kS32 = 32; + constexpr size_t kSrc16Elems = kRows * kS16; + constexpr size_t kSrc32Elems = kRows * kS32; + constexpr size_t kRhsElems = kRows; + constexpr size_t kOutputElems = kRows; + size_t src16Bytes = kSrc16Elems * sizeof(float); + size_t src32Bytes = kSrc32Elems * sizeof(float); + size_t rhsBytes = kRhsElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *src16Host = nullptr; + float *src32Host = nullptr; + float *rhsHost = nullptr; + float *dst16Host = nullptr; + float *dst32Host = nullptr; + float *src16Device = nullptr; + float *src32Device = nullptr; + float *rhsDevice = nullptr; + float *dst16Device = nullptr; + float *dst32Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&src16Host), src16Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&src32Host), src32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst16Host), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dst32Host), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&src16Device, src16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&src32Device, src32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst16Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dst32Device, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", src16Bytes, src16Host, src16Bytes); + ReadFile("./v2.bin", src32Bytes, src32Host, src32Bytes); + ReadFile("./v3.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v4.bin", dstBytes, dst16Host, dstBytes); + ReadFile("./v5.bin", dstBytes, dst32Host, dstBytes); + ACL_CHECK(aclrtMemcpy(src16Device, src16Bytes, src16Host, src16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(src32Device, src32Bytes, src32Host, src32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst16Device, dstBytes, dst16Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dst32Device, dstBytes, dst32Host, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_slot_add_store_kernel( + src16Device, src32Device, rhsDevice, dst16Device, dst32Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dst16Host, dstBytes, dst16Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dst32Host, dstBytes, dst32Device, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dst16Host, dstBytes); + WriteFile("./v5.bin", dst32Host, dstBytes); + +cleanup: + aclrtFree(src16Device); + aclrtFree(src32Device); + aclrtFree(rhsDevice); + aclrtFree(dst16Device); + aclrtFree(dst32Device); + aclrtFreeHost(src16Host); + aclrtFreeHost(src32Host); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dst16Host); + aclrtFreeHost(dst32Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py new file mode 100644 index 0000000000..60aeab3da6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v3") or not check("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py new file mode 100644 index 0000000000..fa1fc04fe6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + rhs = np.linspace(-0.375, 0.5, ROWS, dtype=np.float32) + dst_reduce = np.full(ROWS, SENTINEL, dtype=np.float32) + dst_slot = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_reduce = np.empty(ROWS, dtype=np.float32) + golden_slot = rhs + rhs + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + golden_reduce[row] = np.sum(values, dtype=np.float32) + rhs[row] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dst_reduce.tofile(output_dir / "v3.bin") + dst_slot.tofile(output_dir / "v4.bin") + golden_reduce.tofile(output_dir / "golden_v3.bin") + golden_slot.astype(np.float32).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto new file mode 100644 index 0000000000..7fcdd382c8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_cf_join_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dst_reduce_gm: !pto.ptr, + %dst_slot_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dst_reduce = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_dst_slot = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %cond_true = arith.cmpi eq, %c0, %c0 : index + %cond_false = arith.cmpi ne, %c0, %c0 : index + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + + %reduce_join = scf.if %cond_true -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %sum : !pto.vmi.vreg<128xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %slot : !pto.vmi.vreg<128xf32> + } + %bias0 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %reduce_out = pto.vmi.addf %reduce_join, %bias0 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %reduce_out, %ub_dst_reduce[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %slot_join = scf.if %cond_false -> !pto.vmi.vreg<128xf32> { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + scf.yield %sum : !pto.vmi.vreg<128xf32> + } else { + %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + scf.yield %slot : !pto.vmi.vreg<128xf32> + } + %bias1 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %slot_out = pto.vmi.addf %slot_join, %bias1 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %slot_out, %ub_dst_slot[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst_reduce, %dst_reduce_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dst_slot, %dst_slot_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp new file mode 100644 index 0000000000..add61550a6 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_cf_join_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dstReduce, + __gm__ float *dstSlot); + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream) { + vmi_group_slots_cf_join_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dstReduce, + (__gm__ float *)dstSlot); +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp new file mode 100644 index 0000000000..fb8d6ace69 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/main.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_cf_join_store_kernel(float *src, float *rhs, + float *dstReduce, + float *dstSlot, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kGroupSize = 16; + constexpr size_t kInputElems = kRows * kGroupSize; + constexpr size_t kOutputElems = kRows; + size_t srcBytes = kInputElems * sizeof(float); + size_t rhsBytes = kOutputElems * sizeof(float); + size_t dstBytes = kOutputElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *dstReduceHost = nullptr; + float *dstSlotHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *dstReduceDevice = nullptr; + float *dstSlotDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstReduceHost), dstBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstSlotHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstReduceDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstSlotDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", dstBytes, dstReduceHost, dstBytes); + ReadFile("./v4.bin", dstBytes, dstSlotHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstReduceDevice, dstBytes, dstReduceHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstSlotDevice, dstBytes, dstSlotHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_cf_join_store_kernel(srcDevice, rhsDevice, + dstReduceDevice, dstSlotDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstReduceHost, dstBytes, dstReduceDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(dstSlotHost, dstBytes, dstSlotDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstReduceHost, dstBytes); + WriteFile("./v4.bin", dstSlotHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(dstReduceDevice); + aclrtFree(dstSlotDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(dstReduceHost); + aclrtFreeHost(dstSlotHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py new file mode 100644 index 0000000000..49180d97de --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return True + close = np.isclose(golden, output, atol=1e-4, rtol=1e-4) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2") or not check("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py new file mode 100644 index 0000000000..146d0d1fd2 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/golden.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +INPUT_ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty(INPUT_ELEMS, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_sum = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty(ROWS, dtype=np.float32) + + base_row = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float32) + for row in range(ROWS): + begin = row * GROUP_SIZE + values = base_row + np.float32(row) * np.float32(0.125) + src[begin : begin + GROUP_SIZE] = values + reduction = np.sum(values, dtype=np.float32) + golden_sum[row] = reduction + golden_out[row] = np.sum(values * reduction, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto new file mode 100644 index 0000000000..0660b1e0a3 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_fanout_store_broadcast_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + %b = pto.vmi.group_broadcast %sum {num_groups = 8} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %y = pto.vmi.mulf %x, %b + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %ysum, %ub_out[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp new file mode 100644 index 0000000000..9a0667aae1 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_fanout_store_broadcast_kernel(__gm__ float *src, + __gm__ float *sum, + __gm__ float *out); + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream) { + vmi_group_slots_fanout_store_broadcast_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)sum, (__gm__ float *)out); +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp new file mode 100644 index 0000000000..f7b0fee4b8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_fanout_store_broadcast_kernel(float *src, + float *sum, + float *out, + void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kOutElems = 8; + size_t srcBytes = kSrcElems * sizeof(float); + size_t sumBytes = kOutElems * sizeof(float); + size_t outBytes = kOutElems * sizeof(float); + float *srcHost = nullptr; + float *sumHost = nullptr; + float *outHost = nullptr; + float *srcDevice = nullptr; + float *sumDevice = nullptr; + float *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_fanout_store_broadcast_kernel(srcDevice, sumDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py new file mode 100644 index 0000000000..be861f3da8 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + output = np.fromfile("v3.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py new file mode 100644 index 0000000000..a62c83071c --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + init = np.linspace(-0.25, 0.625, ROWS, dtype=np.float32) + base = np.linspace(-0.75, 0.25, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + dst = np.full(ROWS, SENTINEL, dtype=np.float32) + golden = init + np.float32(2.0) * np.sum(src, axis=1, dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + init.tofile(output_dir / "v1.bin") + src.reshape(-1).tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto new file mode 100644 index 0000000000..8ae0c03444 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto @@ -0,0 +1,68 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_slots_scf_for_store_kernel(%init_gm: !pto.ptr, + %src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_init = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %init_gm, %ub_init, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %acc0 = pto.vmi.group_slot_load %ub_init[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %acc = scf.for %i = %c0 to %c2 step %c1 + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_group_mask %c16 + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + %next = pto.vmi.addf %arg, %sum + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.group_store %acc, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp new file mode 100644 index 0000000000..6837a88fd4 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_slots_scf_for_store_kernel(__gm__ float *init, __gm__ float *src, + __gm__ float *dst); + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream) { + vmi_group_slots_scf_for_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)init, (__gm__ float *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp new file mode 100644 index 0000000000..555d105f43 --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_slots_scf_for_store_kernel(float *init, float *src, + float *dst, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 16; + constexpr size_t kInitElems = kRows; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kDstElems = kRows; + size_t initBytes = kInitElems * sizeof(float); + size_t srcBytes = kSrcElems * sizeof(float); + size_t dstBytes = kDstElems * sizeof(float); + float *initHost = nullptr; + float *srcHost = nullptr; + float *dstHost = nullptr; + float *initDevice = nullptr; + float *srcDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&initHost), initBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&initDevice, initBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", initBytes, initHost, initBytes); + ReadFile("./v2.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(initDevice, initBytes, initHost, initBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_slots_scf_for_store_kernel(initDevice, srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(initDevice); + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(initHost); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py new file mode 100644 index 0000000000..24d554e100 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/compare.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32() -> bool: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-5, rtol=1e-5): + return True + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v2 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_f16() -> bool: + golden = np.fromfile("golden_v3.bin", dtype=np.float16) + output = np.fromfile("v3.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden.view(np.uint16) != output.view(np.uint16))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3 idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check_f32() or not check_f16(): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py new file mode 100644 index 0000000000..6a28077ea8 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/golden.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +ACTIVE = 96 +SEED = 29 +SENTINEL32 = np.float32(-901.25) +SENTINEL16 = np.float16(-17.5) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + out32 = np.full(ELEMS, SENTINEL32, dtype=np.float32) + out16 = np.full(ELEMS, SENTINEL16, dtype=np.float16) + golden32 = out32.copy() + golden16 = out16.copy() + golden32[:ACTIVE] = src[:ACTIVE] + golden16[:ACTIVE] = src[:ACTIVE].astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out32.tofile(output_dir / "v2.bin") + out16.tofile(output_dir / "v3.bin") + golden32.tofile(output_dir / "golden_v2.bin") + golden16.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto new file mode 100644 index 0000000000..f9362793ec --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/kernel.pto @@ -0,0 +1,60 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_granularity_f32_f16_store_kernel(%src_gm: !pto.ptr, + %out32_gm: !pto.ptr, + %out16_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c96 = arith.constant 96 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out32 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out16 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out32_gm, %ub_out32, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out16_gm, %ub_out16, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c96 : index -> !pto.vmi.mask<128xpred> + pto.vmi.masked_store %x, %ub_out32[%c0], %mask + : !pto.vmi.vreg<128xf32>, !pto.ptr, !pto.vmi.mask<128xpred> + %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + pto.vmi.masked_store %h, %ub_out16[%c0], %mask + : !pto.vmi.vreg<128xf16>, !pto.ptr, !pto.vmi.mask<128xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out32, %out32_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out16, %out16_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp new file mode 100644 index 0000000000..de0c069797 --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_granularity_f32_f16_store_kernel(__gm__ float *src, + __gm__ float *out32, + __gm__ half *out16); + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream) { + vmi_mask_granularity_f32_f16_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)out32, (__gm__ half *)out16); +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp new file mode 100644 index 0000000000..2a65d8c46d --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_granularity_f32_f16_store_kernel(float *src, float *out32, + uint16_t *out16, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(float); + size_t out32Bytes = kElems * sizeof(float); + size_t out16Bytes = kElems * sizeof(uint16_t); + float *srcHost = nullptr; + float *out32Host = nullptr; + uint16_t *out16Host = nullptr; + float *srcDevice = nullptr; + float *out32Device = nullptr; + uint16_t *out16Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out32Host), out32Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out16Host), out16Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out32Device, out32Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out16Device, out16Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", out32Bytes, out32Host, out32Bytes); + ReadFile("./v3.bin", out16Bytes, out16Host, out16Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out32Device, out32Bytes, out32Host, out32Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out16Device, out16Bytes, out16Host, out16Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_granularity_f32_f16_store_kernel(srcDevice, out32Device, + out16Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(out32Host, out32Bytes, out32Device, out32Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out16Host, out16Bytes, out16Device, out16Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", out32Host, out32Bytes); + WriteFile("./v3.bin", out16Host, out16Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(out32Device); + aclrtFree(out16Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(out32Host); + aclrtFreeHost(out16Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-granularity-f32-f16-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/mask-select-store/compare.py b/test/vpto/cases/vmi/mask-select-store/compare.py new file mode 100644 index 0000000000..b9e3290e76 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + for name in ("v3", "v4"): + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-5, rtol=1e-5): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-5, rtol=1e-5))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/golden.py b/test/vpto/cases/vmi/mask-select-store/golden.py new file mode 100644 index 0000000000..19ce1ebe2c --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 64 +ACTIVE = 48 +SEED = 29 +SENTINEL = np.float32(-901.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-8.0, 8.0, size=ELEMS).astype(np.float32) + rhs = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float32) + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + masked = np.full(ELEMS, SENTINEL, dtype=np.float32) + summed = (src + rhs).astype(np.float32) + golden_dense = src.copy() + golden_dense[:ACTIVE] = summed[:ACTIVE] + golden_masked = masked.copy() + golden_masked[:ACTIVE] = summed[:ACTIVE] + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + rhs.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + masked.tofile(output_dir / "v4.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + golden_masked.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/mask-select-store/kernel.pto b/test/vpto/cases/vmi/mask-select-store/kernel.pto new file mode 100644 index 0000000000..51538fd4e0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/kernel.pto @@ -0,0 +1,71 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_mask_select_store_kernel(%src_gm: !pto.ptr, + %rhs_gm: !pto.ptr, + %dense_gm: !pto.ptr, + %masked_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c48 = arith.constant 48 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_rhs = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_masked = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %rhs_gm, %ub_rhs, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %masked_gm, %ub_masked, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %rhs = pto.vmi.load %ub_rhs[%c0] : !pto.ptr -> !pto.vmi.vreg<64xf32> + %mask = pto.vmi.create_mask %c48 : index -> !pto.vmi.mask<64xpred> + %sum = pto.vmi.addf %x, %rhs + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<64xf32> + %passthrough = pto.vmi.select %mask, %sum, %x + : !pto.vmi.mask<64xpred>, !pto.vmi.vreg<64xf32>, + !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %passthrough, %ub_dense[%c0] + : !pto.vmi.vreg<64xf32>, !pto.ptr + pto.vmi.masked_store %sum, %ub_masked[%c0], %mask + : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.mask<64xpred> + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dense, %dense_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_masked, %masked_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/mask-select-store/launch.cpp b/test/vpto/cases/vmi/mask-select-store/launch.cpp new file mode 100644 index 0000000000..d75d0da804 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_mask_select_store_kernel(__gm__ float *src, __gm__ float *rhs, + __gm__ float *dense, __gm__ float *masked); + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream) { + vmi_mask_select_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)rhs, (__gm__ float *)dense, + (__gm__ float *)masked); +} diff --git a/test/vpto/cases/vmi/mask-select-store/main.cpp b/test/vpto/cases/vmi/mask-select-store/main.cpp new file mode 100644 index 0000000000..07648040d0 --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_mask_select_store_kernel(float *src, float *rhs, float *dense, + float *masked, void *stream); + +int main() { + constexpr size_t kElems = 64; + size_t srcBytes = kElems * sizeof(float); + size_t rhsBytes = kElems * sizeof(float); + size_t denseBytes = kElems * sizeof(float); + size_t maskedBytes = kElems * sizeof(float); + float *srcHost = nullptr; + float *rhsHost = nullptr; + float *denseHost = nullptr; + float *maskedHost = nullptr; + float *srcDevice = nullptr; + float *rhsDevice = nullptr; + float *denseDevice = nullptr; + float *maskedDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&rhsHost), rhsBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&maskedHost), maskedBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&rhsDevice, rhsBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&maskedDevice, maskedBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", rhsBytes, rhsHost, rhsBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ReadFile("./v4.bin", maskedBytes, maskedHost, maskedBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(rhsDevice, rhsBytes, rhsHost, rhsBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(maskedDevice, maskedBytes, maskedHost, maskedBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_mask_select_store_kernel(srcDevice, rhsDevice, denseDevice, + maskedDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(maskedHost, maskedBytes, maskedDevice, maskedBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", denseHost, denseBytes); + WriteFile("./v4.bin", maskedHost, maskedBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(rhsDevice); + aclrtFree(denseDevice); + aclrtFree(maskedDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(rhsHost); + aclrtFreeHost(denseHost); + aclrtFreeHost(maskedHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/mask-select-store/ptoas.flags b/test/vpto/cases/vmi/mask-select-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/mask-select-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto new file mode 100644 index 0000000000..503068186e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_dense_group_users_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp new file mode 100644 index 0000000000..306dddada0 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_dense_group_users_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp new file mode 100644 index 0000000000..089794a818 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_dense_group_users_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_dense_group_users_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py new file mode 100644 index 0000000000..28299087e5 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/compare.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float32) + output = np.fromfile("v2.bin", dtype=np.float32) + if golden.shape != output.shape or not np.allclose(golden, output, atol=1e-4, rtol=1e-4): + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed idx={idx} golden={golden[idx] if idx >= 0 else 'n/a'} output={output[idx] if idx >= 0 else 'n/a'}") + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py new file mode 100644 index 0000000000..bc9c97fdee --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/golden.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 128 +SEED = 37 +SENTINEL = np.float32(-123.25) + + +def generate(output_dir: Path, seed: int) -> None: + rng = np.random.default_rng(seed) + src = rng.uniform(-4.0, 4.0, size=ELEMS).astype(np.float16) + dst = np.full(ELEMS, SENTINEL, dtype=np.float32) + golden = src.astype(np.float32) * np.float32(4.0) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + parser.add_argument("--seed", type=int, default=SEED) + args = parser.parse_args() + generate(args.output_dir, args.seed) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto new file mode 100644 index 0000000000..3398ef3318 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/kernel.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_scf_for_loop_carried_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %packed = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %init = pto.vmi.extf %packed : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %result = scf.for %i = %c0 to %c2 step %c1 + iter_args(%acc = %init) -> (!pto.vmi.vreg<128xf32>) { + %next = pto.vmi.addf %acc, %acc + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + scf.yield %next : !pto.vmi.vreg<128xf32> + } + pto.vmi.store %result, %ub_dst[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp new file mode 100644 index 0000000000..b0902d1207 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/launch.cpp @@ -0,0 +1,32 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_scf_for_loop_carried_store_kernel(__gm__ uint16_t *src, __gm__ float *dst); + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream) { + vmi_scf_for_loop_carried_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint16_t *)src, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp new file mode 100644 index 0000000000..f45b070260 --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/main.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_scf_for_loop_carried_store_kernel(uint16_t *src, float *dst, + void *stream); + +int main() { + constexpr size_t kElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t dstBytes = kElems * sizeof(float); + uint16_t *srcHost = nullptr; + uint16_t *srcDevice = nullptr; + float *dstHost = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_scf_for_loop_carried_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/scf-for-loop-carried-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py new file mode 100644 index 0000000000..c964405de5 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/compare.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol): + return True + close = np.isclose(golden, output, atol=atol, rtol=rtol) + diff = np.nonzero(~close)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def main() -> None: + if not check("v2", 1e-4, 1e-4) or not check("v3", 0.0, 0.0): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py new file mode 100644 index 0000000000..b41d0e8681 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/golden.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +GROUP_SIZE = 16 +ELEMS = ROWS * GROUP_SIZE +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, GROUP_SIZE), dtype=np.float16) + base = np.linspace(-0.625, 0.875, GROUP_SIZE, dtype=np.float16) + for row in range(ROWS): + src[row, :] = base + np.float16(row * 0.125) + + dense = np.full(ELEMS, SENTINEL, dtype=np.float32) + sum_out = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_dense = src.astype(np.float32).reshape(-1) + golden_sum = np.empty(ROWS, dtype=np.float32) + for row in range(ROWS): + golden_sum[row] = np.sum(src[row, :].astype(np.float32), dtype=np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + sum_out.tofile(output_dir / "v2.bin") + dense.tofile(output_dir / "v3.bin") + golden_sum.tofile(output_dir / "golden_v2.bin") + golden_dense.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto new file mode 100644 index 0000000000..9f3dfeabb4 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_widen_f16_to_f32_store_reduce_kernel(%src_gm: !pto.ptr, + %sum_gm: !pto.ptr, + %dense_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_dense = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %sum_gm, %ub_sum, %c0_i64, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dense_gm, %ub_dense, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %x32, %ub_dense[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_dense, %dense_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp new file mode 100644 index 0000000000..b0ee12da2b --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_widen_f16_to_f32_store_reduce_kernel(__gm__ half *src, __gm__ float *sum, + __gm__ float *dense); + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, + void *stream) { + vmi_widen_f16_to_f32_store_reduce_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)sum, (__gm__ float *)dense); +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp new file mode 100644 index 0000000000..96a4a102f8 --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/main.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_widen_f16_to_f32_store_reduce_kernel(uint16_t *src, float *sum, + float *dense, void *stream); + +int main() { + constexpr size_t kSrcElems = 128; + constexpr size_t kSumElems = 8; + constexpr size_t kDenseElems = 128; + size_t srcBytes = kSrcElems * sizeof(uint16_t); + size_t sumBytes = kSumElems * sizeof(float); + size_t denseBytes = kDenseElems * sizeof(float); + uint16_t *srcHost = nullptr; + float *sumHost = nullptr; + float *denseHost = nullptr; + uint16_t *srcDevice = nullptr; + float *sumDevice = nullptr; + float *denseDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&denseHost), denseBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&denseDevice, denseBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", sumBytes, sumHost, sumBytes); + ReadFile("./v3.bin", denseBytes, denseHost, denseBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(denseDevice, denseBytes, denseHost, denseBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_widen_f16_to_f32_store_reduce_kernel(srcDevice, sumDevice, + denseDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(denseHost, denseBytes, denseDevice, denseBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", sumHost, sumBytes); + WriteFile("./v3.bin", denseHost, denseBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(sumDevice); + aclrtFree(denseDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(sumHost); + aclrtFreeHost(denseHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From aecea73380b1784c3164dd7e744c0180bee5b5fc Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 11:45:23 +0800 Subject: [PATCH 07/54] Support S32 partial grouped mask lowering --- .../vmi-layout-assignment-implementation.md | 38 ++++---- .../vmi-layout-assignment-lowering-design.md | 9 +- docs/designs/vmi-layout-lowering-cases.md | 66 ++++++------- lib/PTO/Transforms/VMILayoutAssignment.cpp | 22 ----- lib/PTO/Transforms/VMIToVPTO.cpp | 59 ++++++++++- ..._assignment_masked_load_group_tail_s32.pto | 21 +++- .../compare.py | 40 ++++++++ .../golden.py | 51 ++++++++++ .../kernel.pto | 62 ++++++++++++ .../launch.cpp | 33 +++++++ .../main.cpp | 97 +++++++++++++++++++ .../ptoas.flags | 1 + 12 files changed, 413 insertions(+), 86 deletions(-) create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index e6b39dd984..12fedece13 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -758,8 +758,11 @@ Current audit result: ```text 3.44 partial S=32 create_group_mask: - decision moved to vmi-layout-assignment. vmi-to-vpto no longer walks from - group_reduce_addf to the mask defining op to reject the plan. + assignment writes explicit contiguous and deinterleaved mask values. When + lowering the deinterleaved create_group_mask itself, vmi-to-vpto first + materializes contiguous grouped predicate chunks and then applies predicate + pdintlv in the same tree shape as the data vdintlv. It still does not walk + from group_reduce_addf to the mask defining op to choose or reject the plan. masked_load: direct lowering is load + vsel. It does not inspect the mask producer to @@ -870,13 +873,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-39 CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-40 CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=39 FAIL=0 -summary: .tmp/vmi-runtime-batch-39/parallel-summary.tsv +PASS=40 FAIL=0 +summary: .tmp/vmi-runtime-batch-40/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-39.log + .tmp/vmi-runtime-batch-40.log result: no matches ``` @@ -1054,6 +1057,17 @@ runtime SIM: load must run through layout assignment before VPTO/LLVM emission. ``` +Current checked-in coverage for 3.44 masked_load grouped tail feeding S=32 +reduce: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto + +runtime SIM: + test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` + Current checked-in runtime coverage for 3.12 control-flow join before S=32 `group_reduce`: @@ -1297,7 +1311,6 @@ Diagnostic-only cases: 3.25.1 full ptoas emission for private VMI callees that return VPTO vector values 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback -3.44 masked_load grouped tail with S=32 partial create_group_mask ``` Current checked-in diagnostic coverage for 3.9/3.13/3.14: @@ -1326,7 +1339,6 @@ lit: test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto - test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto ``` Known implementation gaps before all catalog cases can become runtime SIM @@ -1339,16 +1351,6 @@ dynamic grouped masks: yet. Do not replace grouped masks with prefix create_mask; that would change the semantics. -S=32 partial grouped masks: - 3.44 `masked_load` grouped tail with `active_elems_per_group < 32` is - diagnostic-only for the current S=32 block8 reduce path, and the diagnostic - is emitted by `vmi-layout-assignment` before a selected plan is written. A - runtime probe of the previously allowed lowering did not preserve the logical - 25-lane row sum. A second probe with `active_elems_per_group = 25` produced - row 0 `golden=-3.6290324` but `output=-3.6592741`, and the row-wise error - grew monotonically. This combination must stay unsupported until the - deinterleaved grouped-mask materialization is fixed and validated by SIM. - remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and vmi-to-vpto lit coverage, but full ptoas emission still fails after diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 710ab267a7..d0f71f0daf 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -552,9 +552,12 @@ diagnostic embellishment: Anything else is a layout-assignment responsibility. In particular, an unsupported producer/consumer combination must be rejected before assignment -writes a selected plan. Section 3.44 is the model: partial S=32 grouped masks -are diagnosed in `vmi-layout-assignment`, not by `vmi-to-vpto` walking from -`group_reduce_addf` to the mask producer. +writes a selected plan. Section 3.44 is the model for supported partial S=32 +grouped masks: assignment emits explicit contiguous and deinterleaved mask +values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through +contiguous grouped-mask materialization followed by predicate deinterleave. It +does not walk from `group_reduce_addf` to the mask producer to choose or reject +the plan. ## 9. Physical Value Ordering diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index b111397fc9..93a6c1dc57 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -196,7 +196,7 @@ the immediately following complete endpoints. 3.41 non-rematerializable value with incompatible users complete/materialization 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design -3.44 masked_load grouped tail feeding S=32 reduce complete/design +3.44 masked_load grouped tail feeding S=32 reduce complete ``` ### 3.1 `f16 -> f32 -> store` @@ -5167,25 +5167,7 @@ Assigned layouts: !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` -Current implementation result: - -```text -VMI-UNSUPPORTED: pto.vmi.group_reduce_addf s32 block8 lowering does not yet -support partial create_group_mask active_elems_per_group during layout -assignment -``` - -This must remain a layout-assignment diagnostic until the S=32 block8 -grouped-mask lowering is proven against runtime SIM. Assignment must not write -`vmi.selected_plan = "s32_reduce_block8_stride"` for this case and leave -`vmi-to-vpto` to discover the partial mask by walking the mask defining op. A -`masked_load` can be lowered contiguously and then materialized to -`deinterleaved = 4, block_elems = 8`, but the grouped reduce still needs a -physically correct `create_group_mask` for `active_elems_per_group = 25`. -Allowing the current S=32 block8 path to proceed would not preserve the logical -memory result below. - -Intended VPTO lowering shape after the grouped-mask issue is fixed: +Lowering: ```text %all_b32 = pto.pge_b32 "PAT_ALL" @@ -5209,15 +5191,16 @@ Intended VPTO lowering shape after the grouped-mask issue is fixed: %x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo %x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi -// Correct deinterleaved grouped mask for active columns 0..24: -// part 0 covers columns 0..7 for every row: all active -// part 1 covers columns 8..15 for every row: all active -// part 2 covers columns 16..23 for every row: all active -// part 3 covers columns 24..31 for every row: one active lane per row -%mask_p0 = pto.pset_b32 "PAT_ALL" -%mask_p1 = pto.pset_b32 "PAT_ALL" -%mask_p2 = pto.pset_b32 "PAT_ALL" -%mask_p3 = materialize one lane per 8-lane row block +// The reduce-side grouped mask is not built by guessing the final sparse +// predicate image. It is first materialized as the same contiguous grouped +// mask used by masked_load, then converted to the reduce layout with predicate +// deinterleave. This keeps predicate reordering identical to the data +// reordering above. +%rm0, %rm1, %rm2, %rm3 = materialize contiguous create_group_mask(c25, S=32) +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi %s0 = pto.vcgadd %x_p0, %mask_p0 : !pto.vreg<64xf32> %s1 = pto.vcgadd %x_p1, %mask_p1 : !pto.vreg<64xf32> @@ -5244,12 +5227,21 @@ Required assignment rule: ```text `masked_load` and `group_reduce` must share the same grouped mask layout. The passthrough value defines inactive loaded lanes, while the reduce mask defines -participation. Assignment may select a deinterleaved S=32 load plan only when -the rounded physical reads are memory-safe; otherwise it must diagnose or use a -future stable gather fallback. - -Current implementation additionally diagnoses the S=32 block8 partial grouped -mask itself. This is deliberate: the case is not implemented until the -deinterleaved grouped-mask materialization and `vcgadd` interpretation are -validated end to end by SIM. +participation. Assignment materializes two explicit mask values when needed: +one contiguous value for `masked_load`, and one deinterleaved value for +`group_reduce_addf`. `vmi-to-vpto` lowers the deinterleaved +`create_group_mask` by materializing the contiguous grouped predicate chunks +and then applying `pdintlv_b32` in the same tree shape as the data +`vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to +choose or reject the selected plan. + +Assignment may select a deinterleaved S=32 load plan only when the rounded +physical reads are memory-safe; otherwise it must diagnose or use a future +stable gather fallback. + +Runtime coverage: + +```text +test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store +``` ``` diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index c95e8772ec..9352ffce76 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -832,28 +832,6 @@ struct LayoutSolver { VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); } } - if (sourceLayout && sourceLayout.isDeinterleaved() && - sourceLayout.getFactor() == 4 && - sourceLayout.getBlockElems() == 8 && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - if (groupSize == 32) { - if (auto groupMask = - reduce.getMask().getDefiningOp()) { - std::optional activeElems = - getConstantIndexValue(groupMask.getActiveElemsPerGroup()); - if (activeElems && *activeElems >= 0 && - *activeElems < groupSize) { - reduce.emitError() - << kVMIDiagUnsupportedPrefix - << "pto.vmi.group_reduce_addf s32 block8 lowering does " - "not yet support partial create_group_mask " - "active_elems_per_group during layout assignment"; - return WalkResult::interrupt(); - } - } - } - } requestDataUse(reduce.getSourceMutable(), sourceLayout); if (failed(requestMaskUse( reduce.getMaskMutable(), sourceLayout, diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 95141bada7..85dbec5f1e 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -2079,7 +2079,9 @@ computeConstantMaskMaterialization(VMIConstantMaskOp op, std::string *reason) { } FailureOr> -computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { +computeGroupMaskMaterializationForType(VMICreateGroupMaskOp op, + VMIMaskType resultVMIType, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr> { if (reason) @@ -2095,7 +2097,6 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { if (!activeAttr) return fail("active_elems_per_group must be an integer constant"); - auto resultVMIType = cast(op.getResult().getType()); VMILayoutAttr layout = resultVMIType.getLayoutAttr(); if (!layout || !VMIMaskType::isConcreteGranularity(resultVMIType.getGranularity())) @@ -2153,6 +2154,12 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { return materializations; } +FailureOr> +computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { + return computeGroupMaskMaterializationForType( + op, cast(op.getResult().getType()), reason); +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -3781,6 +3788,54 @@ struct OneToNVMICreateGroupMaskOpPattern matchAndRewrite(VMICreateGroupMaskOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 8) { + VMILayoutAttr contiguousLayout = + VMILayoutAttr::getContiguous(op.getContext()); + auto contiguousType = + VMIMaskType::get(op.getContext(), resultVMIType.getElementCount(), + resultVMIType.getGranularity(), contiguousLayout); + std::string contiguousReason; + FailureOr> + contiguousMaterializations = computeGroupMaskMaterializationForType( + op, contiguousType, &contiguousReason); + if (failed(contiguousMaterializations)) + return rewriter.notifyMatchFailure( + op, Twine("create_group_mask ") + contiguousReason); + + SmallVector contiguousParts; + contiguousParts.reserve(contiguousMaterializations->size()); + for (const ConstantMaskChunkMaterialization &materialization : + *contiguousMaterializations) { + if (contiguousParts.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many contiguous masks"); + auto maskType = dyn_cast(resultTypes[contiguousParts.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask contiguous chunk"); + contiguousParts.push_back(*mask); + } + + if (contiguousParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask contiguous physical result count mismatch"); + FailureOr> results = materializeMaskLayoutConversion( + op, contiguousParts, resultTypes, contiguousLayout, resultLayout, + rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + std::string reason; FailureOr> materializations = computeGroupMaskMaterialization(op, &reason); diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index bad43bb869..33ee79cb57 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -6,7 +6,8 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_masked_load_group_tail_s32( @@ -34,6 +35,18 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED: pto.vmi.group_reduce_addf -// CHECK-SAME: s32 block8 lowering does not yet support partial create_group_mask active_elems_per_group during layout assignment -// CHECK-NOT: vmi.selected_plan = "s32_reduce_block8_stride" +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.ensure_layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN-SAME: #pto.vmi.layout +// ASSIGN: pto.vmi.create_group_mask +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.pdintlv_b32 +// LOWER: pto.vcgadd +// LOWER: pto.vsts diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..37a10109ee --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_masked_load_group_tail_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c25 = arith.constant 25 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c25 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5b39bc3962 --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_masked_load_group_tail_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_masked_load_group_tail_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..f9f224885e --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_masked_load_group_tail_s32_reduce_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 3d740079a6ba978e509278a9994c8effd87e1a23 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:13:15 +0800 Subject: [PATCH 08/54] Support dynamic S32 grouped mask lowering --- .../vmi-layout-assignment-implementation.md | 19 +- .../vmi-layout-assignment-lowering-design.md | 4 +- docs/designs/vmi-layout-lowering-cases.md | 74 ++++++- lib/PTO/Transforms/VMIToVPTO.cpp | 189 +++++++++++++++--- ...signment_create_group_mask_s32_dynamic.pto | 61 ++++++ 5 files changed, 316 insertions(+), 31 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 12fedece13..f0dc821444 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -763,6 +763,9 @@ Current audit result: materializes contiguous grouped predicate chunks and then applies predicate pdintlv in the same tree shape as the data vdintlv. It still does not walk from group_reduce_addf to the mask defining op to choose or reject the plan. + The dynamic active_elems_per_group form is also op-local: vmi-to-vpto lowers + contiguous chunks with vci/vshrs/vshls/vsub/vcmps, then uses the same + predicate pdintlv tree for S=32 deinterleaved masks. masked_load: direct lowering is load + vsel. It does not inspect the mask producer to @@ -904,7 +907,7 @@ layout/rematerialization: mask/tail: 3.11.1, 3.15.1, 3.15.2, 3.21, 3.24, 3.26, 3.29, - 3.30, 3.44 + 3.30, 3.44, 3.45 strided/group-slot memory: 3.27, 3.28, 3.37, 3.39 @@ -1063,6 +1066,7 @@ reduce: ```text lit: test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto + test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store @@ -1345,11 +1349,14 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -dynamic grouped masks: - pto.vmi.create_group_mask exists and supports constant - active_elems_per_group. Dynamic active_elems_per_group is not implemented - yet. Do not replace grouped masks with prefix create_mask; that would change - the semantics. +dynamic grouped mask runtime source: + vmi-to-vpto supports dynamic active_elems_per_group for contiguous b32 + grouped masks and S=32 deinterleaved=4/block_elems=8 masks. Full runtime SIM + coverage still needs a supported scalar source for active_elems_per_group in + vector kernels. Direct GM pto.ldg crashed the Bisheng vector backend in this + test shape, and UB pto.load_scalar reached an invalid scalar LSU address in + the SIM. Do not replace grouped masks with prefix create_mask; that would + change the semantics. remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index d0f71f0daf..c13944d348 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -557,7 +557,9 @@ grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the plan. +the plan. Dynamic `active_elems_per_group` follows the same rule: the +`create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps +for contiguous chunks before any predicate deinterleave. ## 9. Physical Value Ordering diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 93a6c1dc57..e44f32a97e 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -197,6 +197,7 @@ the immediately following complete endpoints. 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design 3.44 masked_load grouped tail feeding S=32 reduce complete +3.45 dynamic S=32 create_group_mask complete/lit ``` ### 3.1 `f16 -> f32 -> store` @@ -5224,7 +5225,6 @@ for r = 0..7: Required assignment rule: -```text `masked_load` and `group_reduce` must share the same grouped mask layout. The passthrough value defines inactive loaded lanes, while the reduce mask defines participation. Assignment materializes two explicit mask values when needed: @@ -5244,4 +5244,76 @@ Runtime coverage: ```text test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` + +### 3.45 Dynamic S=32 `create_group_mask` + +This is the dynamic-shape form of section 3.44. The active column count is an +SSA `index`, not a constant. The semantic mask is still grouped: + +```text +lane i active iff (i % 32) < active_cols +``` + +VMI input: + +```text +%mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +``` + +Assigned layouts: + +```text +%mask for masked_load: + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%mask for S=32 group_reduce: + !pto.vmi.mask<256xb32, + #pto.vmi.layout> +``` + +Contiguous VPTO lowering for one b32 physical chunk: + +```text +%active_i32 = arith.index_cast %active_cols : index to i32 +%active_nonneg = arith.maxsi %active_i32, %c0_i32 : i32 +%active_clamped = arith.minui %active_nonneg, %c32_i32 : i32 + +%all = pto.pset_b32 "PAT_ALL" : !pto.mask +%lane = pto.vci %c0_i32 : i32 -> !pto.vreg<64xi32> +%row = pto.vshrs %lane, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%row_base = pto.vshls %row, %c5_i16, %all + : !pto.vreg<64xi32>, i16, !pto.mask -> !pto.vreg<64xi32> +%col = pto.vsub %lane, %row_base, %all + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask + -> !pto.vreg<64xi32> +%m = pto.vcmps %col, %active_clamped, %all, "lt" + : !pto.vreg<64xi32>, i32, !pto.mask -> !pto.mask ``` + +For `deinterleaved = 4, block_elems = 8`, lowering first emits four contiguous +chunks with the sequence above, then applies the same predicate deinterleave +tree used by section 3.44: + +```text +%rm0, %rm1, %rm2, %rm3 = dynamic contiguous grouped masks +%rm01_lo, %rm01_hi = pto.pdintlv_b32 %rm0, %rm1 +%rm23_lo, %rm23_hi = pto.pdintlv_b32 %rm2, %rm3 +%mask_p0, %mask_p2 = pto.pdintlv_b32 %rm01_lo, %rm23_lo +%mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi +``` + +The current lit coverage validates the IR lowering: + +```text +test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +``` + +Runtime SIM coverage is intentionally not listed yet. A direct runtime case +needs a supported way to feed a dynamic scalar `active_cols` into a vector +kernel. Experiments with GM `pto.ldg` and UB `pto.load_scalar` either crashed +the Bisheng vector backend or produced an invalid scalar LSU address in the +SIM. That is an ABI/source-materialization gap, not a `create_group_mask` +layout-lowering gap. diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 85dbec5f1e..36ccc21f3f 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -424,6 +424,11 @@ Value createI32Constant(Location loc, int64_t value, return rewriter.create(loc, value, 32); } +Value createI16Constant(Location loc, int64_t value, + PatternRewriter &rewriter) { + return rewriter.create(loc, value, 16); +} + FailureOr createPrefixMaskForActiveLanes(Location loc, MaskType maskType, int64_t activeLanes, PatternRewriter &rewriter) { @@ -477,6 +482,17 @@ Value createPartitionActiveLanes(Location loc, Value activeLanesI32, loc, biased, createI32Constant(loc, factor, rewriter)); } +std::optional getPowerOfTwoLog2(int64_t value) { + if (value <= 0 || (value & (value - 1)) != 0) + return std::nullopt; + int64_t log2 = 0; + while (value > 1) { + value >>= 1; + ++log2; + } + return log2; +} + std::optional getPrefixPattern(int64_t activeLanes, int64_t lanesPerPart) { if (activeLanes <= 0) @@ -2160,6 +2176,96 @@ computeGroupMaskMaterialization(VMICreateGroupMaskOp op, std::string *reason) { op, cast(op.getResult().getType()), reason); } +FailureOr> materializeDynamicContiguousGroupMask( + VMICreateGroupMaskOp op, Value activeElemsPerGroup, + VMIMaskType contiguousVMIType, TypeRange resultTypes, + PatternRewriter &rewriter) { + auto fail = [&](const Twine &message) -> FailureOr> { + (void)rewriter.notifyMatchFailure(op, message); + return failure(); + }; + + VMILayoutAttr layout = contiguousVMIType.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return fail("dynamic create_group_mask requires contiguous seed layout"); + if (contiguousVMIType.getGranularity() != "b32") + return fail("dynamic create_group_mask currently requires b32 " + "granularity"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = op.getGroupSizeAttr().getInt(); + if (numGroups <= 0 || groupSize <= 0 || + contiguousVMIType.getElementCount() != numGroups * groupSize) + return fail("dynamic create_group_mask requires result lane count to " + "match num_groups * group_size"); + + FailureOr lanesPerPart = + getMaskLanesPerPart(contiguousVMIType.getGranularity()); + FailureOr arity = getVMIPhysicalArity(contiguousVMIType); + if (failed(lanesPerPart) || failed(arity) || *arity < 1) + return fail("dynamic create_group_mask requires computable physical " + "mask chunks"); + if (static_cast(resultTypes.size()) != *arity) + return fail("dynamic create_group_mask physical result count mismatch"); + if (groupSize > *lanesPerPart || (*lanesPerPart % groupSize) != 0) + return fail("dynamic create_group_mask currently requires group_size to " + "divide one physical b32 predicate chunk"); + + std::optional shift = getPowerOfTwoLog2(groupSize); + if (!shift) + return fail("dynamic create_group_mask currently requires power-of-two " + "group_size"); + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type i32 = rewriter.getI32Type(); + auto indexVectorType = VRegType::get(ctx, *lanesPerPart, i32); + Value activeI32 = + clampDynamicActiveLanes(loc, activeElemsPerGroup, groupSize, rewriter); + + SmallVector results; + results.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto maskType = dyn_cast(resultType); + if (!maskType || !maskType.isB32()) + return fail("dynamic create_group_mask result must be b32 mask"); + + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(allMask)) + return fail("failed to create dynamic create_group_mask all mask"); + + Value zero = createI32Constant(loc, 0, rewriter); + Value lane = + rewriter.create(loc, indexVectorType, zero, StringAttr{}) + .getResult(); + + Value col = lane; + if (groupSize != *lanesPerPart) { + Value shiftScalar = createI16Constant(loc, *shift, rewriter); + Value group = rewriter + .create(loc, indexVectorType, lane, + shiftScalar, *allMask) + .getResult(); + Value groupBase = rewriter + .create(loc, indexVectorType, group, + shiftScalar, *allMask) + .getResult(); + col = rewriter + .create(loc, indexVectorType, lane, groupBase, + *allMask) + .getResult(); + } + + results.push_back(rewriter + .create(loc, maskType, col, activeI32, + *allMask, + rewriter.getStringAttr("lt")) + .getResult()); + } + + return results; +} + std::optional getPrefixActiveLaneCount(ArrayRef activeLanes) { bool seenInactive = false; int64_t activeCount = 0; @@ -3797,31 +3903,50 @@ struct OneToNVMICreateGroupMaskOpPattern auto contiguousType = VMIMaskType::get(op.getContext(), resultVMIType.getElementCount(), resultVMIType.getGranularity(), contiguousLayout); - std::string contiguousReason; - FailureOr> - contiguousMaterializations = computeGroupMaskMaterializationForType( - op, contiguousType, &contiguousReason); - if (failed(contiguousMaterializations)) - return rewriter.notifyMatchFailure( - op, Twine("create_group_mask ") + contiguousReason); - SmallVector contiguousParts; - contiguousParts.reserve(contiguousMaterializations->size()); - for (const ConstantMaskChunkMaterialization &materialization : - *contiguousMaterializations) { - if (contiguousParts.size() >= resultTypes.size()) - return rewriter.notifyMatchFailure( - op, "create_group_mask produced too many contiguous masks"); - auto maskType = dyn_cast(resultTypes[contiguousParts.size()]); - if (!maskType) + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (activeConstant) { + std::string contiguousReason; + FailureOr> + contiguousMaterializations = computeGroupMaskMaterializationForType( + op, contiguousType, &contiguousReason); + if (failed(contiguousMaterializations)) return rewriter.notifyMatchFailure( - op, "create_group_mask result must be mask"); - FailureOr mask = materializeConstantMaskChunk( - op.getLoc(), maskType, materialization.activeLanes, rewriter); - if (failed(mask)) - return rewriter.notifyMatchFailure( - op, "failed to materialize create_group_mask contiguous chunk"); - contiguousParts.push_back(*mask); + op, Twine("create_group_mask ") + contiguousReason); + + contiguousParts.reserve(contiguousMaterializations->size()); + for (const ConstantMaskChunkMaterialization &materialization : + *contiguousMaterializations) { + if (contiguousParts.size() >= resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "create_group_mask produced too many contiguous masks"); + auto maskType = + dyn_cast(resultTypes[contiguousParts.size()]); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "create_group_mask result must be mask"); + FailureOr mask = materializeConstantMaskChunk( + op.getLoc(), maskType, materialization.activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize create_group_mask contiguous chunk"); + contiguousParts.push_back(*mask); + } + } else { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one " + "value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> dynamicParts = + materializeDynamicContiguousGroupMask(op, *active, contiguousType, + resultTypes, rewriter); + if (failed(dynamicParts)) + return failure(); + contiguousParts = std::move(*dynamicParts); } if (contiguousParts.size() != resultTypes.size()) @@ -3836,6 +3961,24 @@ struct OneToNVMICreateGroupMaskOpPattern return success(); } + auto activeConstant = + op.getActiveElemsPerGroup().getDefiningOp(); + if (!activeConstant && resultLayout && resultLayout.isContiguous()) { + FailureOr active = getSingleValue( + op, adaptor.getActiveElemsPerGroup(), + "create_group_mask active_elems_per_group must convert to one value", + rewriter); + if (failed(active)) + return failure(); + FailureOr> results = + materializeDynamicContiguousGroupMask(op, *active, resultVMIType, + resultTypes, rewriter); + if (failed(results)) + return failure(); + rewriter.replaceOp(op, *results, adaptor.getResultMapping()); + return success(); + } + std::string reason; FailureOr> materializations = computeGroupMaskMaterialization(op, &reason); diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto new file mode 100644 index 0000000000..f68b4d5509 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( + %base: !pto.ptr, + %sum_out: !pto.ptr, + %off: index, + %active_cols: index) { + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %base[%off], %mask, %zero + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// ASSIGN-SAME: %[[ACTIVE:arg[0-9]+]]: index) +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_reduce_addf +// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" + +// LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( +// LOWER: arith.index_cast +// LOWER: arith.maxsi +// LOWER: arith.minui +// LOWER: pto.vci +// LOWER: pto.vshrs +// LOWER: pto.vshls +// LOWER: pto.vsub +// LOWER-COUNT-8: pto.vcmps +// LOWER-COUNT-4: pto.pdintlv_b32 +// LOWER-COUNT-4: pto.vcgadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast From 81778a1e94a206f5b5941e5289d49b63382662cc Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:15:28 +0800 Subject: [PATCH 09/54] Clarify VMI layout case coverage gaps --- .../vmi-layout-assignment-implementation.md | 18 ++++-- .../vmi-layout-assignment-lowering-design.md | 57 ++++++++++++++++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index f0dc821444..a4fb146317 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -876,13 +876,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-40 CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-mask CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=40 FAIL=0 -summary: .tmp/vmi-runtime-batch-40/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-dynamic-mask/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-40.log + .tmp/vmi-runtime-batch-dynamic-mask.log result: no matches ``` @@ -1066,12 +1066,22 @@ reduce: ```text lit: test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto - test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` +Current checked-in lit coverage for 3.45 dynamic S=32 `create_group_mask`: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto + +runtime SIM: + blocked by the current dynamic scalar source gap for vector kernels; see + known implementation gaps below +``` + Current checked-in runtime coverage for 3.12 control-flow join before S=32 `group_reduce`: diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index c13944d348..da58668057 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -72,6 +72,7 @@ control flow: mask and tail: prefix mask group-periodic mask + dynamic group-periodic mask masked_load tail with explicit passthrough instead of padding masked_load grouped tail feeding group_reduce masked select/store @@ -86,6 +87,59 @@ strided memory: group_store slots=1 with non-unit output stride ``` +### 1.1 Case-Set Sufficiency + +The current case set is sufficient to define the first implementation of layout +assignment and lowering. It covers every decision axis that has changed the +design so far: + +```text +physical dense layout: + contiguous, deinterleaved=2/4, block_elems=1/8 + +sparse result layout: + group_slots(G, slots=8) for packed VCG results + group_slots(G, slots=1) for row-local S=64 results + +producer-driven layout: + load, group_load, group_slot_load, broadcast, create_mask, + create_group_mask + +consumer-driven pressure: + dense store, group_reduce, group_store, group_broadcast, truncf, + elementwise/select, masked_load/masked_store + +conflict resolution: + cheap rematerialization, explicit ensure_layout, explicit diagnostics + +control-flow propagation: + scf.if, scf.for iter_args/results, internal/private function boundaries, + public ABI rejection + +memory legality: + full_tile_readable proof, grouped masks, predicate granularity, aligned + strided group memory, stable gather diagnostic +``` + +No extra layout kind should be added unless a new case proves that the existing +layouts and plans cannot express the logical behavior. The remaining open +items are not missing layout semantics: + +```text +dynamic active_elems_per_group runtime source: + create_group_mask layout lowering is defined and has lit coverage; runtime + SIM still needs a supported scalar source/ABI for vector kernels. + +private vector function runtime: + assignment/lowering semantics are defined; full ptoas runtime depends on + backend support or an inlining policy for physical VPTO vector callees. + +diagnostic-only cases: + compact S=12 gather fallback, packed slots=8 width-changing cast, public VMI + ABI, unsafe masked_load tail, and unaligned/dynamic group memory remain + explicit capability boundaries. +``` + ## 2. Layout Domain Layout is a property of a layout-assigned VMI value, not a property inferred by @@ -626,5 +680,6 @@ The design is complete only when: 3. every unsupported case has a precise capability diagnostic 4. every control-flow/function boundary either specializes layout or diagnoses 5. every mask has explicit data layout and predicate granularity -6. every case has an end-to-end test and simulator validation +6. every positive case has end-to-end lit coverage +7. every simulator-supported positive case has simulator validation ``` From f25128cdedabc1e119332cf1b85032eefdfe9177 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:18:40 +0800 Subject: [PATCH 10/54] Record VMI layout coverage audit --- .../vmi-layout-assignment-implementation.md | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index a4fb146317..fbfef70804 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -937,6 +937,39 @@ Aggregate catalog headings are covered through their endpoint subcases: 3.25.2 public/external boundary diagnostics ``` +Current coverage audit result: + +```text +SIM-backed positive endpoints: + 3.1, 3.2, 3.3, 3.4, 3.5.1, 3.5.2, 3.5.3, + 3.6.1, 3.6.2, 3.6.3, 3.7.1, 3.7.2, 3.7.3, + 3.8, 3.10, 3.11.1, 3.12, 3.15.1, 3.15.2, + 3.16.1 positive, 3.16.2 positive, 3.17, 3.18, + 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, + 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, + 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, + 3.40, 3.41, 3.42, 3.44 + +lit-backed positive endpoints with runtime gap: + 3.25.1 private/internal function boundary + 3.43 internal function argument boundary materialization + 3.45 dynamic S=32 create_group_mask + +diagnostic endpoints: + 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, + 3.16.1 non-unit slots=8 source stride, + 3.16.2 dynamic/unaligned slots=1 source stride, + 3.19.2, 3.25.2, 3.27 unaligned source_group_stride, + 3.30 unsafe masked_load tail + +repository evidence: + all concrete lit/runtime paths listed below exist + all 40 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + golden.py, and compare.py + latest broad VMI runtime sweep passed: PASS=40 FAIL=0 + latest full VMI lit sweep passed: 312/312 +``` + Current checked-in coverage for 3.3 dense f8->f32->compute->f8: ```text From 327dd9cf469b348ef22751265ec054cb5db0a920 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:31:17 +0800 Subject: [PATCH 11/54] Add dynamic S32 group mask runtime coverage --- .../vmi-layout-assignment-implementation.md | 34 +++---- .../vmi-layout-assignment-lowering-design.md | 6 +- docs/designs/vmi-layout-lowering-cases.md | 15 ++- .../compare.py | 40 ++++++++ .../golden.py | 51 ++++++++++ .../kernel.pto | 64 ++++++++++++ .../launch.cpp | 35 +++++++ .../main.cpp | 99 +++++++++++++++++++ .../ptoas.flags | 1 + 9 files changed, 315 insertions(+), 30 deletions(-) create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp create mode 100644 test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index fbfef70804..b05d6e1552 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -876,13 +876,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-mask CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-scalar CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=40 FAIL=0 -summary: .tmp/vmi-runtime-batch-dynamic-mask/parallel-summary.tsv +PASS=41 FAIL=0 +summary: .tmp/vmi-runtime-batch-dynamic-scalar/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-dynamic-mask.log + .tmp/vmi-runtime-batch-dynamic-scalar.log result: no matches ``` @@ -948,12 +948,11 @@ SIM-backed positive endpoints: 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, - 3.40, 3.41, 3.42, 3.44 + 3.40, 3.41, 3.42, 3.44, 3.45 lit-backed positive endpoints with runtime gap: 3.25.1 private/internal function boundary 3.43 internal function argument boundary materialization - 3.45 dynamic S=32 create_group_mask diagnostic endpoints: 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, @@ -964,9 +963,9 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 40 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 41 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=40 FAIL=0 + latest broad VMI runtime sweep passed: PASS=41 FAIL=0 latest full VMI lit sweep passed: 312/312 ``` @@ -1104,15 +1103,19 @@ runtime SIM: test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store ``` -Current checked-in lit coverage for 3.45 dynamic S=32 `create_group_mask`: +Current checked-in coverage for 3.45 dynamic S=32 `create_group_mask`: ```text lit: test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto runtime SIM: - blocked by the current dynamic scalar source gap for vector kernels; see - known implementation gaps below + test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store + +runtime scalar source: + active_cols is passed as a kernel i32 scalar argument and cast to index inside + vecscope before pto.vmi.create_group_mask. This is an explicit scalar ABI, + not a value recovered by vmi-to-vpto from producer/consumer context. ``` Current checked-in runtime coverage for 3.12 control-flow join before S=32 @@ -1392,15 +1395,6 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -dynamic grouped mask runtime source: - vmi-to-vpto supports dynamic active_elems_per_group for contiguous b32 - grouped masks and S=32 deinterleaved=4/block_elems=8 masks. Full runtime SIM - coverage still needs a supported scalar source for active_elems_per_group in - vector kernels. Direct GM pto.ldg crashed the Bisheng vector backend in this - test shape, and UB pto.load_scalar reached an invalid scalar LSU address in - the SIM. Do not replace grouped masks with prefix create_mask; that would - change the semantics. - remaining function runtime coverage: 3.25.1 internal function boundary specialization has layout-assignment and vmi-to-vpto lit coverage, but full ptoas emission still fails after diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index da58668057..99a1a34c6c 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -127,8 +127,10 @@ items are not missing layout semantics: ```text dynamic active_elems_per_group runtime source: - create_group_mask layout lowering is defined and has lit coverage; runtime - SIM still needs a supported scalar source/ABI for vector kernels. + create_group_mask layout lowering is defined and has both lit and SIM + coverage. The supported runtime source is a kernel scalar argument cast to + index inside vecscope; vmi-to-vpto does not recover this value from GM/UB + scalar loads or surrounding context. private vector function runtime: assignment/lowering semantics are defined; full ptoas runtime depends on diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e44f32a97e..160b25a398 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -197,7 +197,7 @@ the immediately following complete endpoints. 3.42 group_slots scf.for loop-carried accumulator complete 3.43 internal function argument boundary materialization complete/design 3.44 masked_load grouped tail feeding S=32 reduce complete -3.45 dynamic S=32 create_group_mask complete/lit +3.45 dynamic S=32 create_group_mask complete ``` ### 3.1 `f16 -> f32 -> store` @@ -5305,15 +5305,14 @@ tree used by section 3.44: %mask_p1, %mask_p3 = pto.pdintlv_b32 %rm01_hi, %rm23_hi ``` -The current lit coverage validates the IR lowering: +Current coverage validates both IR lowering and runtime behavior: ```text test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store ``` -Runtime SIM coverage is intentionally not listed yet. A direct runtime case -needs a supported way to feed a dynamic scalar `active_cols` into a vector -kernel. Experiments with GM `pto.ldg` and UB `pto.load_scalar` either crashed -the Bisheng vector backend or produced an invalid scalar LSU address in the -SIM. That is an ABI/source-materialization gap, not a `create_group_mask` -layout-lowering gap. +The runtime case passes `active_cols` as a kernel scalar argument and casts it +to `index` inside `pto.vecscope`. This keeps scalar materialization outside +`vmi-to-vpto`; the lowering pass only consumes the current +`create_group_mask` operand. diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py new file mode 100644 index 0000000000..df3f6a24dc --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/golden.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +ACTIVE = 25 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + active_base = np.linspace(-0.875, 0.625, ACTIVE, dtype=np.float32) + inactive_base = np.linspace(19.0, 22.5, COLS - ACTIVE, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :ACTIVE] = active_base + np.float32(row) * np.float32(0.03125) + src[row, ACTIVE:] = inactive_base + np.float32(row) * np.float32(1.75) + + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_copy[:, ACTIVE:] = np.float32(0.0) + golden_sum = np.sum(src[:, :ACTIVE], axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto new file mode 100644 index 0000000000..8e9ebed693 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_create_group_mask_s32_reduce_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr, %active_cols_i32: i32) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %zero = arith.constant 0.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %active_cols = arith.index_cast %active_cols_i32 : i32 to index + %mask = pto.vmi.create_group_mask %active_cols + {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %zero_vec = pto.vmi.broadcast %zero : f32 -> !pto.vmi.vreg<256xf32> + %x = pto.vmi.masked_load %ub_src[%c0], %mask, %zero_vec + : !pto.ptr, !pto.vmi.mask<256xpred>, + !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp new file mode 100644 index 0000000000..5865140b26 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_create_group_mask_s32_reduce_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum, int activeCols); + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream) { + vmi_dynamic_create_group_mask_s32_reduce_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum, + activeCols); +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp new file mode 100644 index 0000000000..7bd86defb1 --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel(float *src, float *copy, + float *sum, int activeCols, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr int kActiveCols = 25; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_create_group_mask_s32_reduce_store_kernel( + srcDevice, copyDevice, sumDevice, kActiveCols, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 9e60c1effaa3b83373a0135584f7397cf5ff1898 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:35:03 +0800 Subject: [PATCH 12/54] Detail VMI layout assignment request rules --- .../vmi-layout-assignment-implementation.md | 177 ++++++++++++++++++ .../vmi-layout-assignment-lowering-design.md | 111 +++++++++++ 2 files changed, 288 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index b05d6e1552..3d0cab8215 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -421,6 +421,109 @@ compact S=12 logical S=16: diagnose if gather fallback is disabled/missing ``` +### 6.3.1 Request Builders + +Implement request generation as small per-op builders. The builders produce +candidate plans and use-site requests; they do not rewrite IR. + +```text +buildStoreRequests: + ordinary store -> dense contiguous request unless a layout-aware store plan is + selected + group_store -> group_slots(G,K) request plus stride/alignment capability + checks + +buildCastRequests: + extf f16->f32 -> source contiguous, result deinterleaved=2 + extf f8->f32 -> source contiguous, result deinterleaved=4 + truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous + truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous + group_slots slots=1 f32->f16 -> slot-preserving plan + group_slots slots=8 width-changing cast -> diagnostic unless a packed plan + exists + +buildGroupReduceRequests: + derive S = logical_lanes / num_groups + S=8 -> contiguous source, group_slots(G,8) result + S=16 -> deinterleaved=2/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=64 -> contiguous source, group_slots(G,1) result + other S -> diagnostic unless an explicit fallback plan is enabled + +buildGroupMemoryRequests: + group_load S=16/S=32 with aligned constant stride -> block_elems=8 plan + group_load row-local full chunks -> contiguous plan + group_slot_load unit stride -> group_slots(G,8) + group_slot_load aligned row-local stride -> group_slots(G,1) + unsupported dynamic/unaligned grouped memory -> diagnostic + +buildMaskRequests: + mask layout follows each consuming data layout + predicate granularity follows each consuming element type + create_mask/create_group_mask may be cloned for incompatible mask layout or + granularity requests + +buildControlFlowRequests: + region yields, branch operands, loop iter_args, call operands, and returns + create equality requests on the carried VMI layout variable +``` + +Request builders must record the requesting op. Diagnostics and inserted +helpers are use-site operations, so the user can see which consumer forced a +layout. + +### 6.3.2 Producer Classes + +The solver uses producer classes to decide whether a conflict can be solved by +cloning, equivalence propagation, or materialization. + +```text +cheap rematerializable producers: + load when address operands dominate the clone site, no intervening may-alias + write exists, and any full_read_elems proof is preserved + broadcast + create_mask + create_group_mask + group_broadcast + group_slot_load when the same address/no-alias/proof conditions as load hold + and the selected memory plan is legal at the clone site + +layout-transparent producers: + add/sub/mul/fma/min/max/neg/abs + select + bitcast + integer bitwise and shift ops + +fixed-layout producers: + extf/truncf physical conversion plans + group_load block-fragment plans + group_reduce result group_slots + masked_load when the physical memory-safety proof fixes a full-read plan +``` + +Conflict policy: + +```text +cheap producer: + clone for each incompatible request when cloning does not duplicate a + side-effect, cross an aliasing write, or duplicate an illegal memory read + +layout-transparent producer: + merge into the consumer-requested equivalence class; insert materialization + only at incompatible uses + +fixed-layout producer: + use registered materialization only; otherwise diagnose +``` + +This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to +`deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 +`group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` +that selected `block_elems=8` is fixed unless a block8-to-parity +materialization or rematerialized memory plan is registered. + ### 6.4 Solving And Rewriting Algorithm: @@ -451,6 +554,80 @@ Every ensure_* helper has a registered materialization plan. Every function/call signature carrying VMI is specialized or diagnosed. ``` +### 6.5 Rewrite Artifacts + +Assignment rewrites the IR so that later lowering has no hidden choices. + +```text +type rewrite: + every VMI data/mask result and block argument receives a layout attr + +selected_plan rewrite: + context-sensitive ops receive vmi.selected_plan + examples: group_reduce_addf, group_load, group_slot_load, group_broadcast, + group_slot cast, full-read masked_load plans + +clone rewrite: + cheap producers are cloned before their divergent use sites + each clone receives its own layout and selected_plan + +ensure rewrite: + non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use + site, with source and target layouts visible in the types + +granularity rewrite: + one semantic mask used by f32 and f16 consumers gets + ensure_mask_granularity or cloned mask producers + +control-flow rewrite: + scf.if/scf.for yields and block arguments are rewritten to one agreed layout; + materialization is inserted before yield when branches differ + +function rewrite: + private VMI functions are specialized or get callee-entry ensure_layout + public/external VMI functions are diagnosed +``` + +Canonical assigned IR shape for a conflicting load: + +```text +%x = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +pto.vmi.store %x_dense, ... +``` + +Canonical assigned IR shape for a cloned cheap producer: + +```text +%x_s16 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv2"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%x_s32 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} + : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +Canonical assigned IR shape for `group_broadcast` multi-use: + +```text +%b0 = pto.vmi.group_broadcast %slots + {vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +%b1 = pto.vmi.group_broadcast %slots + {vmi.selected_plan = "group_broadcast_slots8_vselr"} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +If the assigned IR does not have one of these explicit shapes, `vmi-to-vpto` +must reject it instead of attempting to recover the missing decision. + ## 7. OneToN Type Conversion `vmi-to-vpto` should use OneToN conversion for VMI values. diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 99a1a34c6c..5e43f6d9ec 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -422,6 +422,117 @@ one mask used by f32 and f16 consumers: vmi-to-vpto consumes the assigned per-use mask materialization ``` +### 5.5 Case-Driven Request Matrix + +The first implementation should build requests from the following finite table. +This table is deliberately case-derived; adding a new request kind requires a +new catalog case or a proof that it is equivalent to one listed here. + +```text +dense store: + requests dense contiguous source + if source is deinterleaved, assignment must insert ensure_layout or select a + store plan such as vstsx2 that consumes the assigned layout explicitly + +truncf f32 -> f16: + requests source deinterleaved=2, block_elems=1 + requests result contiguous f16 + +truncf f32 -> f8: + requests source deinterleaved=4, block_elems=1 + requests result contiguous f8 + +group_reduce S=8: + requests source contiguous + requests result group_slots(num_groups, slots=8) + +group_reduce S=16: + requests source deinterleaved=2, block_elems=1 or block_elems=8 + requests result group_slots(num_groups, slots=8) + +group_reduce S=32: + requests source deinterleaved=4, block_elems=1 or block_elems=8 + requests result group_slots(num_groups, slots=8) + +group_reduce S=64: + requests source contiguous + requests result group_slots(num_groups, slots=1) + +group_broadcast: + requests source group_slots(num_groups, slots=K) + produces one dense result layout per consumer request + is cloned per incompatible dense consumer + +group_store: + requests source group_slots(num_groups, slots=K) + selected plan also records output stride legality + +group_slot_load: + requests result group_slots(num_groups, slots=8) for packed unit-stride slots + requests result group_slots(num_groups, slots=1) for row-local aligned slots + +group_load: + requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block + fragment plans, or contiguous for row-local full-chunk plans + +masked_load: + requests result layout from its consumers + requests mask layout matching the result + requires explicit passthrough; padding is not synthesized + +create_mask/create_group_mask: + produces whichever mask layout each consumer requests + may be cloned per incompatible mask layout or granularity +``` + +Important negative requests: + +```text +ordinary dense add/mul/store/truncf cannot request group_slots +packed group_slots(slots=8) cannot request width-changing cast unless a packed +slot-preserving cast plan is registered +slots=1 group_store cannot request unit-stride row-major output until a pack or +unaligned-store plan exists +``` + +### 5.6 Conflict Resolution Matrix + +When one value receives incompatible requests, assignment resolves it using the +first legal row below. `vmi-to-vpto` never repeats this decision. + +```text +cheap producer with multiple requested layouts: + clone the producer and assign each clone independently + examples: load, broadcast, create_mask, create_group_mask, group_broadcast + memory-read producers require the same explicit no-alias and safe-read proof + at each clone site + +non-cheap value with registered materialization: + keep one chosen layout on the value and insert ensure_layout at the use site + examples: deinterleaved=4 -> contiguous before dense store + +layout-transparent chain: + assign the whole equivalence class to the non-contiguous consumer request when + that avoids materialization + examples: broadcast -> addf -> S=32 group_reduce + +control-flow join: + all incoming values must be materialized to one layout before yield/branch + examples: scf.if yielding group_slots, scf.for loop-carried group_slots + +private function boundary: + specialize or materialize at call/callee-entry before vmi-to-vpto + +no clone/materialization/specialization plan: + emit a diagnostic naming the requesting op and both layouts +``` + +The cost model may choose between legal rows only when the observable contract +is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both +valid reduce inputs, but `block_elems=8` is selected only when a producer plan +such as strided `group_load` naturally creates 32B row fragments or when cost +proves it cheaper without breaking another consumer such as `truncf`. + ## 6. Layout Assignment Algorithm `vmi-layout-assignment` is module-level. It must see function/call/control-flow From 8f8a8055bde1c1023d9f375ea9c0cba2341aa04d Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 12:40:16 +0800 Subject: [PATCH 13/54] Complete VMI layout request builder coverage --- .../vmi-layout-assignment-implementation.md | 15 +++++++++++++ .../vmi-layout-assignment-lowering-design.md | 22 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 3d0cab8215..36ef7a453c 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -459,15 +459,30 @@ buildGroupMemoryRequests: group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic +buildElementwiseRequests: + dense add/mul/fma/min/max/select -> all dense operands/results share one + dense layout + group-slot add/mul/select -> all operands/results share one group_slots(G,K) + dense/group_slots mixing -> diagnostic unless an explicit group_broadcast or + group_store boundary exists + buildMaskRequests: mask layout follows each consuming data layout predicate granularity follows each consuming element type create_mask/create_group_mask may be cloned for incompatible mask layout or granularity requests + masked_store requests source layout, mask layout, and store predicate + granularity explicitly buildControlFlowRequests: region yields, branch operands, loop iter_args, call operands, and returns create equality requests on the carried VMI layout variable + +buildFunctionBoundaryRequests: + private/internal function argument/result layouts are specialized or + materialized with callee-entry/return-site helpers + public/external VMI arguments/results diagnose unless enablePublicVMIABI has + a real ABI plan ``` Request builders must record the requesting op. Diagnostics and inserted diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 5e43f6d9ec..c16fedcde3 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -467,6 +467,17 @@ group_store: requests source group_slots(num_groups, slots=K) selected plan also records output stride legality +dense elementwise add/mul/fma/min/max/select: + requests all dense data operands and results use one dense layout + mask operands request the same data layout and the consumer element + granularity + +group-slot elementwise add/mul/select: + requests all group-slot operands and results use the same + group_slots(num_groups, slots=K) + rejects mixing dense and group_slots without explicit group_broadcast or + group_store + group_slot_load: requests result group_slots(num_groups, slots=8) for packed unit-stride slots requests result group_slots(num_groups, slots=1) for row-local aligned slots @@ -480,9 +491,20 @@ masked_load: requests mask layout matching the result requires explicit passthrough; padding is not synthesized +masked_store: + requests dense source layout selected by the store plan + requests mask layout matching the source layout and store element granularity + does not choose memory safety for an earlier load + create_mask/create_group_mask: produces whichever mask layout each consumer requests may be cloned per incompatible mask layout or granularity + +scf.if/scf.for/call/return: + requests equality across carried VMI values, yielded values, call operands, + callee arguments, and function results + private/internal functions may specialize or materialize at boundaries + public/external VMI boundaries are diagnostics until an ABI is defined ``` Important negative requests: From 3f51933ad05f57c0161618e66e5d1474f11c04ff Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:05:04 +0800 Subject: [PATCH 14/54] Inline private VMI physical helpers before VPTO emission --- .../vmi-layout-assignment-implementation.md | 77 ++++++------ .../vmi-layout-assignment-lowering-design.md | 12 +- docs/designs/vmi-layout-lowering-cases.md | 37 +++++- ...o => vmi_ptoas_call_boundary_vecscope.pto} | 15 ++- .../lit/vmi/vmi_ptoas_private_call_inline.pto | 42 +++++++ .../compare.py | 40 ++++++ .../golden.py | 46 +++++++ .../kernel.pto | 70 +++++++++++ .../launch.cpp | 36 ++++++ .../main.cpp | 99 +++++++++++++++ .../ptoas.flags | 1 + .../vmi/private-call-inline-store/compare.py | 40 ++++++ .../vmi/private-call-inline-store/golden.py | 46 +++++++ .../vmi/private-call-inline-store/kernel.pto | 67 ++++++++++ .../vmi/private-call-inline-store/launch.cpp | 33 +++++ .../vmi/private-call-inline-store/main.cpp | 97 +++++++++++++++ .../vmi/private-call-inline-store/ptoas.flags | 1 + tools/ptoas/ptoas.cpp | 117 ++++++++++++++++++ 18 files changed, 829 insertions(+), 47 deletions(-) rename test/lit/vmi/{vmi_ptoas_call_boundary_vecscope_invalid.pto => vmi_ptoas_call_boundary_vecscope.pto} (78%) create mode 100644 test/lit/vmi/vmi_ptoas_private_call_inline.pto create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp create mode 100644 test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/private-call-inline-store/compare.py create mode 100644 test/vpto/cases/vmi/private-call-inline-store/golden.py create mode 100644 test/vpto/cases/vmi/private-call-inline-store/kernel.pto create mode 100644 test/vpto/cases/vmi/private-call-inline-store/launch.cpp create mode 100644 test/vpto/cases/vmi/private-call-inline-store/main.cpp create mode 100644 test/vpto/cases/vmi/private-call-inline-store/ptoas.flags diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 36ef7a453c..dc54a4af09 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -1068,13 +1068,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-dynamic-scalar CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-private-calls CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=41 FAIL=0 -summary: .tmp/vmi-runtime-batch-dynamic-scalar/parallel-summary.tsv +PASS=43 FAIL=0 +summary: .tmp/vmi-runtime-batch-private-calls/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-dynamic-scalar.log + .tmp/vmi-runtime-batch-private-calls.log result: no matches ``` @@ -1125,7 +1125,7 @@ Aggregate catalog headings are covered through their endpoint subcases: 3.16.2 row-local slots=1 positive plus dynamic/unaligned diagnostics 3.25 function boundary layout specialization: - 3.25.1 private/internal boundary lit coverage, runtime backend gap + 3.25.1 private/internal boundary lit and runtime coverage 3.25.2 public/external boundary diagnostics ``` @@ -1137,14 +1137,10 @@ SIM-backed positive endpoints: 3.6.1, 3.6.2, 3.6.3, 3.7.1, 3.7.2, 3.7.3, 3.8, 3.10, 3.11.1, 3.12, 3.15.1, 3.15.2, 3.16.1 positive, 3.16.2 positive, 3.17, 3.18, - 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.26, + 3.19.1, 3.20, 3.21, 3.22, 3.23, 3.24, 3.25.1, 3.26, 3.27 positive, 3.28 positive, 3.29, 3.31, 3.32, 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, - 3.40, 3.41, 3.42, 3.44, 3.45 - -lit-backed positive endpoints with runtime gap: - 3.25.1 private/internal function boundary - 3.43 internal function argument boundary materialization + 3.40, 3.41, 3.42, 3.43, 3.44, 3.45 diagnostic endpoints: 3.7.4, 3.9, 3.11.2, 3.13, 3.14, 3.15.3, @@ -1155,10 +1151,10 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 41 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=41 FAIL=0 - latest full VMI lit sweep passed: 312/312 + latest broad VMI runtime sweep passed: PASS=43 FAIL=0 + latest full VMI lit sweep passed: 313/313 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1354,16 +1350,37 @@ runtime SIM: test/vpto/cases/vmi/group-slots-scf-for-store ``` -Current checked-in lit coverage for 3.43 internal function argument boundary +Current checked-in coverage for 3.25.1 private function result boundary: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +implementation note: + after vmi-to-vpto physicalizes the private helper, ptoas inlines private + single-block helpers whose signatures contain !pto.vreg or !pto.mask. This + happens before VPTO vecscope/backend emission, so physical vector values do + not escape through a function return. +``` + +Current checked-in coverage for 3.43 internal function argument boundary materialization: ```text lit: test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto runtime SIM: - blocked by the current private vector callee backend path; see known - implementation gaps below + test/vpto/cases/vmi/private-call-argument-boundary-store + +implementation note: + private physical helper inlining also covers void helper calls with physical + VMI arguments, so the backend no longer sees a physical VPTO vector function + ABI for this internal boundary. ``` Current checked-in coverage for packed group-slot RHS elementwise continuations @@ -1550,7 +1567,6 @@ Diagnostic-only cases: 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements 3.19.2 block_elems=8 value consumed by truncf without materialization plan -3.25.1 full ptoas emission for private VMI callees that return VPTO vector values 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -1578,7 +1594,6 @@ lit: test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto test/lit/vmi/vmi_layout_assignment_external_decl_invalid.pto - test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto test/lit/vmi/vmi_to_vpto_masked_load_nonfull_invalid.pto test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto ``` @@ -1587,25 +1602,11 @@ Known implementation gaps before all catalog cases can become runtime SIM coverage: ```text -remaining function runtime coverage: - 3.25.1 internal function boundary specialization has layout-assignment and - vmi-to-vpto lit coverage, but full ptoas emission still fails after - physicalization because today's inferred pto.vecscope is resultless and VPTO - vector-scope values cannot escape through a function return. Runtime coverage - requires either a resultful vecscope/VPTO vector ABI or an explicit inlining - policy before vecscope inference. - - 3.43 internal function argument boundary materialization has - layout-assignment and vmi-to-vpto lit coverage. Full ptoas emission for a - private void vector callee currently reaches the Bisheng device backend and - fails on the physicalized callee with: - - fatal error: error in backend: Do not know how to split the result of this operator! - - Runtime coverage requires either inlining private vector callees before the - device backend path or adding backend support for the physical VPTO vector - function ABI. This is a runtime/backend gap, not a license for `vmi-to-vpto` - to infer layouts from caller/callee context. +private physical function ABI: + 3.25.1 and 3.43 runtime coverage is closed for private/internal single-block + helpers by inlining private physical VMI helpers after vmi-to-vpto and before + VPTO vecscope/backend emission. Public/external VMI boundaries are still + rejected until a stable VMI ABI is defined. memory-proof runtime coverage: 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index c16fedcde3..0b5a658fbe 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -133,8 +133,10 @@ dynamic active_elems_per_group runtime source: scalar loads or surrounding context. private vector function runtime: - assignment/lowering semantics are defined; full ptoas runtime depends on - backend support or an inlining policy for physical VPTO vector callees. + private/internal single-block helpers are runtime-covered by ptoas inlining + private physical VMI helpers after vmi-to-vpto and before VPTO vecscope/backend + emission. This is a post-physicalization backend hygiene step; vmi-to-vpto + still lowers only from assigned layouts and helper ops. diagnostic-only cases: compact S=12 gather fallback, packed slots=8 width-changing cast, public VMI @@ -683,8 +685,10 @@ the initial value or previous iteration during lowering. Internal/private VMI function boundaries must make layout choices explicit in the assigned IR. The baseline implementation keeps function arguments in a contiguous VMI ABI and inserts callee-entry `ensure_layout` helpers when the -callee body needs another layout. A later private-function optimization may -specialize signatures directly: +callee body needs another layout. Private helpers are then physicalized by +`vmi-to-vpto` and inlined before VPTO vecscope/backend emission so physical +`!pto.vreg`/`!pto.mask` values do not become a backend function ABI. A later +private-function optimization may specialize signatures directly: ```text func @producer() -> !vmi.vreg<256xf32, deinterleaved=4> diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 160b25a398..d0ec9f70a5 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -177,7 +177,7 @@ the immediately following complete endpoints. 3.22 scf.for loop-carried layout complete 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete -3.25 function boundary layout specialization complete/design +3.25 function boundary layout specialization complete 3.26 S=16 grouped tail through broadcast/reduce/store complete 3.27 S=32 group_load with stride greater than group size complete 3.28 group_slot_load slots=1 aligned non-unit stride complete @@ -195,7 +195,7 @@ the immediately following complete endpoints. 3.40 scalar broadcast feeding dense and grouped users complete/materialization 3.41 non-rematerializable value with incompatible users complete/materialization 3.42 group_slots scf.for loop-carried accumulator complete -3.43 internal function argument boundary materialization complete/design +3.43 internal function argument boundary materialization complete 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete ``` @@ -3269,6 +3269,22 @@ for r = 0..7: out[off + r] = reduce(row_r[0..31]) ``` +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_ptoas_private_call_inline.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-inline-store + +ptoas pipeline: + vmi-layout-assignment makes the private result layout explicit + vmi-to-vpto physicalizes the private helper result into !pto.vreg values + ptoas then inlines private physical VMI helpers before VPTO vecscope/backend + emission, so physical vector values do not escape through a function return +``` + #### 3.25.2 Public Or External VMI Boundary VMI input: @@ -5125,6 +5141,23 @@ optimization must still be expressed in the assigned VMI function type before `vmi-to-vpto` runs. ``` +Runtime closure: + +```text +lit: + test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto + test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto + +runtime SIM: + test/vpto/cases/vmi/private-call-argument-boundary-store + +ptoas pipeline: + vmi-layout-assignment inserts explicit callee-entry materialization + vmi-to-vpto physicalizes the call operands and callee body + ptoas then inlines the private physical helper before VPTO vecscope/backend + emission, so the backend never needs a physical VPTO vector function ABI +``` + ### 3.44 `masked_load` Grouped Tail Feeding S=32 Reduce This case connects the explicit `masked_load` tail model from section 3.30 with diff --git a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto similarity index 78% rename from test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto rename to test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto index 950215e5e4..771ae5904c 100644 --- a/test/lit/vmi/vmi_ptoas_call_boundary_vecscope_invalid.pto +++ b/test/lit/vmi/vmi_ptoas_call_boundary_vecscope.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s module attributes {pto.target_arch = "a5"} { module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { @@ -31,5 +31,14 @@ module attributes {pto.target_arch = "a5"} { } } -// CHECK: cannot infer resultless pto.vecscope because VPTO vector-scope data cannot have external users -// CHECK-SAME: escaping value type is '!pto.vreg<64xf32>' +// CHECK-NOT: func.func private @callee +// CHECK-LABEL: func.func @caller +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vadd +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: func.call @callee +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_ptoas_private_call_inline.pto b/test/lit/vmi/vmi_ptoas_private_call_inline.pto new file mode 100644 index 0000000000..c5e1604bec --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_private_call_inline.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%scalar: f32) + -> !pto.vmi.vreg<128xf32> { + %value = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32> + return %value : !pto.vmi.vreg<128xf32> + } + + func.func @vmi_ptoas_private_call_inline( + %scalar: f32, + %dst: !pto.ptr, + %offset: index) { + %value = call @producer(%scalar) + : (f32) -> !pto.vmi.vreg<128xf32> + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + } +} + +// CHECK-NOT: func.func private @producer +// CHECK-LABEL: func.func @vmi_ptoas_private_call_inline +// CHECK: pto.vecscope +// CHECK: pto.vdup +// CHECK: pto.vsts +// CHECK: pto.vsts +// CHECK-NOT: call @producer +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto new file mode 100644 index 0000000000..eb8f7f5e6a --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto @@ -0,0 +1,70 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @consume(%x: !pto.vmi.vreg<256xf32>, + %mask: !pto.vmi.mask<256xpred>, + %out: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + + func.func @vmi_private_call_argument_boundary_store_kernel( + %src_gm: !pto.ptr, %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + func.call @consume(%x, %mask, %ub_sum, %c0) + : (!pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred>, + !pto.ptr, index) -> () + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp new file mode 100644 index 0000000000..ba6be566de --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_argument_boundary_store_kernel(__gm__ float *src, + __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream) { + vmi_private_call_argument_boundary_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp new file mode 100644 index 0000000000..5ce943feae --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_argument_boundary_store_kernel(float *src, + float *copy, + float *sum, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_argument_boundary_store_kernel(srcDevice, copyDevice, + sumDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/private-call-inline-store/compare.py b/test/vpto/cases/vmi/private-call-inline-store/compare.py new file mode 100644 index 0000000000..9f34394fa1 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/compare.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check(name: str, golden_name: str) -> None: + golden = np.fromfile(golden_name, dtype=np.float32) + output = np.fromfile(name, dtype=np.float32) + if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): + return + if golden.shape != output.shape: + print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +def main() -> None: + check("v2.bin", "golden_v2.bin") + check("v3.bin", "golden_v3.bin") + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/golden.py b/test/vpto/cases/vmi/private-call-inline-store/golden.py new file mode 100644 index 0000000000..41f1b1b714 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/golden.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 +SENTINEL = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) + sums = np.full(ROWS, SENTINEL, dtype=np.float32) + golden_copy = src.copy() + golden_sum = np.sum(src, axis=1, dtype=np.float32).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + copy.reshape(-1).tofile(output_dir / "v2.bin") + sums.tofile(output_dir / "v3.bin") + golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") + golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/private-call-inline-store/kernel.pto b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto new file mode 100644 index 0000000000..5f7beec943 --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func private @producer(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + return %x : !pto.vmi.vreg<256xf32> + } + + func.func @vmi_private_call_inline_store_kernel(%src_gm: !pto.ptr, + %copy_gm: !pto.ptr, + %sum_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_copy = pto.castptr %c2048_i64 : i64 -> !pto.ptr + %ub_sum = pto.castptr %c4096_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = func.call @producer(%ub_src, %c0) + : (!pto.ptr, index) -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x, %ub_copy[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_copy, %copy_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_sum, %sum_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/launch.cpp b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp new file mode 100644 index 0000000000..b5015d7cda --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_private_call_inline_store_kernel(__gm__ float *src, __gm__ float *copy, + __gm__ float *sum); + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream) { + vmi_private_call_inline_store_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/main.cpp b/test/vpto/cases/vmi/private-call-inline-store/main.cpp new file mode 100644 index 0000000000..325ebc902e --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/main.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_private_call_inline_store_kernel(float *src, float *copy, + float *sum, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 32; + constexpr size_t kSrcElems = kRows * kCols; + constexpr size_t kSumElems = kRows; + size_t srcBytes = kSrcElems * sizeof(float); + size_t copyBytes = kSrcElems * sizeof(float); + size_t sumBytes = kSumElems * sizeof(float); + float *srcHost = nullptr; + float *copyHost = nullptr; + float *sumHost = nullptr; + float *srcDevice = nullptr; + float *copyDevice = nullptr; + float *sumDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); + ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_private_call_inline_store_kernel(srcDevice, copyDevice, sumDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", copyHost, copyBytes); + WriteFile("./v3.bin", sumHost, sumBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(copyDevice); + aclrtFree(sumDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(copyHost); + aclrtFreeHost(sumHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/private-call-inline-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 4d0bc4b877..fbe74e9b69 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -19,6 +19,8 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser/Parser.h" @@ -1686,6 +1688,117 @@ static LogicalResult verifyNoPublicVMISignature(ModuleOp module) { return failure(result.wasInterrupted()); } +static bool containsVMIPhysicalType(Type type) { + if (isa(type)) + return true; + if (auto functionType = dyn_cast(type)) { + return llvm::any_of(functionType.getInputs(), containsVMIPhysicalType) || + llvm::any_of(functionType.getResults(), containsVMIPhysicalType); + } + return false; +} + +static bool isPrivatePhysicalVMIHelper(func::FuncOp func) { + return !func.isPublic() && !func.isExternal() && + func.getBody().hasOneBlock() && + containsVMIPhysicalType(func.getFunctionType()); +} + +static LogicalResult inlinePrivatePhysicalVMIHelperCall(func::CallOp call, + func::FuncOp callee) { + if (callee.isExternal()) + return call.emitOpError("callee must have a body before inlining"); + if (!callee.getBody().hasOneBlock()) + return call.emitOpError("callee must be single-block before inlining"); + + Block &entry = callee.getBody().front(); + if (entry.getNumArguments() != call.getNumOperands()) + return call.emitOpError("callee argument count mismatch during inlining"); + + auto returnOp = dyn_cast(entry.getTerminator()); + if (!returnOp) + return call.emitOpError("callee must terminate with func.return"); + if (returnOp.getNumOperands() != call.getNumResults()) + return call.emitOpError("callee return/result arity mismatch during inlining"); + + OpBuilder builder(call); + IRMapping mapping; + for (auto [arg, operand] : llvm::zip(entry.getArguments(), call.getOperands())) + mapping.map(arg, operand); + + for (Operation &op : entry.without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + for (auto [oldResult, newResult] : + llvm::zip(op.getResults(), newOp->getResults())) + mapping.map(oldResult, newResult); + } + + for (auto [callResult, returnOperand] : + llvm::zip(call.getResults(), returnOp.getOperands())) + callResult.replaceAllUsesWith(mapping.lookup(returnOperand)); + + call.erase(); + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpersInModule(ModuleOp module) { + bool madeProgress = true; + while (madeProgress) { + madeProgress = false; + + SmallVector calls; + module.walk([&](func::CallOp call) { calls.push_back(call); }); + + for (func::CallOp call : calls) { + if (!call || !call->getBlock()) + continue; + + func::FuncOp caller = call->getParentOfType(); + auto calleeAttr = call.getCalleeAttr(); + if (!caller || !calleeAttr) + continue; + + func::FuncOp callee = + SymbolTable::lookupNearestSymbolFrom( + call, calleeAttr.getAttr()); + if (!callee || !isPrivatePhysicalVMIHelper(callee)) + continue; + if (callee == caller) + return call.emitOpError("recursive private VMI helper call cannot be " + "inlined before VPTO emission"); + + if (failed(inlinePrivatePhysicalVMIHelperCall(call, callee))) + return failure(); + madeProgress = true; + } + } + + SymbolTable symbolTable(module); + SmallVector deadFuncs; + for (func::FuncOp func : module.getOps()) { + if (!isPrivatePhysicalVMIHelper(func)) + continue; + auto uses = symbolTable.getSymbolUses(func, module); + if (uses && uses->empty()) + deadFuncs.push_back(func); + } + for (func::FuncOp func : deadFuncs) + func.erase(); + + return success(); +} + +static LogicalResult inlinePrivatePhysicalVMIHelpers(ModuleOp module) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(module))) + return failure(); + WalkResult result = module.walk([&](ModuleOp nestedModule) { + if (failed(inlinePrivatePhysicalVMIHelpersInModule(nestedModule))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { if (failed(verifyNoPublicVMISignature(module.get()))) return failure(); @@ -1703,6 +1816,10 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { llvm::errs() << "Error: VMI-to-VPTO pipeline failed.\n"; return failure(); } + if (failed(inlinePrivatePhysicalVMIHelpers(module.get()))) { + llvm::errs() << "Error: failed to inline private VMI physical helpers.\n"; + return failure(); + } return success(); } From ea2491fda5604b090db0252f3e8a322dfbbd8c31 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:24:17 +0800 Subject: [PATCH 15/54] Validate required VMI selected plans --- .../vmi-layout-assignment-implementation.md | 58 +++++++- .../vmi-layout-assignment-lowering-design.md | 80 +++++++++++ docs/designs/vmi-layout-lowering-cases.md | 10 +- include/PTO/Transforms/Passes.td | 4 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 132 ++++++++++++++++++ ...out_gate_missing_selected_plan_invalid.pto | 23 +++ 6 files changed, 297 insertions(+), 10 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index dc54a4af09..102bcc628c 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -196,6 +196,55 @@ Ops that are uniquely determined by layout may omit this attr, but the rule should be conservative. If future maintainers could reasonably ask "why this lowering?", assignment should write a plan. +Required-plan table for the current implementation: + +```text +op required when +group_load result layout matches a registered group_load plan +group_slot_load explicit group_slots slots=8 or slots=1 result +group_reduce_addf source/result layouts match a registered reduce plan +group_broadcast explicit slots=8 or slots=1 source and dense result +truncf group_slots slots=1 f32->f16 slot-preserving cast +ensure_layout always carries source/result layouts instead of plan +ensure_mask_layout always carries source/result layouts instead of plan +ensure_mask_granularity always carries source/result granularities instead of plan +``` + +Layout/attr-only decisions today: + +```text +load result layout plus full_read_elems/full chunk proof +group_store source group_slots layout plus explicit output stride +masked_load explicit passthrough, mask layout, and memory proof +masked_store/select operand/result layouts plus mask granularity +dense extf/truncf source/result layouts and element widths +``` + +Implementation rule: + +```text +vmi-layout-assignment attaches the required plan before type conversion. +validate-assigned-vmi rejects a required-plan op that lacks vmi.selected_plan. +vmi-to-vpto verifies the plan against the already assigned layouts and emits +VMI-LAYOUT-CONTRACT instead of selecting a fallback from producer/user context. +If a layout/attr-only op later gains a second legal recipe, that recipe must be +promoted into the required-plan table before vmi-to-vpto can emit it. +Unsupported shapes that have no registered plan still diagnose through their +specific capability check rather than failing with a generic missing-plan error. +``` + +Examples of forbidden recovery in `vmi-to-vpto`: + +```text +group_reduce_addf cannot walk to a load/group_load producer to choose S=16 + parity versus block8. +group_store cannot inspect the group_reduce producer; it consumes only the + assigned source layout and explicit stride. +group_broadcast cannot inspect sibling users to decide whether to rematerialize. +masked_load cannot inspect the mask producer to prove memory safety. +func.call cannot inspect the callee body to decide physical function layout. +``` + ## 4. VMI Surface Ops Required By Cases Initial op set from the case catalog: @@ -1068,13 +1117,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-private-calls CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-selected-plan-gate CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-private-calls/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-selected-plan-gate/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-private-calls.log + .tmp/vmi-runtime-batch-selected-plan-gate.log result: no matches ``` @@ -1154,7 +1203,7 @@ repository evidence: all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 313/313 + latest full VMI lit sweep passed: 314/314 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1585,6 +1634,7 @@ entries: ```text lit: + test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 0b5a658fbe..9261a938dd 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -280,6 +280,86 @@ invariant is not illustrative: if a lowering decision is not uniquely implied by op + assigned operand/result layouts + explicit attrs, assignment must write a selected plan. +### 4.1 Selected Plan Contract + +`selected_plan` is not an optimization hint. It is the serialized answer to a +question that would otherwise require `vmi-to-vpto` to inspect producer, +consumer, control-flow, memory, or mask context. + +Required plans in the current implementation: + +```text +group_load: + required for registered result layouts. The plan fixes source_group_stride + handling and whether the result is contiguous chunks, S=16 block8, or S=32 + block8. Unsupported shapes diagnose through the capability check instead of + inventing a plan. + +group_slot_load: + required for explicit slots=8 or slots=1 layouts. The plan fixes packed + scalar load versus row-local lane-0 load. A single source op may be + rematerialized into two different planned ops. + +group_reduce_addf: + required for registered S=8/S=16/S=32/S=64 shapes. The plan fixes parity + versus block8, packed slots=8 versus row-local slots=1, and multi-chunk + arity. Unsupported group sizes diagnose as unsupported capability, not as + missing selected_plan. + +group_broadcast: + required for explicit slots=8 or slots=1 sources. The plan fixes source + interpretation and the vselr index recipe for the requested dense result + layout. Legacy bare group_slots are tolerated only as compatibility input and + must not be emitted by layout assignment. + +truncf: + required for group_slots slots=1 f32->f16, where the cast is a slot-preserving + group-slot cast rather than an ordinary dense VCVT path. +``` + +Layout-only or attr-only decisions in the current implementation: + +```text +load: + result layout plus explicit memory attrs decide the lowering. full_read_elems + is the memory-safety proof; vmi-to-vpto may not recover that proof from MTE or + caller context. + +group_store: + source group_slots layout and explicit output stride decide packed slots=8 + versus row-local slots=1 store legality. If another legal store recipe is + introduced, assignment must attach a selected plan before vmi-to-vpto uses it. + +masked_load: + explicit passthrough, mask layout, full physical read, shaped safe-tail memref, + or an explicit diagnostic decide legality. A future stable gather fallback + must be selected by assignment before vmi-to-vpto lowers it. + +masked_store/select/elementwise: + operand/result layouts and explicit mask granularity decide the lowering. + They remain transfer ops unless a future case introduces competing recipes. + +extf/truncf: + dense width-changing paths are layout-determined today. Any future + commute-through-group-broadcast or alternative VCVT recipe must become a + selected plan first. +``` + +Forbidden plan recovery: + +```text +No pattern may synthesize one of the required plans by: + - walking from group_reduce to the load/group_load producer + - walking from store/broadcast/truncf to the group_reduce producer + - scanning sibling users of a group_slots value + - inspecting branch bodies or loop bodies from a control-flow boundary + - inspecting private callee bodies while lowering a call +``` + +If a required plan is missing, `vmi-to-vpto` emits +`VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, +assigned layouts, and the missing plan class. + ## 5. Plan Registry The compiler owns a target-aware plan registry. Layout assignment queries this diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index d0ec9f70a5..8e2d6bfceb 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -167,13 +167,13 @@ the immediately following complete endpoints. 3.12 control-flow join before group_reduce complete 3.13 packed group-slot f32 -> f16 cast illegal diagnostic 3.14 unsupported group size illegal diagnostic -3.15 compact S=12 written as logical S=16 complete/design +3.15 compact S=12 written as logical S=16 complete/diagnostic 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization 3.19 S=16 reduce block_elems plan selection complete/diagnostic 3.20 group_slots control-flow join complete -3.21 S=32 tail with full-tile-readable source complete/design +3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete 3.23 group_broadcast with multiple dense consumers complete 3.24 mask with elementwise/select/store complete @@ -187,9 +187,9 @@ the immediately following complete endpoints. 3.32 f32 feeding f8 store and S=32 reduce complete 3.33 one dense value feeding S=16 and S=32 reduces complete/materialization 3.34 S=64 group-slot result f32->f16 cast complete -3.35 group_slots fanout to group_store and broadcast complete/design -3.36 same scalar source materialized as slots=8/slots=1 complete/design -3.37 S=64 group_store with non-unit output stride complete/design +3.35 group_slots fanout to group_store and broadcast complete +3.36 same scalar source materialized as slots=8/slots=1 complete/materialization +3.37 S=64 group_store with non-unit output stride complete 3.38 multi-tile S=32 group_reduce complete 3.39 strided S=32 group_load through broadcast/reduce complete 3.40 scalar broadcast feeding dense and grouped users complete/materialization diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 435b70a328..e64d5bba3f 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -810,7 +810,9 @@ def PTOValidateVMILayoutIR Checks the post-layout-assignment VMI stage: every VMI data value must have a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, - and VMI typed values must stay inside VMI semantic/helper or structural ops. + VMI typed values must stay inside VMI semantic/helper or structural ops, + and context-sensitive VMI ops must carry the selected_plan contract emitted + by layout assignment. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6ce3e8eecd..889a5ebe85 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -36,6 +36,8 @@ using namespace mlir::pto; namespace { +static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; + bool isVMIType(Type type) { return isa(type); } bool isPhysicalVPTOType(Type type) { @@ -159,6 +161,133 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagLayoutContractPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +std::optional getGroupSize(VMIVRegType type, int64_t numGroups) { + if (!type || numGroups <= 0 || type.getElementCount() % numGroups != 0) + return std::nullopt; + return type.getElementCount() / numGroups; +} + +bool hasRegisteredGroupReducePlan(VMIGroupReduceAddFOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return false; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + if (!sourceLayout) + return false; + + std::optional groupSize = + getGroupSize(sourceType, op.getNumGroupsAttr().getInt()); + if (!groupSize) + return false; + + if (sourceLayout.isContiguous()) + return *groupSize == 8 || *groupSize == 64; + + if (!sourceLayout.isDeinterleaved()) + return false; + if (*groupSize == 16 && sourceLayout.getFactor() == 2) + return sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8; + if (*groupSize == 32 && sourceLayout.getFactor() == 4) + return sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8; + return false; +} + +bool hasRegisteredGroupLoadPlan(VMIGroupLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return false; + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return false; + if (layout.isContiguous()) + return true; + if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) + return false; + + std::optional groupSize = + getGroupSize(resultType, op.getNumGroupsAttr().getInt()); + if (!groupSize) + return false; + return (*groupSize == 16 && layout.getFactor() == 2) || + (*groupSize == 32 && layout.getFactor() == 4); +} + +bool hasRegisteredGroupSlotLoadPlan(VMIGroupSlotLoadOp op) { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return false; + VMILayoutAttr layout = resultType.getLayoutAttr(); + return layout && layout.isGroupSlots() && + layout.getNumGroups() == op.getNumGroupsAttr().getInt() && + (layout.getSlots() == 8 || layout.getSlots() == 1); +} + +bool hasRegisteredGroupBroadcastPlan(VMIGroupBroadcastOp op) { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!sourceType || !resultType) + return false; + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + sourceLayout.getNumGroups() == op.getNumGroupsAttr().getInt() && + !resultLayout.isGroupSlots() && + (sourceLayout.getSlots() == 8 || sourceLayout.getSlots() == 1); +} + +bool hasRegisteredGroupSlotTruncFPlan(Operation *op) { + auto truncf = dyn_cast(op); + if (!truncf) + return false; + + auto sourceType = dyn_cast(truncf.getSource().getType()); + auto resultType = dyn_cast(truncf.getResult().getType()); + if (!sourceType || !resultType) + return false; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots() && sourceLayout.getSlots() == 1 && + resultLayout.getSlots() == 1 && sourceType.getElementType().isF32() && + resultType.getElementType().isF16(); +} + +bool requiresSelectedPlan(Operation *op) { + if (auto groupLoad = dyn_cast(op)) + return hasRegisteredGroupLoadPlan(groupLoad); + if (auto groupSlotLoad = dyn_cast(op)) + return hasRegisteredGroupSlotLoadPlan(groupSlotLoad); + if (auto reduce = dyn_cast(op)) + return hasRegisteredGroupReducePlan(reduce); + if (auto broadcast = dyn_cast(op)) + return hasRegisteredGroupBroadcastPlan(broadcast); + return hasRegisteredGroupSlotTruncFPlan(op); +} + +LogicalResult verifySelectedPlanContract(Operation *op, + llvm::raw_ostream *diagOS) { + if (!requiresSelectedPlan(op)) + return success(); + if (op->getAttrOfType(kVMISelectedPlanAttrName)) + return success(); + return emitLayoutContract( + op, diagOS, + Twine(op->getName().getStringRef()) + + " requires vmi.selected_plan selected by vmi-layout-assignment"); +} + LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -378,6 +507,9 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (!hasVMIOrPhysicalType(op)) return success(); + if (failed(verifySelectedPlanContract(op, diagOS))) + return failure(); + if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return success(); diff --git a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto b/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto new file mode 100644 index 0000000000..d06bd275ca --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_missing_selected_plan_invalid( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf requires vmi.selected_plan selected by vmi-layout-assignment From 3b7b585fbe81527f17efd3e9d3b5d1f45c1bae04 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:34:57 +0800 Subject: [PATCH 16/54] Document VMI layout closure matrix --- .../vmi-layout-assignment-implementation.md | 163 ++++++++++++++++-- 1 file changed, 151 insertions(+), 12 deletions(-) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 102bcc628c..dfc8588a86 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -692,6 +692,138 @@ Canonical assigned IR shape for `group_broadcast` multi-use: If the assigned IR does not have one of these explicit shapes, `vmi-to-vpto` must reject it instead of attempting to recover the missing decision. +### 6.6 Case-To-Implementation Closure Matrix + +The current case catalog is sufficient for the first implementation. No new +layout kind is justified by the supported endpoints. The implementation work +should instead close the following finite matrix. Each row names the request +builder that owns the decision, the assignment artifact that must appear in IR, +and the `vmi-to-vpto` contract. + +```text +case family builder / owner assignment artifact +3.1, 3.2, 3.3 dense casts buildCastRequests dense layout on each cast result +3.29 mask width split buildMaskRequests per-use mask granularity helper +3.31, 3.32 dense fanout conflict resolver cloned load or ensure_layout + +vmi-to-vpto contract: + consume only the assigned dense layouts. It may emit VCVT and dense + materialization, but it must not choose deinterleaved=2/4 by inspecting a + later truncf, store, or group_reduce user. +``` + +```text +case family builder / owner assignment artifact +3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous plan +3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 plan +3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 plan +3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local plan +3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks +3.19.1 S=16 block_elems choice buildGroupReduceRequests selected block_elems reduce plan +3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks +3.26 grouped tail buildMaskRequests split grouped masks +3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values + +vmi-to-vpto contract: + lower each reduce from source layout, result group_slots layout, and + selected_plan. It must not walk to the load/group_load producer to decide + parity versus block8, row-local versus packed slots, or static versus dynamic + mask generation. +``` + +```text +case family builder / owner assignment artifact +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load plan +3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan +3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan +3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan +3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load plan +3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof +3.39 strided load fanout conflict resolver preserving layout or materialization + +vmi-to-vpto contract: + consume only explicit memory stride/alignment attrs, selected_plan, and + layouts. It must not infer safe read/write placement from neighboring + compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes + stay diagnostics until a gather plan is registered. +``` + +```text +case family builder / owner assignment artifact +3.8 reduce->truncf->broadcast conflict resolver slot cast plus dense materialization +3.10 non-load S=32 producer buildElementwiseRequests transparent deinterleaved chain +3.17 broadcast deint consumer conflict resolver use-site group_broadcast layout +3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout +3.23 broadcast multi-user conflict resolver cloned group_broadcast +3.33 S=16 + S=32 users conflict resolver cloned load or materialization +3.34 S=64 slots=1 cast buildCastRequests group_slot_cast selected plan +3.35 slots fanout buildElementwiseRequests same group_slots layout on users +3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast +3.40 scalar dense + grouped conflict resolver cloned broadcast +3.41 incompatible fixed value conflict resolver diagnostic or ensure_layout + +vmi-to-vpto contract: + each op instance is already single-plan. The lowering pass never scans + sibling users to decide whether to clone, pack, broadcast, or materialize. +``` + +```text +case family builder / owner assignment artifact +3.21 S=32 safe full-read tail buildMaskRequests full_read_elems memory proof +3.24 mask/select/store buildMaskRequests explicit mask layout/granularity +3.12 scf.if before reduce buildControlFlowRequests common yielded layout +3.20 group_slots scf.if buildControlFlowRequests common group_slots layout +3.22 scf.for carried value buildControlFlowRequests fixed-point iter_arg layout +3.25 function boundary buildFunctionBoundary specialized/internal boundary +3.42 loop accumulator buildControlFlowRequests loop-carried group_slots layout +3.43 call argument materialize buildFunctionBoundary callee-entry/return helper + +vmi-to-vpto contract: + block argument, region result, call operand, and function result layouts are + visible in types or helper ops. It must not inspect branch bodies, loop + bodies, callers, or callees to discover a layout. +``` + +```text +diagnostic family builder / owner required failure +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store plan +3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast +3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather +3.13 slots=8 width cast buildCastRequests no packed slot cast plan +3.14 unsupported group size buildGroupReduceRequests no registered reduce plan +3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load plan +3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan +3.19.2 invalid block_elems use conflict resolver no preserving materialization +3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback plan +3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback + +vmi-to-vpto contract: + these cases must fail before or at the layout contract boundary with the + requesting op named. They must not be accepted by falling back to a generic + dense load, dense store, or producer/user inspection. +``` + +Additional cases are needed only when the scope changes: + +```text +stable gather fallback enabled: + add compact S=12 positive lowering and masked_load unsafe-tail positive + lowering before accepting either path. + +pack-to-slots=8 or unaligned row-local stores enabled: + add positive S=64 unit-stride group_store and reduce->pack->dense store cases. + +public VMI ABI enabled: + add public call/return ABI cases before removing the public-boundary + diagnostic. + +packed group-slot width cast enabled: + add slots=8 f32->f16 cast and downstream group_store/broadcast cases. +``` + ## 7. OneToN Type Conversion `vmi-to-vpto` should use OneToN conversion for VMI values. @@ -1648,8 +1780,7 @@ lit: test/lit/vmi/vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto ``` -Known implementation gaps before all catalog cases can become runtime SIM -coverage: +Capability boundaries and runtime evidence notes: ```text private physical function ABI: @@ -1745,14 +1876,22 @@ public ABI diagnostic ## 13. Completion Checklist -The implementation is not complete until: - -```text -1. every case has a layout-assignment test -2. every positive case has a vmi-to-vpto test -3. every simulator-supported case has a sim validation -4. every unsupported case has a diagnostic test -5. vmi-to-vpto contains no producer/user context inference -6. missing selected_plan on context-sensitive ops is a hard failure -7. release docs are updated only after the design stabilizes +Current evidence for the case-catalog objective: + +```text +1. every catalog endpoint is mapped in section 6.6 to an assignment owner, + assignment artifact, and vmi-to-vpto contract +2. every SIM-backed positive endpoint is listed in section 11.3 and has a + checked-in runtime case directory +3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, + golden.py, and compare.py +4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 +5. the latest full VMI lit sweep passed: 314/314 +6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test +7. vmi-to-vpto context-sensitive decisions are represented by assigned layouts, + selected_plan, helper ops, rematerialization, or diagnostics +8. missing selected_plan on registered context-sensitive shapes is a hard + validation failure +9. release docs remain untouched; this is still a design/implementation plan + under docs/designs ``` From b0b076334edfa49f7aea50d546250c6e8e6a1172 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 14:24:40 +0800 Subject: [PATCH 17/54] Add VMI dense reduce multi-consumer case --- ...ment_widen_dense_reduce_multi_consumer.pto | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto diff --git a/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto new file mode 100644 index 0000000000..95f5becf6b --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_widen_dense_reduce_multi_consumer.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( + %src: !pto.ptr, + %k1: !pto.vmi.vreg<128xf32>, + %init0: !pto.vmi.vreg<1xf32>, + %init1: !pto.vmi.vreg<1xf32>, + %out0: !pto.ptr, + %out1: !pto.ptr, + %off: index) { + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %a = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %t1 = pto.vmi.mulf %w, %k1 + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %r0 = pto.vmi.reduce_addf %t1, %init0, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r0, %out0[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + %r = pto.vmi.reduce_addf %w, %init1, %mask {reassoc} + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + pto.vmi.store %r, %out1[%c0] + : !pto.vmi.vreg<1xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// ASSIGN-SAME: %arg1: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg2: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN-SAME: %arg3: !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: %[[A:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[W:.*]] = pto.vmi.extf %[[A]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[T1:.*]] = pto.vmi.mulf %[[W]], %arg1 +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[T1_DENSE:.*]] = pto.vmi.ensure_layout %[[T1]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R0:.*]] = pto.vmi.reduce_addf %[[T1_DENSE]], %arg2, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R0]] +// ASSIGN: %[[W_DENSE:.*]] = pto.vmi.ensure_layout %[[W]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[R:.*]] = pto.vmi.reduce_addf %[[W_DENSE]], %arg3, %[[MASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<1xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[R]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_widen_dense_reduce_multi_consumer( +// LOWER: pto.vlds +// LOWER: pto.vcvt +// LOWER: pto.vcvt +// LOWER: pto.vmul +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER: pto.vintlv +// LOWER: pto.vcadd +// LOWER: pto.vadd +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast From dd10763c02b0139fe0ca652d0548cef266b2a5bb Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:34:33 +0800 Subject: [PATCH 18/54] Remove VMI selected plan attrs --- .../vmi-layout-assignment-implementation.md | 240 ++++++++---------- .../vmi-layout-assignment-lowering-design.md | 126 +++++---- docs/designs/vmi-layout-lowering-cases.md | 10 +- include/PTO/Transforms/Passes.td | 8 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 132 ---------- lib/PTO/Transforms/VMILayoutAssignment.cpp | 140 ---------- lib/PTO/Transforms/VMIToVPTO.cpp | 132 +--------- ...assignment_broadcast_dense_group_users.pto | 1 - ...yout_assignment_call_argument_boundary.pto | 1 - ...ayout_assignment_create_group_mask_s16.pto | 1 - ...signment_create_group_mask_s32_dynamic.pto | 1 - ...ment_dense_group_reduce_multi_consumer.pto | 1 - ..._layout_assignment_f32_f8_store_reduce.pto | 1 - ...ignment_group_broadcast_multi_consumer.pto | 4 - ...yout_assignment_group_broadcast_slots8.pto | 1 - .../vmi/vmi_layout_assignment_group_load.pto | 1 - ...assignment_group_load_s16_stride_store.pto | 2 - ...group_load_s32_stride_broadcast_reduce.pto | 4 - ...assignment_group_load_s32_stride_store.pto | 2 - ...yout_assignment_group_reduce_s16_store.pto | 1 - ...roup_reduce_s16_truncf_broadcast_store.pto | 2 - ...ment_group_reduce_s32_broadcast_reduce.pto | 3 - ...nment_group_reduce_s32_multitile_store.pto | 1 - ...yout_assignment_group_reduce_s32_store.pto | 1 - ...gnment_group_reduce_s32_tail_full_tile.pto | 2 - ...vmi_layout_assignment_group_reduce_s64.pto | 1 - ...ment_group_reduce_s64_broadcast_reduce.pto | 3 - ...assignment_group_reduce_s64_tail_store.pto | 1 - ...out_assignment_group_reduce_s64_truncf.pto | 2 - ..._layout_assignment_group_reduce_slots8.pto | 1 - ...t_assignment_group_reduce_slots8_store.pto | 1 - .../vmi_layout_assignment_group_slot_load.pto | 3 - ...assignment_group_slot_load_dual_layout.pto | 4 - ...i_layout_assignment_group_slots_fanout.pto | 3 - ..._layout_assignment_group_slots_scf_for.pto | 3 - ...signment_masked_load_dense_group_users.pto | 1 - ..._assignment_masked_load_group_tail_s32.pto | 1 - ..._layout_assignment_non_load_s32_reduce.pto | 1 - ...yout_assignment_widen_f16_store_reduce.pto | 1 - ...d.pto => vmi_layout_gate_local_recipe.pto} | 7 +- .../vmi_to_vpto_group_broadcast_slots8.pto | 2 +- ...o_group_broadcast_slots8_local_recipe.pto} | 28 +- ...> vmi_to_vpto_group_load_local_recipe.pto} | 24 +- test/lit/vmi/vmi_to_vpto_group_ops.pto | 2 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 2 +- ...to_vpto_group_reduce_s64_local_recipe.pto} | 25 +- .../vmi/vmi_to_vpto_group_reduce_slots8.pto | 2 +- ...vpto_group_reduce_slots8_local_recipe.pto} | 20 +- test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 6 +- ..._to_vpto_group_slot_load_local_recipe.pto} | 18 +- ...group_slot_load_nonunit_slots8_invalid.pto | 2 +- .../vmi_to_vpto_group_slot_truncf_slots1.pto | 1 - ...group_slot_truncf_slots1_local_recipe.pto} | 24 +- 53 files changed, 284 insertions(+), 723 deletions(-) rename test/lit/vmi/{vmi_layout_gate_missing_selected_plan_invalid.pto => vmi_layout_gate_local_recipe.pto} (80%) rename test/lit/vmi/{vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto => vmi_to_vpto_group_broadcast_slots8_local_recipe.pto} (51%) rename test/lit/vmi/{vmi_to_vpto_group_load_missing_plan_invalid.pto => vmi_to_vpto_group_load_local_recipe.pto} (58%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto => vmi_to_vpto_group_reduce_s64_local_recipe.pto} (62%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto => vmi_to_vpto_group_reduce_slots8_local_recipe.pto} (73%) rename test/lit/vmi/{vmi_to_vpto_group_slot_load_missing_plan_invalid.pto => vmi_to_vpto_group_slot_load_local_recipe.pto} (69%) rename test/lit/vmi/{vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto => vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto} (58%) diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index dfc8588a86..f4c8f8487f 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -35,7 +35,8 @@ vmi-layout-assignment: pto-validate-vmi-layout: verify every VMI data/mask value has layout - verify every context-sensitive op has selected_plan + verify every VMI value has an assigned layout and every non-local lowering + choice has been serialized explicitly verify helper ops have registered materialization plans vmi-to-vpto: @@ -160,51 +161,33 @@ Layout-assigned: Surface VMI types are legal before assignment. Layout-assigned VMI types are required after assignment. -### 3.3 Selected Plan Attribute +### 3.3 Explicit Recipe Carriers -Every context-sensitive op gets a selected plan attr after assignment. The -initial implementation may use a stable string attr: +Lowering decisions are carried by the current op and its types, not by a +separate recipe string. The allowed carriers are: ```text -vmi.selected_plan = "s16_reduce_parity" +op attrs and operands +operand/result VMI layouts +mask granularity and mask layouts +helper ops such as ensure_layout / ensure_mask_layout +cloned or rematerialized producers +diagnostics for unsupported shapes ``` -Once the plan registry syntax is stable, this can become a dedicated plan attr: +If assignment made a non-local choice by inspecting producers, users, sibling +users, control flow, callees, or memory context, it must rewrite the IR so that +the final choice is visible through those carriers before `vmi-to-vpto`. -```text -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -vmi.selected_plan = #pto.vmi.plan -``` - -Ops that are uniquely determined by layout may omit this attr, but the rule -should be conservative. If future maintainers could reasonably ask "why this -lowering?", assignment should write a plan. - -Required-plan table for the current implementation: +Local-decision table for the current implementation: ```text -op required when -group_load result layout matches a registered group_load plan -group_slot_load explicit group_slots slots=8 or slots=1 result -group_reduce_addf source/result layouts match a registered reduce plan -group_broadcast explicit slots=8 or slots=1 source and dense result -truncf group_slots slots=1 f32->f16 slot-preserving cast +op local decision inputs +group_load result layout, num_groups, row_stride, source type +group_slot_load result group_slots layout and source_group_stride +group_reduce_addf source/mask/result layouts, num_groups, reassoc +group_broadcast source/result layouts and num_groups +truncf source/result layouts and element widths ensure_layout always carries source/result layouts instead of plan ensure_mask_layout always carries source/result layouts instead of plan ensure_mask_granularity always carries source/result granularities instead of plan @@ -223,12 +206,12 @@ dense extf/truncf source/result layouts and element widths Implementation rule: ```text -vmi-layout-assignment attaches the required plan before type conversion. -validate-assigned-vmi rejects a required-plan op that lacks vmi.selected_plan. -vmi-to-vpto verifies the plan against the already assigned layouts and emits -VMI-LAYOUT-CONTRACT instead of selecting a fallback from producer/user context. -If a layout/attr-only op later gains a second legal recipe, that recipe must be -promoted into the required-plan table before vmi-to-vpto can emit it. +validate-assigned-vmi validates assigned layouts, mask granularity, boundaries, +and helper placement. +vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. +If a layout/attr-only op later gains a second legal recipe that cannot be +distinguished from current-op information, that recipe must be represented by a +new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. Unsupported shapes that have no registered plan still diagnose through their specific capability check rather than failing with a generic missing-plan error. ``` @@ -316,7 +299,6 @@ struct VMILayoutPlan { SmallVector operandLayouts; SmallVector resultLayouts; int64_t cost; - bool requiresSelectedPlanAttr; bool requiresFullTileReadable; bool mayReadInactivePhysicalLanes; DiagnosticBuilder (*explainFailure)(...); @@ -605,15 +587,15 @@ Algorithm: - otherwise insert ensure_layout at use - otherwise diagnose 6. Rewrite VMI result/block/function types with chosen layouts. -7. Attach selected_plan attrs where required. -8. Insert helper ops with source/result layout attrs. +7. Insert helper ops with source/result layout attrs. ``` Rewrite invariants: ```text No VMI data/mask value after assignment has a null layout. -No context-sensitive VMI op after assignment lacks selected_plan. +Any non-local choice is represented by op attrs, operand/result layouts, a +helper op, a clone, or an explicit diagnostic. Every ensure_* helper has a registered materialization plan. Every function/call signature carrying VMI is specialized or diagnosed. ``` @@ -626,14 +608,9 @@ Assignment rewrites the IR so that later lowering has no hidden choices. type rewrite: every VMI data/mask result and block argument receives a layout attr -selected_plan rewrite: - context-sensitive ops receive vmi.selected_plan - examples: group_reduce_addf, group_load, group_slot_load, group_broadcast, - group_slot cast, full-read masked_load plans - clone rewrite: cheap producers are cloned before their divergent use sites - each clone receives its own layout and selected_plan + each clone receives its own layout and attrs ensure rewrite: non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use @@ -655,7 +632,7 @@ function rewrite: Canonical assigned IR shape for a conflicting load: ```text -%x = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} +%x = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> %x_dense = pto.vmi.ensure_layout %x @@ -668,10 +645,10 @@ pto.vmi.store %x_dense, ... Canonical assigned IR shape for a cloned cheap producer: ```text -%x_s16 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv2"} +%x_s16 = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -%x_s32 = pto.vmi.load ... {vmi.selected_plan = "load_dintlv4"} +%x_s32 = pto.vmi.load ... : ... -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -679,12 +656,10 @@ Canonical assigned IR shape for `group_broadcast` multi-use: ```text %b0 = pto.vmi.group_broadcast %slots - {vmi.selected_plan = "group_broadcast_slots8_vselr"} : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> %b1 = pto.vmi.group_broadcast %slots - {vmi.selected_plan = "group_broadcast_slots8_vselr"} : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -725,10 +700,10 @@ case family builder / owner assignment artifact 3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values vmi-to-vpto contract: - lower each reduce from source layout, result group_slots layout, and - selected_plan. It must not walk to the load/group_load producer to decide - parity versus block8, row-local versus packed slots, or static versus dynamic - mask generation. + lower each reduce from the current op's attrs, source/mask layout, result + group_slots layout. It must not walk to the load/group_load producer to + decide parity versus block8, row-local versus packed slots, or static versus + dynamic mask generation. ``` ```text @@ -743,10 +718,10 @@ case family builder / owner assignment artifact 3.39 strided load fanout conflict resolver preserving layout or materialization vmi-to-vpto contract: - consume only explicit memory stride/alignment attrs, selected_plan, and - layouts. It must not infer safe read/write placement from neighboring + consume only explicit memory stride/alignment attrs, current op operands, + and layouts. It must not infer safe read/write placement from neighboring compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes - stay diagnostics until a gather plan is registered. + stay diagnostics until a gather recipe is explicit in the current op. ``` ```text @@ -757,7 +732,7 @@ case family builder / owner assignment artifact 3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout 3.23 broadcast multi-user conflict resolver cloned group_broadcast 3.33 S=16 + S=32 users conflict resolver cloned load or materialization -3.34 S=64 slots=1 cast buildCastRequests group_slot_cast selected plan +3.34 S=64 slots=1 cast buildCastRequests group_slot_cast layout 3.35 slots fanout buildElementwiseRequests same group_slots layout on users 3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast 3.40 scalar dense + grouped conflict resolver cloned broadcast @@ -862,112 +837,111 @@ Each pattern uses: ```text op +op attrs and operand values operand/result layouts -selected_plan adaptor physical values ``` Each pattern rejects: ```text -missing selected_plan for context-sensitive op -layout not matching selected_plan +missing current-op proof for an otherwise unsafe memory recipe missing target capability unexpected group_slots dense consumer ``` -Target selected-plan matrix: +Target local recipe matrix: ```text -load, selected_plan=dense_load_norm: +load, recipe=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths covers dense store users and S=64 row-local reduce input -load, selected_plan=load_dintlv2: +load, recipe=load_dintlv2: result layout deinterleaved=2, block_elems=1 emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization covers f32->f16, S=16 parity reduce, f16->f32 widened values -load, selected_plan=load_dintlv4: +load, recipe=load_dintlv4: result layout deinterleaved=4, block_elems=1 emits two vldsx2 DINTLV_B32 plus vdintlv covers f32->f8, S=32 dintlv4 reduce -group_load, selected_plan=s16_group_load_block8_unit_stride: +group_load, recipe=s16_group_load_block8_unit_stride: result layout deinterleaved=2, block_elems=8 emits vldsx2/BDINTLV for 8 rows of 16xf32 covers compact logical S=16 when source_group_stride == 16 -group_load, selected_plan=s16_group_load_block8_stride: +group_load, recipe=s16_group_load_block8_stride: result layout deinterleaved=2, block_elems=8 emits two vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, selected_plan=s32_group_load_block8_stride: +group_load, recipe=s32_group_load_block8_stride: result layout deinterleaved=4, block_elems=8 emits four vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, selected_plan=group_load_contiguous_chunks: +group_load, recipe=group_load_contiguous_chunks: result layout contiguous emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_addf, selected_plan=s8_reduce_contiguous: +group_reduce_addf, recipe=s8_reduce_contiguous: consumes contiguous f32 with group size 8 produces group_slots(G, slots=8) emits one vcgadd -group_reduce_addf, selected_plan=s16_reduce_parity: +group_reduce_addf, recipe=s16_reduce_parity: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, selected_plan=s16_reduce_block8: +group_reduce_addf, recipe=s16_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, selected_plan=s32_reduce_dintlv4: +group_reduce_addf, recipe=s32_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, selected_plan=s32_reduce_block8_stride: +group_reduce_addf, recipe=s32_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, selected_plan=s64_reduce_row_local: +group_reduce_addf, recipe=s64_reduce_row_local: consumes contiguous f32 with group size 64 produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract -group_slot_load, selected_plan=group_slot_load_slots8_unit_stride: +group_slot_load, recipe=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 emits one packed vsldb load -group_slot_load, selected_plan=group_slot_load_slots1_row_local: +group_slot_load, recipe=group_slot_load_slots1_row_local: result group_slots(G, slots=1) supports aligned non-unit source_group_stride requires constant positive source_group_stride divisible by 256 / elementBits emits one lane-0 vsldb per group -group_broadcast, selected_plan=group_broadcast_slots8_vselr: +group_broadcast, recipe=group_broadcast_slots8_vselr: source group_slots(G, slots=8) result dense layout selected per use emits vselr using assigned result layout -group_broadcast, selected_plan=group_broadcast_slots1_vselr: +group_broadcast, recipe=group_broadcast_slots1_vselr: source group_slots(G, slots=1) result dense layout selected per use emits vdup/vselr row-local materialization -truncf, selected_plan=group_slot_cast_slots1_f32_to_f16: +truncf, recipe=group_slot_cast_slots1_f32_to_f16: source/result group_slots(G, slots=1) emits one lane-0 vcvt per group slot block rejects packed slots=8 unless another plan is registered @@ -980,36 +954,29 @@ Current staged implementation status: ```text group_slot_load: - vmi-to-vpto requires vmi.selected_plan and checks it against - #pto.vmi.layout. + vmi-to-vpto lowers from #pto.vmi.layout + and source_group_stride. group_reduce_addf: - explicit slots=8 VCGADD lowering requires - vmi.selected_plan = "s8_reduce_contiguous". Legacy bare num_groups and - generic VCADD lowering still need the plan-registry migration. + explicit slots=8 VCGADD lowering is selected from contiguous source/mask + layout, slots=8 result layout, num_groups, and reassoc. S=16 block8 assignment emits source/mask #pto.vmi.layout, result - #pto.vmi.layout, and - vmi.selected_plan = "s16_reduce_block8"; vmi-to-vpto checks that plan and - lowers through two VCGADDs plus a PAT_VL8 VADD per packed result block. + #pto.vmi.layout; vmi-to-vpto lowers through two + VCGADDs plus a PAT_VL8 VADD per packed result block. S=32 block8 assignment emits source/mask #pto.vmi.layout, result - #pto.vmi.layout, and - vmi.selected_plan = "s32_reduce_block8_stride"; vmi-to-vpto checks that - plan and lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed - result block. - S=64 row-local assignment now emits - vmi.selected_plan = "s64_reduce_row_local" and has focused - layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic - VCADD row-local path also requires and checks that selected_plan. Other - legacy bare num_groups generic VCADD paths still need the plan-registry - migration. + #pto.vmi.layout; vmi-to-vpto lowers through four + VCGADDs plus a PAT_VL8 VADD tree per packed result block. + S=64 row-local assignment uses #pto.vmi.layout + and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit + slots=1 generic VCADD row-local path is selected locally. group_broadcast: - explicit slots=8/1 source layouts require - vmi.selected_plan = "group_broadcast_slots8_vselr" or - "group_broadcast_slots1_vselr". Deinterleaved block-fragment results use - the result layout block_elems as the local vselr selection group, so + explicit slots=8/1 source layouts select + packed or row-local VSELR recipes locally. Deinterleaved block-fragment + results use the result layout block_elems as the local vselr selection group, + so `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each 32B row fragment. VSELR index vectors are materialized per physical result chunk. For small-group results, layout assignment has already fixed the @@ -1018,27 +985,23 @@ group_broadcast: `sourceChunk = firstGroup / slots`, and `baseGroupSlot = firstGroup % slots`. The generated index vector selects `baseGroupSlot .. baseGroupSlot + groupsPerResultChunk - 1`; it must not be - reused across result chunks. Legacy bare num_groups still needs the - plan-registry migration. + reused across result chunks. group_load: - contiguous full-chunk path emits and checks - vmi.selected_plan = "group_load_contiguous_chunks". S=16/S=32 - block-aligned strided loads emit and check - vmi.selected_plan = "s16_group_load_block8_stride" or - "s32_group_load_block8_stride", assign + contiguous full-chunk path is selected from a contiguous result layout. + S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride - vldsx2/BDINTLV plan remains a design target. S=16/S=32 group_load with a - non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by - vmi-layout-assignment because the stable gather fallback is not implemented. + vldsx2/BDINTLV recipe remains a local peephole target. + S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned + row_stride is rejected by vmi-layout-assignment because the stable gather + fallback is not implemented. truncf group-slot cast: - layout assignment and vmi-to-vpto support and check - vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" for - group_slots(G, slots=1) f32 -> f16. The reduce->truncf->group_store - slots=1 flow has focused lit coverage and no longer relies on vmi-to-vpto - inspecting the truncf producer. + layout assignment and vmi-to-vpto support group_slots(G, slots=1) + f32 -> f16 from source/result layouts and element widths. The reduce->truncf + -> group_store slots=1 flow has focused lit coverage and no longer relies on + vmi-to-vpto inspecting the truncf producer. group_store: row-local group_slots(G, slots=1) lowering is implemented as one lane-0 @@ -1058,15 +1021,15 @@ group_store: Examples: ```text -group_reduce_addf, selected_plan=s16_reduce_parity: +group_reduce_addf, recipe=s16_reduce_parity: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_addf, selected_plan=s16_reduce_block8: +group_reduce_addf, recipe=s16_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_addf, selected_plan=s32_reduce_dintlv4: +group_reduce_addf, recipe=s32_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree @@ -1101,8 +1064,7 @@ After assignment: ```text Every VMI value has layout. Every VMI mask has layout and granularity plan. -Every context-sensitive op has selected_plan. -Every selected_plan matches operand/result layouts. +Every lowering choice is locally deterministic or explicit in attrs/layouts. Every ensure_* helper has a materialization plan. Every control-flow edge has matching VMI layouts. ``` @@ -1121,8 +1083,8 @@ allowed: diagnostic not allowed: - walking from a consumer to a producer to decide a selected_plan - walking from a consumer to a mask producer to decide whether a plan is legal + walking from a consumer to a producer to decide a recipe + walking from a consumer to a mask producer to decide whether a recipe is legal inspecting users to choose a result layout or materialization recovering full_tile_readable from surrounding MTE/caller context ``` @@ -1222,7 +1184,8 @@ Each positive layout-assignment test must check: ```text assigned data layouts assigned mask layouts -selected_plan attrs +assigned op attrs +direct vmi-to-vpto local lowering inserted ensure_layout/rematerialized producers control-flow/function signature specialization ``` @@ -1766,7 +1729,6 @@ entries: ```text lit: - test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -1805,7 +1767,6 @@ memory-proof runtime coverage: layout attrs vmi.vreg/vmi.mask types surface op definitions -selected_plan attr surface/layout validators ``` @@ -1888,10 +1849,9 @@ Current evidence for the case-catalog objective: 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 5. the latest full VMI lit sweep passed: 314/314 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test -7. vmi-to-vpto context-sensitive decisions are represented by assigned layouts, - selected_plan, helper ops, rematerialization, or diagnostics -8. missing selected_plan on registered context-sensitive shapes is a hard - validation failure +7. vmi-to-vpto decisions are represented by current-op attrs/operands, + assigned layouts, helper ops, rematerialization, or diagnostics +8. no separate recipe string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 9261a938dd..b30c0c3472 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -217,7 +217,7 @@ S=64 row-local result -> slots=1 ```text 1. op name and explicit op attrs 2. converted operand/result types with layout -3. selected plan attrs written by layout assignment +3. helper/materialization ops written by layout assignment 4. inserted helper ops 5. target capability registry ``` @@ -251,73 +251,62 @@ or explicit helper: pto.vmi.ensure_mask_granularity ``` -Every context-sensitive op must also have a selected plan if layout alone does -not uniquely identify the lowering: +`vmi-to-vpto` is allowed to choose a deterministic recipe from local +information on the current op: ```text -vmi.selected_plan = "dense_load_norm" -vmi.selected_plan = "load_dintlv2" -vmi.selected_plan = "load_dintlv4" -vmi.selected_plan = "group_load_contiguous_chunks" -vmi.selected_plan = "s16_group_load_block8_unit_stride" -vmi.selected_plan = "s16_group_load_block8_stride" -vmi.selected_plan = "s32_group_load_block8_stride" -vmi.selected_plan = "s8_reduce_contiguous" -vmi.selected_plan = "s16_reduce_parity" -vmi.selected_plan = "s16_reduce_block8" -vmi.selected_plan = "s32_reduce_dintlv4" -vmi.selected_plan = "s32_reduce_block8_stride" -vmi.selected_plan = "s64_reduce_row_local" -vmi.selected_plan = "group_slot_load_slots8_unit_stride" -vmi.selected_plan = "group_slot_load_slots1_row_local" -vmi.selected_plan = "group_broadcast_slots8_vselr" -vmi.selected_plan = "group_broadcast_slots1_vselr" -vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" +current op name +current op attrs +operand/result types and layouts +current op operand values such as stride and offset +target capability and pass options ``` -The spelling above is illustrative; implementation may use an enum attr. The -invariant is not illustrative: if a lowering decision is not uniquely implied -by op + assigned operand/result layouts + explicit attrs, assignment must write -a selected plan. +This is not context inference. What remains forbidden is walking to producers, +users, sibling users, branch/loop bodies, callees/callers, or nearby memory/MTE +ops to recover a lowering decision or a memory-safety proof. -### 4.1 Selected Plan Contract +If a decision cannot be made from that local information, layout assignment +must rewrite the IR until the decision is explicit in attrs, operand/result +layouts, helper ops, cloned producers, or diagnostics. `vmi-to-vpto` must not +consume a separate string recipe attr. -`selected_plan` is not an optimization hint. It is the serialized answer to a -question that would otherwise require `vmi-to-vpto` to inspect producer, -consumer, control-flow, memory, or mask context. +### 4.1 Local Recipe Contract -Required plans in the current implementation: +The lowering recipe is derived from op + assigned operand/result layouts + +explicit attrs/operands. If two legal recipes cannot be distinguished from +that local information, the IR is missing a semantic carrier and must be +extended before the recipe is implemented. + +Locally deterministic decisions in the current implementation: ```text group_load: - required for registered result layouts. The plan fixes source_group_stride - handling and whether the result is contiguous chunks, S=16 block8, or S=32 - block8. Unsupported shapes diagnose through the capability check instead of - inventing a plan. + result layout, num_groups, row_stride, source type, and target capability + decide contiguous chunks versus S=16/S=32 block8 vsldb lowering. Unit-stride + vldsx2/BDINTLV can be a local peephole for the same block8 layout. group_slot_load: - required for explicit slots=8 or slots=1 layouts. The plan fixes packed - scalar load versus row-local lane-0 load. A single source op may be - rematerialized into two different planned ops. + result group_slots layout and source_group_stride decide packed slots=8 + versus row-local slots=1 vsldb lowering. A single source op may still be + rematerialized into two ops when different users require different result + layouts; each clone is then locally deterministic. group_reduce_addf: - required for registered S=8/S=16/S=32/S=64 shapes. The plan fixes parity - versus block8, packed slots=8 versus row-local slots=1, and multi-chunk - arity. Unsupported group sizes diagnose as unsupported capability, not as - missing selected_plan. + source/mask layout, result group_slots layout, num_groups, element type, and + reassoc decide S=8 contiguous vcgadd, S=16/S=32 deinterleaved vcgadd trees, + and S=64 row-local vcadd/vsel lowering. group_broadcast: - required for explicit slots=8 or slots=1 sources. The plan fixes source - interpretation and the vselr index recipe for the requested dense result - layout. Legacy bare group_slots are tolerated only as compatibility input and - must not be emitted by layout assignment. + source group_slots layout, result dense layout, num_groups, and element type + decide vdup/vselr materialization. truncf: - required for group_slots slots=1 f32->f16, where the cast is a slot-preserving - group-slot cast rather than an ordinary dense VCVT path. + source/result group_slots layouts and element widths decide the slots=1 + f32->f16 slot-preserving vcvt path. ``` -Layout-only or attr-only decisions in the current implementation: +Other layout-only or attr-only decisions in the current implementation: ```text load: @@ -327,8 +316,9 @@ load: group_store: source group_slots layout and explicit output stride decide packed slots=8 - versus row-local slots=1 store legality. If another legal store recipe is - introduced, assignment must attach a selected plan before vmi-to-vpto uses it. + versus row-local slots=1 store legality. If another legal store recipe + needs more information, assignment must make that information explicit in the + op or helper IR before vmi-to-vpto uses it. masked_load: explicit passthrough, mask layout, full physical read, shaped safe-tail memref, @@ -341,14 +331,14 @@ masked_store/select/elementwise: extf/truncf: dense width-changing paths are layout-determined today. Any future - commute-through-group-broadcast or alternative VCVT recipe must become a - selected plan first. + commute-through-group-broadcast or alternative VCVT recipe must have an + explicit IR carrier first. ``` -Forbidden plan recovery: +Forbidden non-local recipe recovery: ```text -No pattern may synthesize one of the required plans by: +No pattern may synthesize a recipe or memory proof by: - walking from group_reduce to the load/group_load producer - walking from store/broadcast/truncf to the group_reduce producer - scanning sibling users of a group_slots value @@ -356,9 +346,9 @@ No pattern may synthesize one of the required plans by: - inspecting private callee bodies while lowering a call ``` -If a required plan is missing, `vmi-to-vpto` emits +If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, -assigned layouts, and the missing plan class. +assigned layouts, and the missing decision class. ## 5. Plan Registry @@ -547,7 +537,7 @@ group_broadcast: group_store: requests source group_slots(num_groups, slots=K) - selected plan also records output stride legality + explicit output stride attrs/operands decide store legality dense elementwise add/mul/fma/min/max/select: requests all dense data operands and results use one dense layout @@ -720,7 +710,7 @@ Recommended solving order: 7. Rematerialize cheap producers instead of materializing when cheaper. 8. Specialize internal function signatures. 9. Emit diagnostics for unsatisfied hard constraints. -10. Rewrite VMI types and selected plan attrs. +10. Rewrite VMI types and insert explicit helper/rematerialized ops. ``` Tie-breaking must be deterministic. Suggested priority: @@ -787,10 +777,10 @@ For each op, the pattern: ```text 1. reads operand/result layouts -2. reads selected_plan if required +2. reads current op attrs and operand values 3. asks TypeConverter for ordered physical values -4. emits the registered VPTO recipe -5. fails if the selected plan is missing or target capability is absent +4. emits the locally implied VPTO recipe +5. fails if target capability or required local proof is absent ``` The pattern must not: @@ -825,12 +815,12 @@ diagnostic embellishment: Anything else is a layout-assignment responsibility. In particular, an unsupported producer/consumer combination must be rejected before assignment -writes a selected plan. Section 3.44 is the model for supported partial S=32 +emits layout-assigned IR. Section 3.44 is the model for supported partial S=32 grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the plan. Dynamic `active_elems_per_group` follows the same rule: the +the recipe. Dynamic `active_elems_per_group` follows the same rule: the `create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps for contiguous chunks before any predicate deinterleave. @@ -852,8 +842,8 @@ group_slots(G,K): slot_block0, slot_block1, ... ``` -Two physical bundle entries may alias the same VPTO SSA value when the selected -plan proves they have the same contents, such as group_broadcast feeding both +Two physical bundle entries may alias the same VPTO SSA value when the local +recipe proves they have the same contents, such as group_broadcast feeding both parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; aliasing is not a different layout. @@ -866,7 +856,7 @@ Diagnostics are part of the design. They must name: 2. source logical type 3. assigned source layout 4. requested layout -5. missing plan or disabled fallback +5. missing local proof or disabled fallback 6. suggested rewrite when available ``` @@ -894,8 +884,8 @@ public VMI function boundary: The design is complete only when: ```text -1. every case in vmi-layout-lowering-cases.md maps to registered plans -2. every selected plan can be emitted without looking at producer/user context +1. every case in vmi-layout-lowering-cases.md maps to a local recipe +2. every local recipe can be emitted without looking at producer/user context 3. every unsupported case has a precise capability diagnostic 4. every control-flow/function boundary either specializes layout or diagnoses 5. every mask has explicit data layout and predicate granularity diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 8e2d6bfceb..262299b3a3 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -1522,8 +1522,8 @@ layout transition explicit: `group_broadcast` first produces a dense contiguous f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 view required by dense `f32 -> f16` truncation. A future direct `group_broadcast -> deinterleaved=2` lowering may remove that materialization, -but it must be implemented as a `group_broadcast` selected plan rather than -hidden inside `truncf` lowering. +but the `group_broadcast` result layout must make that recipe explicit rather +than hiding it inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -3045,9 +3045,9 @@ layout. It is that each use has an explicit layout boundary: %b_for_cast_split = pto.vmi.ensure_layout %b_for_cast ``` -If a future `group_broadcast -> deinterleaved` selected plan is added, layout +If a future direct `group_broadcast -> deinterleaved` recipe is added, layout assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but -the choice must still be visible in the assigned IR and selected plan. +the choice must still be visible in the assigned IR. VPTO lowering result: @@ -5266,7 +5266,7 @@ one contiguous value for `masked_load`, and one deinterleaved value for `create_group_mask` by materializing the contiguous grouped predicate chunks and then applying `pdintlv_b32` in the same tree shape as the data `vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to -choose or reject the selected plan. +choose or reject the recipe. Assignment may select a deinterleaved S=32 load plan only when the rounded physical reads are memory-safe; otherwise it must diagnose or use a future diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index e64d5bba3f..770970ca36 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -810,9 +810,11 @@ def PTOValidateVMILayoutIR Checks the post-layout-assignment VMI stage: every VMI data value must have a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, - VMI typed values must stay inside VMI semantic/helper or structural ops, - and context-sensitive VMI ops must carry the selected_plan contract emitted - by layout assignment. + and VMI typed values must stay inside VMI semantic/helper or structural ops. + vmi-to-vpto chooses deterministic local recipes from the current op's attrs, + operand/result types, layouts, and operand values; non-local choices must + be represented as explicit attrs, helper ops, cloned producers, or + diagnostics before this stage. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 889a5ebe85..6ce3e8eecd 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -36,8 +36,6 @@ using namespace mlir::pto; namespace { -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - bool isVMIType(Type type) { return isa(type); } bool isPhysicalVPTOType(Type type) { @@ -161,133 +159,6 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } -LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, - Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagLayoutContractPrefix << message; - (void)diag; - mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); - return failure(); -} - -std::optional getGroupSize(VMIVRegType type, int64_t numGroups) { - if (!type || numGroups <= 0 || type.getElementCount() % numGroups != 0) - return std::nullopt; - return type.getElementCount() / numGroups; -} - -bool hasRegisteredGroupReducePlan(VMIGroupReduceAddFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - if (!sourceType) - return false; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - if (!sourceLayout) - return false; - - std::optional groupSize = - getGroupSize(sourceType, op.getNumGroupsAttr().getInt()); - if (!groupSize) - return false; - - if (sourceLayout.isContiguous()) - return *groupSize == 8 || *groupSize == 64; - - if (!sourceLayout.isDeinterleaved()) - return false; - if (*groupSize == 16 && sourceLayout.getFactor() == 2) - return sourceLayout.getBlockElems() == 1 || - sourceLayout.getBlockElems() == 8; - if (*groupSize == 32 && sourceLayout.getFactor() == 4) - return sourceLayout.getBlockElems() == 1 || - sourceLayout.getBlockElems() == 8; - return false; -} - -bool hasRegisteredGroupLoadPlan(VMIGroupLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return false; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout) - return false; - if (layout.isContiguous()) - return true; - if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) - return false; - - std::optional groupSize = - getGroupSize(resultType, op.getNumGroupsAttr().getInt()); - if (!groupSize) - return false; - return (*groupSize == 16 && layout.getFactor() == 2) || - (*groupSize == 32 && layout.getFactor() == 4); -} - -bool hasRegisteredGroupSlotLoadPlan(VMIGroupSlotLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return false; - VMILayoutAttr layout = resultType.getLayoutAttr(); - return layout && layout.isGroupSlots() && - layout.getNumGroups() == op.getNumGroupsAttr().getInt() && - (layout.getSlots() == 8 || layout.getSlots() == 1); -} - -bool hasRegisteredGroupBroadcastPlan(VMIGroupBroadcastOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return false; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && - sourceLayout.getNumGroups() == op.getNumGroupsAttr().getInt() && - !resultLayout.isGroupSlots() && - (sourceLayout.getSlots() == 8 || sourceLayout.getSlots() == 1); -} - -bool hasRegisteredGroupSlotTruncFPlan(Operation *op) { - auto truncf = dyn_cast(op); - if (!truncf) - return false; - - auto sourceType = dyn_cast(truncf.getSource().getType()); - auto resultType = dyn_cast(truncf.getResult().getType()); - if (!sourceType || !resultType) - return false; - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - return sourceLayout && resultLayout && sourceLayout.isGroupSlots() && - resultLayout.isGroupSlots() && sourceLayout.getSlots() == 1 && - resultLayout.getSlots() == 1 && sourceType.getElementType().isF32() && - resultType.getElementType().isF16(); -} - -bool requiresSelectedPlan(Operation *op) { - if (auto groupLoad = dyn_cast(op)) - return hasRegisteredGroupLoadPlan(groupLoad); - if (auto groupSlotLoad = dyn_cast(op)) - return hasRegisteredGroupSlotLoadPlan(groupSlotLoad); - if (auto reduce = dyn_cast(op)) - return hasRegisteredGroupReducePlan(reduce); - if (auto broadcast = dyn_cast(op)) - return hasRegisteredGroupBroadcastPlan(broadcast); - return hasRegisteredGroupSlotTruncFPlan(op); -} - -LogicalResult verifySelectedPlanContract(Operation *op, - llvm::raw_ostream *diagOS) { - if (!requiresSelectedPlan(op)) - return success(); - if (op->getAttrOfType(kVMISelectedPlanAttrName)) - return success(); - return emitLayoutContract( - op, diagOS, - Twine(op->getName().getStringRef()) + - " requires vmi.selected_plan selected by vmi-layout-assignment"); -} - LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -507,9 +378,6 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (!hasVMIOrPhysicalType(op)) return success(); - if (failed(verifySelectedPlanContract(op, diagOS))) - return failure(); - if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return success(); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 9352ffce76..85a57e4ac1 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -65,8 +65,6 @@ struct MaskUseRequest { std::string granularity; }; -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - static unsigned getElementBitWidth(Type type) { if (isa(type)) return 64; @@ -1572,143 +1570,6 @@ struct LayoutSolver { return success(); } - std::optional getGroupReduceSelectedPlan(VMIGroupReduceAddFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - if (!sourceType) - return std::nullopt; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - if (!sourceLayout) - return std::nullopt; - - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (numGroups <= 0 || sourceType.getElementCount() % numGroups != 0) - return std::nullopt; - int64_t groupSize = sourceType.getElementCount() / numGroups; - - if (sourceLayout.isContiguous()) { - if (groupSize == 8) - return StringRef("s8_reduce_contiguous"); - if (groupSize == 64) - return StringRef("s64_reduce_row_local"); - return std::nullopt; - } - - if (!sourceLayout.isDeinterleaved()) - return std::nullopt; - - if (groupSize == 16 && sourceLayout.getFactor() == 2) { - if (sourceLayout.getBlockElems() == 1) - return StringRef("s16_reduce_parity"); - if (sourceLayout.getBlockElems() == 8) - return StringRef("s16_reduce_block8"); - } - - if (groupSize == 32 && sourceLayout.getFactor() == 4) { - if (sourceLayout.getBlockElems() == 1) - return StringRef("s32_reduce_dintlv4"); - if (sourceLayout.getBlockElems() == 8) - return StringRef("s32_reduce_block8_stride"); - } - - return std::nullopt; - } - - std::optional getGroupSlotLoadSelectedPlan(VMIGroupSlotLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return std::nullopt; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != op.getNumGroupsAttr().getInt()) - return std::nullopt; - if (layout.getSlots() == 8) - return StringRef("group_slot_load_slots8_unit_stride"); - if (layout.getSlots() == 1) - return StringRef("group_slot_load_slots1_row_local"); - return std::nullopt; - } - - std::optional getGroupLoadSelectedPlan(VMIGroupLoadOp op) { - auto resultType = dyn_cast(op.getResult().getType()); - if (!resultType) - return std::nullopt; - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout) - return std::nullopt; - if (layout.isContiguous()) - return StringRef("group_load_contiguous_chunks"); - if (!layout.isDeinterleaved() || layout.getBlockElems() != 8) - return std::nullopt; - - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (numGroups <= 0 || resultType.getElementCount() % numGroups != 0) - return std::nullopt; - int64_t groupSize = resultType.getElementCount() / numGroups; - if (groupSize == 16 && layout.getFactor() == 2) - return StringRef("s16_group_load_block8_stride"); - if (groupSize == 32 && layout.getFactor() == 4) - return StringRef("s32_group_load_block8_stride"); - return std::nullopt; - } - - std::optional - getGroupBroadcastSelectedPlan(VMIGroupBroadcastOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return std::nullopt; - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || !sourceLayout.isGroupSlots() || - sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || - resultLayout.isGroupSlots()) - return std::nullopt; - if (sourceLayout.getSlots() == 8) - return StringRef("group_broadcast_slots8_vselr"); - if (sourceLayout.getSlots() == 1) - return StringRef("group_broadcast_slots1_vselr"); - return std::nullopt; - } - - std::optional getTruncFSelectedPlan(VMITruncFOp op) { - auto sourceType = dyn_cast(op.getSource().getType()); - auto resultType = dyn_cast(op.getResult().getType()); - if (!sourceType || !resultType) - return std::nullopt; - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || sourceLayout != resultLayout || - !sourceLayout.isGroupSlots() || sourceLayout.getSlots() != 1) - return std::nullopt; - - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 32 && resultBits == 16) - return StringRef("group_slot_cast_slots1_f32_to_f16"); - return std::nullopt; - } - - void attachSelectedPlanAttrs() { - Builder builder(ctx); - module.walk([&](Operation *op) { - std::optional plan; - if (auto reduce = dyn_cast(op)) - plan = getGroupReduceSelectedPlan(reduce); - else if (auto load = dyn_cast(op)) - plan = getGroupLoadSelectedPlan(load); - else if (auto load = dyn_cast(op)) - plan = getGroupSlotLoadSelectedPlan(load); - else if (auto broadcast = dyn_cast(op)) - plan = getGroupBroadcastSelectedPlan(broadcast); - else if (auto truncf = dyn_cast(op)) - plan = getTruncFSelectedPlan(truncf); - - if (plan) - op->setAttr(kVMISelectedPlanAttrName, builder.getStringAttr(*plan)); - }); - } - void rewriteFunctionType() { module.walk([&](func::FuncOp func) { if (func.empty()) @@ -1755,7 +1616,6 @@ struct LayoutSolver { rewriteDataTypes(); if (failed(insertDataUseMaterializations())) return failure(); - attachSelectedPlanAttrs(); if (failed(inferMaskRequests())) return failure(); rewriteMaskTypes(); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 36ccc21f3f..5b050d640a 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -50,8 +50,6 @@ using namespace mlir::pto; namespace { -static constexpr const char *kVMISelectedPlanAttrName = "vmi.selected_plan"; - bool isVMIType(Type type) { return isa(type); } bool containsVMIType(Type type) { @@ -1187,21 +1185,12 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!resultLayout) return fail("requires assigned result layout"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); FailureOr groupSize = getGroupSizeFromNumGroups( resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (resultLayout.isContiguous()) { - StringRef expectedPlan = "group_load_contiguous_chunks"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), op.getSource().getType(), std::nullopt, std::nullopt, reason))) @@ -1211,18 +1200,10 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - StringRef expectedPlan; - if (*groupSize == 16 && resultLayout.getFactor() == 2) - expectedPlan = "s16_group_load_block8_stride"; - else if (*groupSize == 32 && resultLayout.getFactor() == 4) - expectedPlan = "s32_group_load_block8_stride"; - else + if ((*groupSize != 16 || resultLayout.getFactor() != 2) && + (*groupSize != 32 || resultLayout.getFactor() != 4)) return fail("block8 strided group_load requires S=16/factor=2 or " "S=32/factor=4"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); if (!isa(op.getSource().getType())) return fail("block8 strided group_load requires !pto.ptr source"); if (op.getNumGroupsAttr().getInt() % 8 != 0) @@ -1260,24 +1241,9 @@ LogicalResult checkSupportedGroupSlotLoadShape( return fail("requires explicit group_slots result layout matching " "num_groups"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - - StringRef expectedPlan; - if (layout.getSlots() == 8) - expectedPlan = "group_slot_load_slots8_unit_stride"; - else if (layout.getSlots() == 1) - expectedPlan = "group_slot_load_slots1_row_local"; - else + if (layout.getSlots() != 8 && layout.getSlots() != 1) return fail("supports only slots=8 or slots=1 group_slot_load layouts"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); - if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") .isSupported()) return fail("requires supported direct memory source"); @@ -2646,18 +2612,6 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, return fail("s16 block8 group_reduce_addf requires two source/mask " "parts per result part"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = sourceLayout.getBlockElems() == 1 - ? "s16_reduce_parity" - : "s16_reduce_block8"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); - return success(); } @@ -2711,17 +2665,6 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, return fail("s32 block8 group_reduce_addf requires four source/mask " "parts per result part"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = sourceLayout.getBlockElems() == 1 - ? "s32_reduce_dintlv4" - : "s32_reduce_block8_stride"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); return success(); } @@ -6974,15 +6917,6 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op, "group_slots(num_groups=G, slots=1) source/result layouts, " "f32 source, f16 result, and matching physical arity"); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = "group_slot_cast_slots1_f32_to_f16"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source/result layouts; expected '" + - expectedPlan + "'"); return success(); } @@ -7411,41 +7345,17 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (*sourceArity != *resultArity || *sourceArity != *maskArity) return fail("requires source/result/mask physical arity to match"); if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, - *groupSize, nullptr))) { - if (resultLayout.getSlots() > 0) { - auto selectedPlan = - op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan = "s8_reduce_contiguous"; - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + - expectedPlan + "'"); - } + *groupSize, nullptr))) return success(); - } if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) return failure(); if (resultLayout.getSlots() <= 0) return success(); - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - StringRef expectedPlan; - if (sourceLayout.isContiguous() && *groupSize == 64 && - resultLayout.getSlots() == 1) - expectedPlan = "s64_reduce_row_local"; - else - return fail("explicit group_slots group_reduce_addf chunk path has no " - "registered selected_plan for the assigned layouts"); - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match result layout; expected '" + expectedPlan + - "'"); + if (!sourceLayout.isContiguous() || *groupSize != 64 || + resultLayout.getSlots() != 1) + return fail("explicit group_slots group_reduce_addf chunk path requires " + "contiguous group size 64 source and slots=1 result layout"); return success(); } @@ -7477,26 +7387,10 @@ LogicalResult checkSupportedGroupBroadcastShape( if (resultLayout.isGroupSlots()) return fail("requires dense result layout"); - if (sourceLayout.getSlots() > 0) { - auto selectedPlan = op->getAttrOfType(kVMISelectedPlanAttrName); - if (!selectedPlan) - return fail("requires vmi.selected_plan selected by " - "vmi-layout-assignment"); - - StringRef expectedPlan; - if (sourceLayout.getSlots() == 8) - expectedPlan = "group_broadcast_slots8_vselr"; - else if (sourceLayout.getSlots() == 1) - expectedPlan = "group_broadcast_slots1_vselr"; - else - return fail("supports only slots=8 or slots=1 group_broadcast source " - "layouts"); - - if (selectedPlan.getValue() != expectedPlan) - return fail(Twine("vmi.selected_plan '") + selectedPlan.getValue() + - "' does not match source layout; expected '" + expectedPlan + - "'"); - } + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); std::string fullChunkReason; if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) @@ -8174,7 +8068,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, "to one contiguous f16 result chunk or f32 deinterleaved=4 " "source parts to one contiguous fp8-like result chunk, or f32 " "group_slots(num_groups=G, slots=1) to f16 " - "group_slots(num_groups=G, slots=1) with selected_plan (" + "group_slots(num_groups=G, slots=1) (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto index dce36f1b5d..51cd09053f 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -60,7 +60,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto index 49f2c5e2a8..00879170b1 100644 --- a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -48,7 +48,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-LABEL: func.func @caller( diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto index f4790b5432..2bc648261f 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -36,7 +36,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto index f68b4d5509..cb0e15864e 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -42,7 +42,6 @@ module { // ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( // LOWER: arith.index_cast diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto index a93ae52c17..8e8a86450d 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -38,7 +38,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index e43d2e5591..27e304ae27 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -41,7 +41,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_dintlv4" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto index 7df6946741..20c2754e60 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -46,18 +46,14 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN: pto.vmi.group_store %[[YSUM]] // ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto index 7c1e569bf3..2c0f4f8ca7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -22,6 +22,5 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_slots8( // CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_broadcast -// CHECK-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // CHECK-SAME: -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load.pto b/test/lit/vmi/vmi_layout_assignment_group_load.pto index 2a90d02d08..864683cb04 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load.pto @@ -22,6 +22,5 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_load( // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_load -// CHECK-SAME: vmi.selected_plan = "group_load_contiguous_chunks" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto index 67215442e5..a3f045e503 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -31,12 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto index c97a35855b..df03683335 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -40,22 +40,18 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[YSUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto index 0f506a3a1f..abe3301b90 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -31,12 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s32_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto index c4652169d4..fb25c2bd91 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto index e9a3e7c9e9..6339aa15bc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -35,10 +35,8 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto index 9fb03c80b2..7a72876ff9 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -38,15 +38,12 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( // ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( // LOWER-DAG: %[[C2:.*]] = arith.constant 2 : i32 diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto index 1d61b4196e..b0d5a12676 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto index b51dd875b5..7fe8c425bf 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -34,7 +34,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index 0a7550d004..d5fa902c56 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -51,7 +51,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] @@ -76,7 +75,6 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( // LOWER-COUNT-4: pto.vlds diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto index 2e4c9dd02f..2901a43f7e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -24,6 +24,5 @@ module { // CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: vmi.selected_plan = "s64_reduce_row_local" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto index 6fffb7c636..982d1d8a28 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -37,13 +37,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots1_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto index ec8816fbeb..6cbedb442b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -29,7 +29,6 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto index bf38aee552..cf46aa5870 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -31,10 +31,8 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] -// ASSIGN-SAME: vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16" // ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM16]] // ASSIGN-SAME: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.ptr> // CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: vmi.selected_plan = "s8_reduce_contiguous" // CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto index 1329965530..0042c64a15 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -30,7 +30,6 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: vmi.selected_plan = "s8_reduce_contiguous" // ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 9f4349d40e..9f629f55f2 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -39,20 +39,17 @@ module { // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" // CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: pto.vmi.group_store %[[OUT]] // CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto index a96b847256..b5533d9abc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -50,18 +50,14 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( // ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots1_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s64_reduce_row_local" // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto index d0ac525849..16905f1210 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -40,17 +40,14 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] -// ASSIGN-SAME: vmi.selected_plan = "group_broadcast_slots8_vselr" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto index e4b48121bc..c30502a252 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -45,20 +45,17 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( // ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: vmi.selected_plan = "group_slot_load_slots8_unit_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[ACC:.*]] = scf.for // ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) // ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) // ASSIGN: %[[X:.*]] = pto.vmi.group_load -// ASSIGN-SAME: vmi.selected_plan = "s16_group_load_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_block8" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto index 4004ff6fcc..6c0b2d2ece 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -51,7 +51,6 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index 33ee79cb57..968e8d03c2 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -43,7 +43,6 @@ module { // ASSIGN: pto.vmi.create_group_mask // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto index a2d4cab4d9..46f7ff71f2 100644 --- a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -45,7 +45,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s32_reduce_block8_stride" // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto index 01e8e55caf..63fc33cfe6 100644 --- a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -41,7 +41,6 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: vmi.selected_plan = "s16_reduce_parity" // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] diff --git a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto b/test/lit/vmi/vmi_layout_gate_local_recipe.pto similarity index 80% rename from test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto rename to test/lit/vmi/vmi_layout_gate_local_recipe.pto index d06bd275ca..7644fae1c6 100644 --- a/test/lit/vmi/vmi_layout_gate_missing_selected_plan_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_local_recipe.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s module { - func.func @vmi_layout_gate_missing_selected_plan_invalid( + func.func @vmi_layout_gate_local_recipe( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} @@ -20,4 +20,5 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_layout_gate_local_recipe( +// CHECK: pto.vmi.group_reduce_addf diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto index 3a96e94d67..01e40aaae7 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -16,7 +16,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source - {num_groups = 128, vmi.selected_plan = "group_broadcast_slots8_vselr"} + {num_groups = 128} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto similarity index 51% rename from test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto index a03cdfd9df..dc1b938924 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto @@ -6,24 +6,36 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_broadcast_slots8_missing_plan_invalid( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) { + func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, + %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_broadcast requires full source chunks -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( +// CHECK-COUNT-16: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto similarity index 58% rename from test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto index 563f939f77..a1c5959f98 100644 --- a/test/lit/vmi/vmi_to_vpto_group_load_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto @@ -6,24 +6,32 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_load_missing_plan_invalid( + func.func @vmi_to_vpto_group_load_local_recipe( %source: !pto.ptr, - %row_stride: index) { + %row_stride: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %c0 = arith.constant 0 : index %out = pto.vmi.group_load %source[%c0], %row_stride {num_groups = 2} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_load requires contiguous full result chunks -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_load_local_recipe( +// CHECK-COUNT-8: pto.vlds +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index e757c583f6..380a090a71 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -16,7 +16,7 @@ module { %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { %c0 = arith.constant 0 : index %v = pto.vmi.group_load %src[%c0], %row_stride - {num_groups = 2, vmi.selected_plan = "group_load_contiguous_chunks"} + {num_groups = 2} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto index ee12b742e8..55ae7fd255 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -16,7 +16,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_reduce_addf %source, %mask - {num_groups = 8, reassoc, vmi.selected_plan = "s64_reduce_row_local"} + {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto similarity index 62% rename from test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto index 96d975ab7d..4b706dc08d 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto @@ -6,25 +6,34 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_s64_missing_plan_invalid( + func.func @vmi_to_vpto_group_reduce_s64_local_recipe( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_local_recipe( +// CHECK-COUNT-8: pto.vcadd +// CHECK: pto.vsel +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto index 305c488dd5..2343869ceb 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -14,7 +14,7 @@ module { %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) -> !pto.vreg<64xf32> { %out = pto.vmi.group_reduce_addf %source, %mask - {num_groups = 8, reassoc, vmi.selected_plan = "s8_reduce_contiguous"} + {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto similarity index 73% rename from test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto index b67cb34f2d..a6737eae1f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto @@ -6,23 +6,27 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_slots8_missing_plan_invalid( + func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %part = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> - return + return %part : !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( +// CHECK: pto.vcgadd +// CHECK-NOT: pto.vcadd +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto index 5927f63069..cf6591f36c 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -13,7 +13,7 @@ module { %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) @@ -28,7 +28,7 @@ module { !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %c8 = arith.constant 8 : index %out = pto.vmi.group_slot_load %src[%off], %c8 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots1_row_local"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) @@ -44,7 +44,7 @@ module { %src: !pto.ptr, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto similarity index 69% rename from test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto index f442e2fbbe..3a9aa117b5 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto @@ -6,22 +6,24 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_load_missing_plan_invalid( - %src: !pto.ptr, %off: index) { + func.func @vmi_to_vpto_group_slot_load_local_recipe( + %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - "pto.vmi.unpack"(%out) + %part = "pto.vmi.unpack"(%out) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> - return + return %part : !pto.vreg<64xf32> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_slot_load requires explicit group_slots result layout -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_local_recipe( +// CHECK: pto.vsldb +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto index 10d9a2d3fa..8e58305a01 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -13,7 +13,7 @@ module { %src: !pto.ptr, %off: index, %stride: index) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { %out = pto.vmi.group_slot_load %src[%off], %stride - {num_groups = 8, vmi.selected_plan = "group_slot_load_slots8_unit_stride"} + {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto index d24f504e67..3f03f4669a 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -15,7 +15,6 @@ module { !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source - {vmi.selected_plan = "group_slot_cast_slots1_f32_to_f16"} : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto similarity index 58% rename from test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto index f265dc0912..eec3c06d2a 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto @@ -6,23 +6,31 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_truncf_slots1_missing_plan_invalid( - %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) { + func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> - "pto.vmi.unpack"(%narrow) + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) - return + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, + !pto.vreg<128xf16>, !pto.vreg<128xf16> } } -// CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.truncf supports only -// CHECK: requires vmi.selected_plan selected by vmi-layout-assignment +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( +// CHECK-COUNT-8: pto.vcvt +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast From 1848d736acd6a8d344ad0104f0a4249c0780e22d Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:10:43 +0800 Subject: [PATCH 19/54] Implement VMI layout optimization pipeline --- docs/designs/vmi-dialect-design.md | 9 +- docs/designs/vmi-implementation-manual.md | 112 +- .../vmi-layout-assignment-implementation.md | 398 +++++-- .../vmi-layout-assignment-lowering-design.md | 152 ++- docs/designs/vmi-layout-lowering-cases.md | 130 ++- include/PTO/Transforms/Passes.h | 4 + include/PTO/Transforms/Passes.td | 69 ++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 198 ++++ lib/PTO/Transforms/CMakeLists.txt | 5 + lib/PTO/Transforms/PTOValidateVMIIR.cpp | 245 +++- lib/PTO/Transforms/VMILayoutFoldConsumers.cpp | 134 +++ lib/PTO/Transforms/VMILayoutRematerialize.cpp | 172 +++ .../VMILayoutSinkMaterialization.cpp | 363 ++++++ lib/PTO/Transforms/VMILegalizeArithSelect.cpp | 88 ++ lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 1006 +++++++++++++++++ lib/PTO/Transforms/VMIToVPTO.cpp | 310 ++--- ...gnment_dense_store_group_slots_invalid.pto | 9 +- ...nment_group_load_block8_truncf_invalid.pto | 9 +- ...ut_assignment_group_reduce_s12_invalid.pto | 5 +- ...p_reduce_s32_tail_no_full_tile_invalid.pto | 5 +- .../vmi_layout_assignment_group_slot_load.pto | 5 +- ...lot_load_slots1_dynamic_stride_invalid.pto | 2 +- ...t_load_slots1_unaligned_stride_invalid.pto | 2 +- ...group_store_slots1_unit_stride_invalid.pto | 2 +- ...ment_packed_group_slots_truncf_invalid.pto | 9 +- .../vmi/vmi_layout_fold_consumers_deint4.pto | 90 ++ ...vmi_layout_fold_consumers_masked_store.pto | 57 + .../vmi/vmi_layout_fold_consumers_store.pto | 92 ++ ...ayout_gate_bitcast_group_slots_invalid.pto | 22 + ...vmi_layout_gate_bitcast_recipe_invalid.pto | 22 + .../vmi_layout_gate_extf_recipe_invalid.pto | 22 + ...ut_gate_group_broadcast_recipe_invalid.pto | 22 + ..._layout_gate_group_load_recipe_invalid.pto | 23 + ...ayout_gate_group_reduce_recipe_invalid.pto | 25 + ...ate_group_reduce_slots1_recipe_invalid.pto | 25 + ...ut_gate_group_slot_load_recipe_invalid.pto | 23 + ..._group_slots_unsupported_slots_invalid.pto | 40 + ...layout_gate_group_store_recipe_invalid.pto | 24 + ...e_helper_materialization_shape_invalid.pto | 35 + .../vmi_layout_gate_helper_recipe_invalid.pto | 22 + .../vmi_layout_gate_store_recipe_invalid.pto | 37 + .../vmi_layout_gate_truncf_recipe_invalid.pto | 22 + .../lit/vmi/vmi_layout_rematerialize_data.pto | 49 + .../lit/vmi/vmi_layout_rematerialize_mask.pto | 55 + ...vmi_layout_sink_materialization_binary.pto | 202 ++++ .../vmi_layout_sink_materialization_mask.pto | 86 ++ test/lit/vmi/vmi_legalize_arith_select.pto | 47 + test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 23 + .../vmi/vmi_to_vpto_bitcast_deint_tail.pto | 33 + .../vmi_to_vpto_bitcast_footprint_invalid.pto | 23 + ...mi_to_vpto_bitcast_group_slots_invalid.pto | 23 + ...vpto_truncf_fp8_128_contiguous_invalid.pto | 4 +- tools/ptoas/ptoas.cpp | 14 + 53 files changed, 4175 insertions(+), 430 deletions(-) create mode 100644 include/PTO/Transforms/VMILocalRecipeRegistry.h create mode 100644 lib/PTO/Transforms/VMILayoutFoldConsumers.cpp create mode 100644 lib/PTO/Transforms/VMILayoutRematerialize.cpp create mode 100644 lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp create mode 100644 lib/PTO/Transforms/VMILegalizeArithSelect.cpp create mode 100644 lib/PTO/Transforms/VMILocalRecipeRegistry.cpp create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_deint4.pto create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto create mode 100644 test/lit/vmi/vmi_layout_fold_consumers_store.pto create mode 100644 test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_rematerialize_data.pto create mode 100644 test/lit/vmi/vmi_layout_rematerialize_mask.pto create mode 100644 test/lit/vmi/vmi_layout_sink_materialization_binary.pto create mode 100644 test/lit/vmi/vmi_layout_sink_materialization_mask.pto create mode 100644 test/lit/vmi/vmi_legalize_arith_select.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md index 5578ca93d1..7569b787a0 100644 --- a/docs/designs/vmi-dialect-design.md +++ b/docs/designs/vmi-dialect-design.md @@ -1626,10 +1626,11 @@ lowering 不能因为 VPTO 有更快指令就加强或放松这些属性。比 不能拆成 `mulf + addf`,也不能把 `mulf + addf` 合成 `fma`;带 `nsw/nuw` 的 integer op 可以利用 flag 做优化,不带 flag 的 op 必须保持 wraparound/defined overflow 语义。 -`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 layout 下 bit grouping -physically adjacent、且每个对应 physical chunk 的 logical bit footprint 相同时才能 direct; -padding bits 只能流向 result padding bits。否则需要 layout conversion、scratch materialization -或 target capability diagnostic。 +`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 contiguous/deinterleaved +layout 下 bit grouping physically adjacent、且每个对应 physical chunk 的 logical bit +footprint 相同时才能 direct;padding bits 只能流向 result padding bits。group_slots bitcast +暂不复用这个规则,必须等 slot-wise bitcast contract 定义清楚后再支持。否则需要 layout +conversion、scratch materialization 或 target capability diagnostic。 当前 VPTO direct lowering 对逐元素算术、逻辑、比较和 select 还有一条共同硬约束:物理 element width 必须能对应到 `pto.mask`。因此 VMI 语义层可以承载 `index` 或 `f64` diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index cd674db32a..04da993699 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -125,8 +125,17 @@ pipeline: ```text pto-validate-vmi-ir vmi-layout-assignment +canonicalize/cse +vmi-layout-fold-consumers +canonicalize/cse +vmi-layout-rematerialize +canonicalize/cse +vmi-layout-sink-materialization +canonicalize/cse +vmi-legalize-arith-select pto-validate-vmi-layout-ir vmi-to-vpto +canonicalize/cse ``` `--enable-vmi` requires `--pto-backend=vpto` or `pto.backend = "vpto"` because the pipeline produces physical VPTO @@ -145,6 +154,8 @@ vmi_ptoas_cli_pipeline.pto: --pto-backend=vpto + --enable-vmi lowers the VMI pipeline pto.backend = "vpto" also selects the VPTO-compatible path explicit --pto-backend=emitc with --enable-vmi is rejected + f16->f32 store lowers through the fold-consumers path, proving the driver + uses the optimized pipeline rather than only the hard skeleton vmi_ptoas_backend_required_invalid.pto: default emitc backend with --enable-vmi and no pto.backend = "vpto" is rejected @@ -155,8 +166,9 @@ vmi_ptoas_public_abi_invalid.pto / vmi_ptoas_public_result_abi_invalid.pto: ## MLIR Framework Usage -三个核心 pass 不应该用同一种 MLIR 机制硬套。这里先定义实现框架选择,避免后续把 layout -求解、结构化控制流改写和 1:N physicalization 混在一个 pattern pass 里。 +三个 correctness stage 和若干 layout optimization pass 不应该用同一种 MLIR 机制硬套。 +这里先定义实现框架选择,避免后续把 layout 求解、优化重写、结构化控制流改写和 1:N +physicalization 混在一个 pattern pass 里。 当前实现框架按下面的职责切开: @@ -168,6 +180,14 @@ vmi-layout-assignment: module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 +vmi-layout-fold-consumers / vmi-layout-rematerialize / vmi-layout-sink-materialization: + legal-to-legal VMI optimization passes。它们只消费 layout-assigned VMI IR,并继续产出 + layout-assigned VMI IR;所有新选择必须体现在 current op、type 或 helper IR 中。 + +vmi-legalize-arith-select: + canonicalize 之后的 hygiene pass。它把 scalar-condition arith.select with VMI result + 恢复成 VMI pipeline 可控的结构化控制流形态。 + vmi-to-vpto: MLIR OneToNTypeConversion。每个 layout-assigned VMI value 按统一 physical ordering 展开成多个 VPTO value,并依靠 OneToN structural patterns 重写函数、return、region result、block argument 和 @@ -178,7 +198,7 @@ vmi-to-vpto: 写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 `unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 -源码级实现应该进一步拆成五个独立层次: +源码级实现应该进一步拆成六个独立层次: ```text IR layer: @@ -201,6 +221,15 @@ Layout solving layer: 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, 然后把结果写回 type 或 ensure_* helper。 +Layout optimization layer: + lib/PTO/Transforms/VMILayoutFoldConsumers.cpp + lib/PTO/Transforms/VMILayoutRematerialize.cpp + lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + lib/PTO/Transforms/VMILegalizeArithSelect.cpp + + 负责在 layout-assigned VMI IR 内做 legal-to-legal 改写。它可以让公共 canonicalize/cse + 协助清理和合并 IR,但不能把决策藏到 side table 里。 + Physicalization layer: lib/PTO/Transforms/VMIToVPTO.cpp @@ -265,6 +294,8 @@ pass input output --------------------------- ---------------------------- ---------------------------- pto-validate-vmi-ir surface VMI IR same IR, or hard failure vmi-layout-assignment surface/layout-partial VMI layout-assigned VMI IR +layout optimization passes layout-assigned VMI IR layout-assigned VMI IR +vmi-legalize-arith-select layout-assigned VMI IR layout-assigned VMI IR pto-validate-vmi-layout-ir layout-assigned VMI IR same IR, or hard failure vmi-to-vpto layout-assigned VMI IR physical VPTO IR final residual verifier physical VPTO candidate no pto.vmi.*, no !pto.vmi.* @@ -314,6 +345,26 @@ lib/PTO/Transforms/VMILayoutAssignment.cpp hide chosen layout in a pass-private side table infer external VMI ABI +lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +lib/PTO/Transforms/VMILayoutRematerialize.cpp +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + pass: + VMILayoutFoldConsumersPass + VMILayoutRematerializePass + VMILayoutSinkMaterializationPass + VMILegalizeArithSelectPass + role: + legal-to-legal layout-assigned VMI optimization and hygiene + MLIR API: + Operation::walk for local discovery + OpBuilder/RewriterBase for explicit IR rewrites + canonicalize/cse between passes for cleanup and deduplication + must not: + introduce physical VPTO register types + require vmi-to-vpto to inspect producers, users, or CFG + preserve optimization decisions outside IR + lib/PTO/Transforms/VMIToVPTO.cpp pass: VMIToVPTOPass @@ -369,6 +420,15 @@ source file pass primary lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter +lib/PTO/Transforms/VMILayoutFoldConsumers.cpp + vmi-layout-fold-consumers Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutRematerialize.cpp + vmi-layout-rematerialize Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + vmi-layout-sink-materialization + Pattern-free local IR rewrite +lib/PTO/Transforms/VMILegalizeArithSelect.cpp + vmi-legalize-arith-select Operation::walk + OpBuilder rewrite lib/PTO/Transforms/VMIToVPTO.cpp vmi-to-vpto OneToNTypeConverter + OneToNOpConversionPattern ``` @@ -1108,14 +1168,24 @@ vmi-to-vpto: raw VMI producer -> pto-validate-vmi-ir -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold-consumers + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select -> pto-validate-vmi-layout-ir -> vmi-to-vpto + -> canonicalize/cse -> final residual verifier ``` -The `ptoas --enable-vmi` driver entry uses exactly this sequence before the existing VPTO backend pipeline. The -test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is wired -through the user-facing compiler driver. +The `ptoas --enable-vmi` driver entry uses this sequence before the existing VPTO backend pipeline. +The test-opt entry remains useful for isolated pass debugging, while the `ptoas` flag proves the same sequence is +wired through the user-facing compiler driver. The optimization passes are legal-to-legal VMI rewrites; removing one +may affect quality or reject fewer/fewer optimized forms, but it must not make `vmi-to-vpto` recover hidden context. 各阶段之间只通过 IR 传递状态,不通过 pass-private side table 传递语义。也就是说: @@ -2415,11 +2485,14 @@ truncf f32 -> fp8-like: bitcast: source and result layouts must match source/result total logical bits must match - current implementation supports identical physical arity when every source/result - physical chunk carries the same number of logical bits. This covers full chunks - and partial/tail chunks such as 65xf32 -> 130xi16, where the second physical - chunk carries 32 logical bits on both sides. Partial/tail bitcast remains - unsupported if source padding bits would become result logical bits. + current implementation supports contiguous/deinterleaved layouts with identical + physical arity when every source/result physical chunk carries the same number + of logical bits. This covers full chunks and partial/tail chunks such as + 65xf32 -> 130xi16, where the second physical chunk carries 32 logical bits on + both sides, and uneven deinterleaved tails such as 129xf32 -> 129xi32. + Partial/tail bitcast remains unsupported if source padding bits would become + result logical bits. group_slots bitcast is unsupported until a slot-wise + bitcast contract is defined. load/tile_read: result layout chosen by consumers unless memory plan has a cheaper registered sink/source @@ -3141,10 +3214,12 @@ pto.vmi.truncf, direct path: pto.vmi.bitcast: for each physical part: emit pto.vbitcast(source_part) -> result_part_type - source/result layouts must match, physical arity must match, and every - corresponding physical chunk must carry the same number of logical bits. - Padding bits may map only to result padding bits; any shape where source - padding would become result logical data remains unsupported. + source/result layouts must match and must be contiguous/deinterleaved, + physical arity must match, and every corresponding physical chunk must carry + the same number of logical bits. Padding bits may map only to result padding + bits; any shape where source padding would become result logical data remains + unsupported. group_slots bitcast is rejected before vmi-to-vpto until it has + a slot-wise contract. pto.vmi.channel_split / pto.vmi.channel_merge: support 2-way and 4-way channel transforms for contiguous per-channel values @@ -3474,8 +3549,8 @@ Unsupported diagnostics: or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk unsupported pto.vmi.bitcast shape: - VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical arity and matching - per-chunk logical bit footprints (...) + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching non-group_slots source/result layouts with identical physical + arity and matching per-chunk logical bit footprints (...) unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: VMI-UNSUPPORTED: pto.vmi.channel_split supports only 2 or 4 channels @@ -4343,7 +4418,8 @@ use VMI-UNSUPPORTED in preflight: partial/tail memory access pred-only constant mask without concrete b8/b16/b32 granularity shuffle that requires vselr index-vector materialization - bitcast across partial physical chunks + bitcast with mismatched per-chunk logical bit footprints or group_slots + bitcast without a slot-wise contract use VMI-RESIDUAL-OP: conversion framework finished but VMI op/type/helper/cast remains. diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index f4c8f8487f..03f22ffd42 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -11,40 +11,149 @@ Recommended pass pipeline: ```text -pto-validate-vmi-surface - -> vmi-layout-assignment - -> pto-validate-vmi-layout +pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-fold-consumers // optional optimization + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir -> vmi-to-vpto -> canonicalize/cse -> existing VPTO lowering/codegen ``` +Only `vmi-layout-assignment` is required for the first legal implementation. +The optimization passes may be introduced one by one. Their contract is that +they consume legal layout-assigned VMI IR and produce legal layout-assigned VMI +IR; they never move a hidden decision into `vmi-to-vpto`. + Pass responsibilities: ```text -pto-validate-vmi-surface: +pto-validate-vmi-ir: verify surface VMI has no physical VPTO layout dependency reject public/external VMI ABI unless explicitly enabled vmi-layout-assignment: - solve value layouts - choose selected lowering plans + solve hard value layout constraints + choose explicit layouts and local recipe carriers visible in IR insert ensure/rematerialization helpers make internal function boundary layouts explicit rewrite VMI types with layout attrs -pto-validate-vmi-layout: +canonicalize/cse: + remove dead helpers and merge identical cloned producers where MLIR legality + permits + +vmi-layout-fold-consumers: + fold use-site materialization into consumers that can directly consume the + source layout while preserving the same logical effect + example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become + a store of deinterleaved=2 when the store has a local vstsx2 INTLV recipe + current implementation: pto.vmi.store, pto.vmi.tile_write, and the value + operand of pto.vmi.masked_store when the existing mask arity matches, fed by + ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 + uses the store's vstsx2 INTLV recipe; factor=4 is still store-local, but it + materializes through physical interleave before vsts. + +vmi-layout-rematerialize: + replace explicit ensure_* helpers with cloned cheap layout-polymorphic + producers when the clone directly creates the requested result type + current implementation: splat pto.vmi.constant, pto.vmi.broadcast, + pto.vmi.iota, pto.vmi.create_mask, pto.vmi.create_group_mask, and + pto.vmi.constant_mask + not included in the first implementation: load, group_load, masked_load, + group_slot_load, and group_broadcast; those require separate memory, + execution-count, or source-layout proof before they can be rematerialized + +vmi-layout-sink-materialization: + move ensure_layout across pure layout-transparent elementwise chains when the + rewritten IR reduces materialization cost and keeps every op locally legal + current implementation: sink two identical operand ensure_layout helpers + across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, or one + source ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, + producing one result ensure_layout. It also sinks matching + ensure_mask_layout or ensure_mask_granularity helpers across + mask_and/mask_or/mask_xor/mask_not, producing one result mask helper. It + does not sink through select, fma, cast, load, store, reduce, + group_broadcast, or control-flow ops + +vmi-legalize-arith-select: + restore scalar-condition arith.select with VMI result type back to scf.if + after canonicalize; canonicalize may fold simple scf.if into arith.select, + but VMI values must not cross non-VMI semantic ops before vmi-to-vpto + +pto-validate-vmi-layout-ir: verify every VMI data/mask value has layout verify every VMI value has an assigned layout and every non-local lowering choice has been serialized explicitly - verify helper ops have registered materialization plans + verify helper ops have registered materialization recipes. Current + implementation checks `ensure_layout`, `ensure_mask_layout`, and + `ensure_mask_granularity` at the layout gate, so unsupported helper recipes + fail before `vmi-to-vpto`. It also checks the first semantic local-recipe + families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 + `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots + `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_addf`, + explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, + `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. vmi-to-vpto: use OneToN type conversion - lower only from explicit layout/plan information + lower only from current-op attrs/operands, operand/result layouts, and helper + ops emit VPTO or precise unsupported diagnostic ``` +### 1.1 Hard Constraints Versus Optimizations + +Hard legalization answers "can this program be lowered correctly?" It is +allowed to be conservative: + +```text +%w = pto.vmi.extf %a // natural layout deinterleaved=2 +%t1 = pto.vmi.mulf %w, %k1 // layout-transparent, stays deinterleaved=2 +%t1_c = pto.vmi.ensure_layout %t1 // hard store contract wants contiguous +pto.vmi.store %t1_c, %OUT1 +%w_c = pto.vmi.ensure_layout %w +pto.vmi.store %w_c, %OUT2 +``` + +This is a correct legal shape. The contiguous action is explicit at each store +use, and `vmi-to-vpto` lowers the helper with register materialization such as +`vintlv` before ordinary `vsts`. + +Optimization answers "can the same external effect be cheaper?" A fold pass +may rewrite the two store uses to consume the deinterleaved values directly: + +```text +pto.vmi.store %t1, %OUT1 // value type still says deinterleaved=2 +pto.vmi.store %w, %OUT2 +``` + +This optimized shape is legal only because `pto.vmi.store` has enough local +information to lower a `deinterleaved=2` f32 value to row-major memory, for +example with `vstsx2 INTLV_B32`. The optimization does not require +`vmi-to-vpto` to inspect `%w`'s producer or the sibling store. + +The split gives later passes room to improve layout choices: + +```text +hard pass: + guarantee legality with explicit ensure_* helpers + +optimization passes: + remove, fold, clone, or sink helpers when the optimized IR is still locally + deterministic + +vmi-to-vpto: + physicalize exactly the IR it sees, with no global planning +``` + ## 2. Files To Add Or Update Expected implementation files: @@ -56,10 +165,10 @@ include/PTO/IR/VMIAttrs.td lib/PTO/IR/VMI.cpp include/PTO/Transforms/Passes.td -lib/PTO/Transforms/ValidateVMI.cpp +lib/PTO/Transforms/PTOValidateVMIIR.cpp lib/PTO/Transforms/VMILayoutAssignment.cpp lib/PTO/Transforms/VMIToVPTO.cpp -lib/PTO/Transforms/VMILayoutPlanRegistry.cpp +lib/PTO/Transforms/VMILocalRecipeRegistry.cpp test/lit/vmi/vmi_layout_assignment_*.pto test/lit/vmi/vmi_to_vpto_*.pto @@ -115,7 +224,7 @@ contiguous: deinterleaved: F > 1 B > 0 - direct full-chunk plans require N % (F * B) == 0 + direct full-chunk recipes require N % (F * B) == 0 group_slots: G > 0 @@ -188,9 +297,9 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_addf source/mask/result layouts, num_groups, reassoc group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths -ensure_layout always carries source/result layouts instead of plan -ensure_mask_layout always carries source/result layouts instead of plan -ensure_mask_granularity always carries source/result granularities instead of plan +ensure_layout always carries source/result layouts instead of recipe +ensure_mask_layout always carries source/result layouts instead of recipe +ensure_mask_granularity always carries source/result granularities instead of recipe ``` Layout/attr-only decisions today: @@ -212,8 +321,9 @@ vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. If a layout/attr-only op later gains a second legal recipe that cannot be distinguished from current-op information, that recipe must be represented by a new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. -Unsupported shapes that have no registered plan still diagnose through their -specific capability check rather than failing with a generic missing-plan error. +Unsupported shapes that have no registered recipe still diagnose through their +specific capability check rather than failing with a generic missing-recipe +error. ``` Examples of forbidden recovery in `vmi-to-vpto`: @@ -273,29 +383,31 @@ group_slot_load: loads one scalar per group and produces group_slots ``` -## 5. Plan Registry +## 5. Local Recipe Registry -Create one registry object shared by assignment and lowering. +Create one target-aware local recipe registry shared by assignment and lowering. +It is not serialized as a separate recipe-selection attribute. It answers local legality +questions from op kind, explicit attrs/operands, layouts, and target capability. ```c++ -class VMILayoutPlanRegistry { +class VMILocalRecipeRegistry { public: - SmallVector getProducerPlans(Operation *op); - SmallVector getConsumerPlans(OpOperand &use); - SmallVector getTransferPlans(Operation *op); - FailureOr getMaterializationPlan(Type valueType, - VMILayoutKey from, - VMILayoutKey to); + SmallVector getProducerRecipes(Operation *op); + SmallVector getConsumerRecipes(OpOperand &use); + SmallVector getTransferRecipes(Operation *op); + FailureOr + getMaterializationRecipe(Type valueType, VMILayoutKey from, + VMILayoutKey to); bool isCheaplyRematerializable(Operation *op); - bool hasTargetCapability(PlanID plan) const; + bool hasTargetCapability(RecipeID recipe) const; }; ``` -Plan record: +Recipe record: ```c++ -struct VMILayoutPlan { - PlanID id; +struct VMILayoutRecipe { + RecipeID id; SmallVector operandLayouts; SmallVector resultLayouts; int64_t cost; @@ -315,6 +427,69 @@ enablePublicVMIABI diagnosticVerbosity ``` +Assignment and optimization passes may query the registry to decide which IR +shape to produce. `vmi-to-vpto` may query the same registry to verify the +current op is locally lowerable. If the same op, attrs, operands, and +operand/result layouts could map to two different physical recipes with +different observable preconditions, the IR is under-specified; add an explicit +attr, operand, helper op, or distinct VMI semantic op before implementing that +recipe. + +Current implementation status: `VMILocalRecipeRegistry` exists and currently +owns nine local recipe families: + +```text +contiguous store/tile_write consumer recipes: + contiguous vsts + deinterleaved=2 vstsx2 INTLV + deinterleaved=4 materialize-then-vsts + +helper materialization recipes: + data/mask layout identity + data/mask contiguous <-> deinterleaved=2/4 when source/result physical + arity matches and the physical part shape can be materialized + mask granularity identity or b8/b16/b32 predicate cast + +group_slot_load semantic recipes: + slots=8 unit-stride vsldb + slots=1 aligned lane-0 vsldb per group + +block8 group_load semantic recipes: + S=16 deinterleaved=2, block_elems=8 vsldb per row fragment + S=32 deinterleaved=4, block_elems=8 vsldb per row fragment + +group_slots group_store semantic recipes: + slots=8 unit-stride vsts + slots=1 aligned lane-0 vsts per group + +group_slots group_reduce_addf semantic recipes: + S=8 vcgadd + S=16 deinterleaved=2 vcgadd+vadd + S=32 deinterleaved=4 vcgadd+vadd tree + S=64 contiguous slots=1 vcadd/vadd/vsel row-local reduction + +explicit-slots group_broadcast semantic recipes: + slots=8/slots=1 vselr materialization to contiguous or supported + deinterleaved result layouts + +extf/truncf semantic recipes: + contiguous f16/bf16 -> deinterleaved=2 f32 + contiguous f8-like -> deinterleaved=4 f32 + deinterleaved=2 f32 -> contiguous f16 + deinterleaved=4 f32 -> contiguous f8-like + group_slots(G, slots=1) f32 -> f16 + +bitcast semantic recipes: + per-part vbitcast for contiguous/deinterleaved layouts when source/result + layouts match, physical arity matches, and every physical chunk carries the + same logical bit footprint; this does not require each deinterleaved part to + contain the same number of chunks. group_slots bitcast is unsupported until a + slot-wise bitcast contract is defined. +``` + +`vmi-layout-fold-consumers`, `pto-validate-vmi-layout-ir`, and `vmi-to-vpto` +query this registry for the decisions implemented above. + ## 6. Layout Assignment Data Model ### 6.1 Solver State @@ -331,14 +506,14 @@ struct ValueLayoutState { struct UseRequest { OpOperand *operand; VMILayoutKey requestedLayout; - PlanID requestingPlan; + RecipeID requestingRecipe; bool hard; }; -struct OpPlanState { +struct OpRecipeState { Operation *op; - SmallVector candidates; - std::optional chosen; + SmallVector candidates; + std::optional chosen; }; ``` @@ -350,7 +525,7 @@ Walk the module and collect: 1. every VMI value 2. every VMI block argument 3. every VMI function argument/result -4. every VMI op with candidate plans +4. every VMI op with candidate local recipes 5. every branch/yield/call/return edge carrying VMI ``` @@ -455,11 +630,11 @@ compact S=12 logical S=16: ### 6.3.1 Request Builders Implement request generation as small per-op builders. The builders produce -candidate plans and use-site requests; they do not rewrite IR. +candidate recipes and use-site requests; they do not rewrite IR. ```text buildStoreRequests: - ordinary store -> dense contiguous request unless a layout-aware store plan is + ordinary store -> dense contiguous request unless a layout-aware store recipe is selected group_store -> group_slots(G,K) request plus stride/alignment capability checks @@ -469,8 +644,8 @@ buildCastRequests: extf f8->f32 -> source contiguous, result deinterleaved=4 truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous - group_slots slots=1 f32->f16 -> slot-preserving plan - group_slots slots=8 width-changing cast -> diagnostic unless a packed plan + group_slots slots=1 f32->f16 -> slot-preserving recipe + group_slots slots=8 width-changing cast -> diagnostic unless a packed recipe exists buildGroupReduceRequests: @@ -481,11 +656,11 @@ buildGroupReduceRequests: S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, group_slots(G,8) result S=64 -> contiguous source, group_slots(G,1) result - other S -> diagnostic unless an explicit fallback plan is enabled + other S -> diagnostic unless an explicit fallback recipe is enabled buildGroupMemoryRequests: - group_load S=16/S=32 with aligned constant stride -> block_elems=8 plan - group_load row-local full chunks -> contiguous plan + group_load S=16/S=32 with aligned constant stride -> block_elems=8 recipe + group_load row-local full chunks -> contiguous recipe group_slot_load unit stride -> group_slots(G,8) group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic @@ -513,7 +688,7 @@ buildFunctionBoundaryRequests: private/internal function argument/result layouts are specialized or materialized with callee-entry/return-site helpers public/external VMI arguments/results diagnose unless enablePublicVMIABI has - a real ABI plan + a real ABI recipe ``` Request builders must record the requesting op. Diagnostics and inserted @@ -534,7 +709,7 @@ cheap rematerializable producers: create_group_mask group_broadcast group_slot_load when the same address/no-alias/proof conditions as load hold - and the selected memory plan is legal at the clone site + and the memory recipe is legal at the clone site layout-transparent producers: add/sub/mul/fma/min/max/neg/abs @@ -543,10 +718,10 @@ layout-transparent producers: integer bitwise and shift ops fixed-layout producers: - extf/truncf physical conversion plans - group_load block-fragment plans + extf/truncf physical conversion recipes + group_load block-fragment recipes group_reduce result group_slots - masked_load when the physical memory-safety proof fixes a full-read plan + masked_load when the physical memory-safety proof fixes a full-read recipe ``` Conflict policy: @@ -568,17 +743,17 @@ This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to `deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 `group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` that selected `block_elems=8` is fixed unless a block8-to-parity -materialization or rematerialized memory plan is registered. +materialization or rematerialized memory recipe is registered. ### 6.4 Solving And Rewriting Algorithm: ```text -1. Pick candidate plan sets for every op. +1. Pick candidate recipe sets for every op. 2. Propagate hard constraints through SCCs. 3. Resolve transfer-equivalent dense values. -4. Choose multi-plan ops by cost: +4. Choose multi-recipe ops by cost: - S=16 parity vs block8 - load memory-fused vs load+materialize - group_slot_load slots=8 vs slots=1 @@ -596,7 +771,7 @@ Rewrite invariants: No VMI data/mask value after assignment has a null layout. Any non-local choice is represented by op attrs, operand/result layouts, a helper op, a clone, or an explicit diagnostic. -Every ensure_* helper has a registered materialization plan. +Every ensure_* helper has a registered materialization recipe. Every function/call signature carrying VMI is specialized or diagnosed. ``` @@ -689,12 +864,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous plan -3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 plan -3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 plan -3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local plan +3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous recipe +3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 recipe +3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 recipe +3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local recipe 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks -3.19.1 S=16 block_elems choice buildGroupReduceRequests selected block_elems reduce plan +3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks 3.26 grouped tail buildMaskRequests split grouped masks 3.44, 3.45 grouped S=32 masks buildMaskRequests explicit deinterleaved mask values @@ -708,12 +883,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load plan +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load recipe 3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan 3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan 3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan 3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan -3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load plan +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load recipe 3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof 3.39 strided load fanout conflict resolver preserving layout or materialization @@ -762,17 +937,17 @@ vmi-to-vpto contract: ```text diagnostic family builder / owner required failure -3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store plan +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store recipe 3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast 3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather -3.13 slots=8 width cast buildCastRequests no packed slot cast plan -3.14 unsupported group size buildGroupReduceRequests no registered reduce plan +3.13 slots=8 width cast buildCastRequests no packed slot cast recipe +3.14 unsupported group size buildGroupReduceRequests no registered reduce recipe 3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan -3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load plan +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load recipe 3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan 3.19.2 invalid block_elems use conflict resolver no preserving materialization 3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI -3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback plan +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback recipe 3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback vmi-to-vpto contract: @@ -970,7 +1145,7 @@ group_reduce_addf: VCGADDs plus a PAT_VL8 VADD tree per packed result block. S=64 row-local assignment uses #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit - slots=1 generic VCADD row-local path is selected locally. + slots=1 generic VCADD row-local path is registered and selected locally. group_broadcast: explicit slots=8/1 source layouts select @@ -991,8 +1166,10 @@ group_load: contiguous full-chunk path is selected from a contiguous result layout. S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one - vsldb per 32B row fragment and physical chunk. The dedicated S=16 unit-stride - vldsx2/BDINTLV recipe remains a local peephole target. + vsldb per 32B row fragment and physical chunk. The explicit block8 recipe + is registered and checked by pto-validate-vmi-layout-ir before vmi-to-vpto. + The dedicated S=16 unit-stride vldsx2/BDINTLV recipe remains a local + peephole target. S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by vmi-layout-assignment because the stable gather fallback is not implemented. @@ -1010,12 +1187,12 @@ group_store: multiple of the 32B store alignment in destination elements: 8 for f32, 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only the first row-local store is 32B-aligned; later `group_off + r` stores are - 4B apart. A future pack-to-slots=8 or unaligned-store plan is required before + 4B apart. A future pack-to-slots=8 or unaligned-store recipe is required before contiguous `%c1` slots=1 group_store can be accepted. Packed group_slots(G, slots=8) group_store is implemented only when num_groups is a multiple of 8 and row_stride is constant 1; it emits one PAT_VL8 store per packed slot block. Non-unit packed group stores remain a - design target unless a strided packed-lane store plan is selected explicitly. + design target unless a strided packed-lane store recipe is made explicit. ``` Examples: @@ -1065,7 +1242,7 @@ After assignment: Every VMI value has layout. Every VMI mask has layout and granularity plan. Every lowering choice is locally deterministic or explicit in attrs/layouts. -Every ensure_* helper has a materialization plan. +Every ensure_* helper has a materialization recipe. Every control-flow edge has matching VMI layouts. ``` @@ -1143,8 +1320,8 @@ VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is fixed to #pto.vmi.layout by the selected - strided group_load plan. Register a rematerialization or preserving - materialization plan, or avoid consuming this block-loaded value with truncf. + strided group_load recipe. Register a rematerialization or preserving + materialization recipe, or avoid consuming this block-loaded value with truncf. ``` ## 11. Test And Simulator Acceptance @@ -1212,13 +1389,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-selected-plan-gate CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-layout-gate CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-selected-plan-gate/parallel-summary.tsv +summary: .tmp/vmi-runtime-batch-layout-gate/parallel-summary.tsv log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-selected-plan-gate.log + .tmp/vmi-runtime-batch-layout-gate.log result: no matches ``` @@ -1298,7 +1475,7 @@ repository evidence: all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 314/314 + latest full VMI lit sweep passed: 340/340 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1656,6 +1833,65 @@ runtime SIM: test/vpto/cases/vmi/widen-f16-to-f32-store-reduce ``` +Current checked-in lit coverage for the first `vmi-layout-fold-consumers` +optimization is: + +```text +test/lit/vmi/vmi_layout_fold_consumers_store.pto +test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto +test/lit/vmi/vmi_layout_fold_consumers_deint4.pto +``` + +Current checked-in lit coverage for the first `vmi-layout-rematerialize` +optimization is: + +```text +test/lit/vmi/vmi_layout_rematerialize_data.pto +test/lit/vmi/vmi_layout_rematerialize_mask.pto +``` + +Current checked-in lit coverage for the first +`vmi-layout-sink-materialization` optimization is: + +```text +test/lit/vmi/vmi_layout_sink_materialization_binary.pto +test/lit/vmi/vmi_layout_sink_materialization_mask.pto +``` + +Current checked-in lit coverage for canonicalized VMI control-flow restoration is: + +```text +test/lit/vmi/vmi_legalize_arith_select.pto +test/lit/vmi/vmi_ptoas_cli_control_flow.pto +``` + +Current checked-in lit coverage for the first semantic local-recipe layout gate +is: + +```text +test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +``` + +Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local +recipes is: + +```text +test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto +test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto +``` + Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: ```text @@ -1710,7 +1946,7 @@ Diagnostic-only cases: 3.16.1 group_slot_load slots=8 non-unit stride 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements -3.19.2 block_elems=8 value consumed by truncf without materialization plan +3.19.2 block_elems=8 value consumed by truncf without materialization recipe 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -1729,11 +1965,17 @@ entries: ```text lit: + test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto + test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto + test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto test/lit/vmi/vmi_ptoas_public_abi_invalid.pto test/lit/vmi/vmi_ptoas_public_result_abi_invalid.pto test/lit/vmi/vmi_layout_assignment_external_call_invalid.pto @@ -1795,7 +2037,7 @@ group_store ```text 3.8 cast commute through group_broadcast 3.18 dense/group-reduce multi-consumer -3.19 block_elems plan selection +3.19 block_elems recipe selection 3.23 group_broadcast multi-consumer 3.32 f32 feeding f8 store and S=32 reduce 3.33 S=16/S=32 reduce multi-consumer rematerialization @@ -1847,7 +2089,7 @@ Current evidence for the case-catalog objective: 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 314/314 +5. the latest full VMI lit sweep passed: 340/340 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index b30c0c3472..4c13b07ef8 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -8,8 +8,20 @@ ```text VMI surface IR - -> vmi-layout-assignment - -> layout-assigned VMI IR + -> pto-validate-vmi-ir + -> vmi-layout-assignment // hard legalization baseline + -> canonicalize/cse + -> vmi-layout-fold-consumers // optional optimization + -> canonicalize/cse + -> vmi-layout-rematerialize // optional optimization + -> canonicalize/cse + -> vmi-layout-sink-materialization // optional optimization + -> canonicalize/cse + -> optional later layout optimization passes + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> layout-assigned and optimized VMI IR -> vmi-to-vpto -> VPTO IR ``` @@ -20,15 +32,59 @@ VMI surface IR vmi-to-vpto 不允许通过上下文猜 lowering。 任何需要 producer/consumer/control-flow/memory/mask 上下文才能决定的事, -必须在 vmi-layout-assignment 阶段变成显式 IR 信息: +必须在 vmi-layout-assignment 或后续 VMI layout optimization 阶段变成显式 IR: 1. vmi.vreg/vmi.mask 的 layout -2. op 的 selected lowering plan -3. use-site ensure_layout / ensure_mask_layout -4. rematerialized producer +2. current-op attrs/operands that make the local recipe deterministic +3. use-site ensure_layout / ensure_mask_layout / ensure_mask_granularity +4. rematerialized or cloned producer 5. target capability diagnostic ``` +## 0. Hard Legalization And Optimization Boundary + +Layout assignment is a stage, not necessarily one monolithic pass. The design +separates correctness from optimization: + +```text +hard legalization: + produces legal layout-assigned VMI IR for all supported semantics + inserts conservative ensure_* helpers at incompatible uses + may choose a simple canonical layout even when a fused consumer recipe exists + must diagnose unsupported semantics before vmi-to-vpto has to guess + +layout optimization: + rewrites already legal VMI IR into cheaper but equivalent VMI IR + may fold ensure_layout into a layout-aware consumer + may clone/rematerialize cheap producers for different use-site layouts + may sink or hoist layout materialization through pure elementwise chains + may specialize private VMI function signatures +``` + +The driver currently runs MLIR's normal `canonicalize` and `cse` between these +VMI-specific passes. They are allowed to clean up trivially unused helpers, +merge identical rematerialized producers, and expose simpler use-def shapes. +They are not a source of hidden lowering information; after every optimization, +the IR must still carry enough local information for `vmi-to-vpto`. + +The baseline hard pass may emit: + +```text +%x_c = pto.vmi.ensure_layout %x : deinterleaved=2 -> contiguous +pto.vmi.store %x_c +``` + +A later optimization may replace that use with: + +```text +pto.vmi.store %x : deinterleaved=2 +``` + +only if the store op itself has a local deterministic recipe for preserving the +same row-major memory effect, such as a layout-aware `vstsx2 INTLV` lowering. +Both forms are semantically complete. The second form is an optimization, not +a hard requirement for correctness. + ## 1. Source Case Coverage 设计必须覆盖 case catalog 中的端到端场景: @@ -55,7 +111,7 @@ layout conflict: one scalar broadcast rematerialized for dense and grouped users one non-rematerializable value materialized with use-site ensure_layout one scalar group-slot source rematerialized as slots=8 and slots=1 - S=16 block_elems=1/8 plan selection + S=16 block_elems=1/8 recipe selection dense consumer of group_slots diagnostic packed group-slot width-changing cast diagnostic S=64 slots=1 group-slot width-changing cast @@ -122,7 +178,7 @@ memory legality: ``` No extra layout kind should be added unless a new case proves that the existing -layouts and plans cannot express the logical behavior. The remaining open +layouts and recipes cannot express the logical behavior. The remaining open items are not missing layout semantics: ```text @@ -203,14 +259,14 @@ slot_lane(g) = g % K All non-slot lanes are undefined and may only be read by group-aware operations. Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. -`K` is selected by the lowering plan: +`K` is selected by the producer/consumer local recipe: ```text S=8/16/32 packed VCG result -> slots=8 S=64 row-local result -> slots=1 ``` -## 3. Lowering Context Must Become Assignment Output +## 3. Lowering Context Must Become Explicit IR Output `vmi-to-vpto` may inspect only: @@ -233,7 +289,7 @@ It must not: 6. specialize function signatures during vmi-to-vpto ``` -Any of those decisions belongs to `vmi-layout-assignment`. +Any of those decisions belongs to the layout stage before `vmi-to-vpto`. ## 4. Explicit Assignment Products @@ -323,7 +379,7 @@ group_store: masked_load: explicit passthrough, mask layout, full physical read, shaped safe-tail memref, or an explicit diagnostic decide legality. A future stable gather fallback - must be selected by assignment before vmi-to-vpto lowers it. + must be made explicit by assignment before vmi-to-vpto lowers it. masked_store/select/elementwise: operand/result layouts and explicit mask granularity decide the lowering. @@ -350,40 +406,48 @@ If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, assigned layouts, and the missing decision class. -## 5. Plan Registry +## 5. Local Recipe Registry + +The compiler owns a target-aware local recipe registry. Layout assignment and +layout optimization query this registry to decide which explicit IR shape to +produce. `vmi-to-vpto` queries the same registry only to verify and lower the +current op from local information. -The compiler owns a target-aware plan registry. Layout assignment queries this -registry; vmi-to-vpto verifies and consumes the chosen plan. +The registry is not serialized as a separate recipe-selection attribute. If +two legal physical recipes cannot be distinguished by the current op's name, +attrs, operands, operand/result layouts, helper ops, and target options, the +VMI IR is missing a carrier. Add an explicit attr, operand, helper op, or +semantic op before implementing that recipe. -### 5.1 Plan Kinds +### 5.1 Recipe Kinds ```text -ProducerPlan: +ProducerRecipe: op can produce result layout L example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv -ConsumerPlan: +ConsumerRecipe: op can consume operand layout L example: group_reduce S=32 consumes deinterleaved=4 -TransferPlan: +TransferRecipe: op ties operand/result layouts example: addf requires same dense layout for operands/result -MaterializationPlan: +MaterializationRecipe: layout A -> layout B without changing logical value example: deinterleaved=4 -> contiguous by vintlv tree -RematerializationPlan: +RematerializationRecipe: cheap producer can be cloned for a use-site layout example: broadcast/create_mask/group_broadcast -DiagnosticPlan: +DiagnosticRecipe: known unsupported semantic/capability boundary example: compact S=12 requires gather materialization ``` -### 5.2 Dense Plans From Cases +### 5.2 Dense Recipes From Cases ```text f16 -> f32: @@ -413,7 +477,7 @@ load: layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 ``` -### 5.3 Group Plans From Cases +### 5.3 Group Recipes From Cases ```text group_reduce f32 S=8: @@ -449,11 +513,11 @@ group_store: group_slot_cast f32 -> f16: slots=1 row-local source/result is legal with group_slot_cast_slots1_f32_to_f16 - slots=8 packed source is illegal unless a packed slot-preserving plan is + slots=8 packed source is illegal unless a packed slot-preserving recipe is registered ``` -### 5.4 Tail And Memory Safety Plans +### 5.4 Tail And Memory Safety Recipes Mask semantics and memory legality are separate: @@ -504,7 +568,7 @@ new catalog case or a proof that it is equivalent to one listed here. dense store: requests dense contiguous source if source is deinterleaved, assignment must insert ensure_layout or select a - store plan such as vstsx2 that consumes the assigned layout explicitly + store recipe such as vstsx2 that consumes the assigned layout explicitly truncf f32 -> f16: requests source deinterleaved=2, block_elems=1 @@ -556,7 +620,7 @@ group_slot_load: group_load: requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block - fragment plans, or contiguous for row-local full-chunk plans + fragment recipes, or contiguous for row-local full-chunk recipes masked_load: requests result layout from its consumers @@ -564,7 +628,7 @@ masked_load: requires explicit passthrough; padding is not synthesized masked_store: - requests dense source layout selected by the store plan + requests dense source layout selected by the store recipe requests mask layout matching the source layout and store element granularity does not choose memory safety for an earlier load @@ -584,9 +648,9 @@ Important negative requests: ```text ordinary dense add/mul/store/truncf cannot request group_slots packed group_slots(slots=8) cannot request width-changing cast unless a packed -slot-preserving cast plan is registered +slot-preserving cast recipe is registered slots=1 group_store cannot request unit-stride row-major output until a pack or -unaligned-store plan exists +unaligned-store recipe exists ``` ### 5.6 Conflict Resolution Matrix @@ -617,13 +681,13 @@ control-flow join: private function boundary: specialize or materialize at call/callee-entry before vmi-to-vpto -no clone/materialization/specialization plan: +no clone/materialization/specialization recipe: emit a diagnostic naming the requesting op and both layouts ``` The cost model may choose between legal rows only when the observable contract is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both -valid reduce inputs, but `block_elems=8` is selected only when a producer plan +valid reduce inputs, but `block_elems=8` is selected only when a producer recipe such as strided `group_load` naturally creates 32B row fragments or when cost proves it cheaper without breaking another consumer such as `truncf`. @@ -648,7 +712,7 @@ Create a use-site request for: ```text 1. every operand use that requires a specific layout 2. every control-flow yield/branch/call/return edge -3. every memory operation that requires a memory legality plan +3. every memory operation that requires a memory legality recipe ``` ### 6.2 Constraints @@ -657,14 +721,14 @@ Hard constraints: ```text group_slots cannot feed ordinary dense consumers -direct group-slot width-changing cast requires a slot-preserving plan +direct group-slot width-changing cast requires a slot-preserving recipe public/external VMI function boundary requires a stable ABI or diagnostic S=32 fast tail load requires full_tile_readable or gather fallback ``` -`slots = 1` row-local cast may satisfy the slot-preserving plan requirement. +`slots = 1` row-local cast may satisfy the slot-preserving recipe requirement. Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast -or unpack/materialization plan is registered. +or unpack/materialization recipe is registered. Equivalence constraints: @@ -686,11 +750,11 @@ S=16 group_reduce: one dense value feeding S=16 and S=32 group_reduce: rematerialize a cheap producer per consumer layout, or insert an explicit - materialization plan; the final lowering pass must not pick one layout after + materialization recipe; the final lowering pass must not pick one layout after seeing both users load/group_load: - choose memory plan and result layout together + choose memory recipe and result layout together group_broadcast: rematerialize per dense consumer layout @@ -702,10 +766,10 @@ Recommended solving order: ```text 1. Build function/control-flow SCCs. -2. Collect candidate plans for every op. +2. Collect candidate recipes for every op. 3. Propagate hard required layouts from consumers. 4. Propagate producer natural layouts where they are unique. -5. Resolve multi-plan ops by cost. +5. Resolve multi-recipe ops by cost. 6. Insert use-site materialization where a value has multiple incompatible uses. 7. Rematerialize cheap producers instead of materializing when cheaper. 8. Specialize internal function signatures. @@ -716,10 +780,10 @@ Recommended solving order: Tie-breaking must be deterministic. Suggested priority: ```text -1. Avoid unsupported plans. +1. Avoid unsupported recipes. 2. Prefer rematerializing cheap producers over register materialization. 3. Prefer layouts accepted by all consumers without conversion. -4. Prefer memory-fused layout plans over load + register rearrange. +4. Prefer memory-fused layout recipes over load + register rearrange. 5. Prefer fewer VPTO instructions. 6. Prefer contiguous only when cost ties and no consumer requests a special layout. ``` @@ -806,7 +870,7 @@ current VMI op body/attrs: helper materialization chain: allowed only to strip ensure_mask_layout / ensure_mask_granularity for - static predicate analysis that does not choose a different layout or plan + static predicate analysis that does not choose a different layout or recipe diagnostic embellishment: allowed only to improve an already-failed capability message, such as naming diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 262299b3a3..e084ad58c0 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -83,7 +83,7 @@ G % K == 0 K must fit in the physical vreg element count ``` -`K` is selected by the producer/consumer plan. It is not always 8. For +`K` is selected by the producer/consumer local recipe. It is not always 8. For `VCGADD`-packed results, `K = 8` matches the eight 32B block results written to the low lanes of one destination vreg. For row-local reductions where each logical group already occupies one full 256B vreg, `K = 1` keeps each group's @@ -99,9 +99,9 @@ physical slot block slot_block(g), lane slot_lane(g) All other lanes are undefined for ordinary VMI consumers. They may only be read by group-aware ops that define how to interpret group slots. -## 2. Plan Selection Rules +## 2. Recipe Selection Rules -VMI cast ops must not hard-code one physical `vcvt` plan as their semantic +VMI cast ops must not hard-code one physical `vcvt` recipe as their semantic layout rule. ```text @@ -112,7 +112,7 @@ dense cast: group-slot cast: source/result are both group_slots(G,K). lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are - legal only when a slot-preserving VPTO plan is registered, or when the cast + legal only when a slot-preserving VPTO recipe is registered, or when the cast can be commuted through a later group-aware consumer such as group_broadcast. ``` @@ -171,7 +171,7 @@ the immediately following complete endpoints. 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization -3.19 S=16 reduce block_elems plan selection complete/diagnostic +3.19 S=16 reduce block_elems recipe selection complete/diagnostic 3.20 group_slots control-flow join complete 3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete @@ -198,6 +198,7 @@ the immediately following complete endpoints. 3.43 internal function argument boundary materialization complete 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete +3.46 extf value and derived elemwise value both stored complete/optimization ``` ### 3.1 `f16 -> f32 -> store` @@ -2561,7 +2562,7 @@ VMI-LAYOUT-CONTRACT: use site. ``` -### 3.19 S=16 Reduce `block_elems` Plan Selection +### 3.19 S=16 Reduce `block_elems` Recipe Selection S=16 f32 group reduction has two legal dense input layouts: @@ -5349,3 +5350,120 @@ The runtime case passes `active_cols` as a kernel scalar argument and casts it to `index` inside `pto.vecscope`. This keeps scalar materialization outside `vmi-to-vpto`; the lowering pass only consumes the current `create_group_mask` operand. + +### 3.46 `extf` Value And Derived Elementwise Value Both Stored + +This case fixes where contiguous materialization belongs when one widened value +is used directly by a store and also by a layout-transparent elementwise chain +that is stored. + +VMI input: + +```text +%a = pto.vmi.load %in[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%k = pto.vmi.broadcast %k1 + : f32 -> !pto.vmi.vreg<128xf32> + +%w = pto.vmi.extf %a + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +%t1 = pto.vmi.mulf %w, %k + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Hard-legalized assigned layouts: + +```text +%a: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%w, %k, %t1: + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%t1_c = pto.vmi.ensure_layout %t1: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %t1_c, %out1[%off] + +%w_c = pto.vmi.ensure_layout %w: + #pto.vmi.layout -> #pto.vmi.layout +pto.vmi.store %w_c, %out2[%off] +``` + +Baseline VPTO lowering result: + +```text +%a0 = pto.vlds %in[%off] {dist = "NORM"} + : !pto.ptr, index -> !pto.vreg<128xf16> + +%w_p0 = pto.vcvt %a0, PAT_ALL_B16 {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +%w_p1 = pto.vcvt %a0, PAT_ALL_B16 {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +%k_p0 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> +%k_p1 = pto.vdup %k1, PAT_ALL_B32 : f32, !pto.mask -> !pto.vreg<64xf32> + +%t1_p0 = pto.vmul %w_p0, %k_p0, PAT_ALL_B32 : !pto.vreg<64xf32> +%t1_p1 = pto.vmul %w_p1, %k_p1, PAT_ALL_B32 : !pto.vreg<64xf32> + +// ensure_layout for the first store. +%t1_0, %t1_1 = pto.vintlv %t1_p0, %t1_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %t1_0, %out1[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %t1_1, %out1[%off_plus_64], %all_b32 {dist = "NORM_B32"} + +// ensure_layout for the second store. +%w_0, %w_1 = pto.vintlv %w_p0, %w_p1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32> + -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +pto.vsts %w_0, %out2[%off], %all_b32 {dist = "NORM_B32"} +pto.vsts %w_1, %out2[%off_plus_64], %all_b32 {dist = "NORM_B32"} +``` + +Memory result: + +```text +for i = 0..127: + out1[off + i] = f32(in[off + i]) * k1 + out2[off + i] = f32(in[off + i]) +``` + +Optimization pass result: + +```text +// vmi-layout-fold-consumers may remove both ensure_layout ops if the target +// supports a store recipe that consumes deinterleaved=2 and writes contiguous +// row-major memory. +pto.vmi.store %t1, %out1[%off] +pto.vmi.store %w, %out2[%off] +``` + +Optimized VPTO lowering result: + +```text +pto.vstsx2 %t1_p0, %t1_p1, %out1[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask + +pto.vstsx2 %w_p0, %w_p1, %out2[%off], "INTLV_B32", PAT_ALL_B32 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.ptr, index, + !pto.mask +``` + +Required assignment and optimization rule: + +```text +Hard legalization may always preserve `%w` and `%t1` in deinterleaved=2 and +insert use-site ensure_layout before ordinary stores. This is correct because +the layout change is explicit at the store use. + +Consumer folding is optional. It may remove the ensure_layout only when the +store itself can locally prove the same contiguous memory effect from the +source layout. vmi-to-vpto must not scan the `%w` producer or both store users +to decide this. +``` diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index b83bdbc195..9207993a14 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -114,6 +114,10 @@ LogicalResult validateVMILayoutAssignedIR(ModuleOp module, std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); std::unique_ptr createVMILayoutAssignmentPass(); +std::unique_ptr createVMILayoutFoldConsumersPass(); +std::unique_ptr createVMILayoutRematerializePass(); +std::unique_ptr createVMILayoutSinkMaterializationPass(); +std::unique_ptr createVMILegalizeArithSelectPass(); std::unique_ptr createVMIToVPTOPass(); std::unique_ptr createExpandTileOpPass(); std::unique_ptr createExpandTileOpPass(const ExpandTileOpOptions &options); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 770970ca36..74c9bf607e 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -839,6 +839,75 @@ def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { "mlir::scf::SCFDialect"]; } +def VMILayoutFoldConsumers : Pass<"vmi-layout-fold-consumers", "ModuleOp"> { + let summary = "Fold VMI layout materialization into layout-aware consumers"; + let description = [{ + Optimizes legal layout-assigned VMI IR by replacing selected use-site + ensure_layout consumers with consumers that can directly lower from the + source layout while preserving the same logical effect. The pass does not + choose layouts by inspecting producer/user context for vmi-to-vpto; it only + rewrites explicit helper IR into an equivalent local-consumer form. + }]; + let constructor = "mlir::pto::createVMILayoutFoldConsumersPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutRematerialize : Pass<"vmi-layout-rematerialize", "ModuleOp"> { + let summary = "Rematerialize cheap VMI producers at layout helpers"; + let description = [{ + Optimizes legal layout-assigned VMI IR by replacing selected ensure_layout, + ensure_mask_layout, and ensure_mask_granularity helpers with cloned cheap + producers that directly create the requested result type. The pass is + deliberately limited to pure construction ops, so memory, control-flow, and + mask-tail proofs remain explicit in the IR. + }]; + let constructor = "mlir::pto::createVMILayoutRematerializePass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILayoutSinkMaterialization + : Pass<"vmi-layout-sink-materialization", "ModuleOp"> { + let summary = "Sink VMI layout materialization through transfer ops"; + let description = [{ + Optimizes legal layout-assigned VMI IR by moving matching operand + ensure_layout helpers across pure layout-transparent elementwise operations. + The rewritten IR keeps the layout conversion explicit as a result + ensure_layout, so vmi-to-vpto still lowers from local op information only. + }]; + let constructor = "mlir::pto::createVMILayoutSinkMaterializationPass()"; + let dependentDialects = ["mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + +def VMILegalizeArithSelect : Pass<"vmi-legalize-arith-select", "ModuleOp"> { + let summary = "Legalize canonical arith.select over VMI values"; + let description = [{ + Rewrites scalar-condition arith.select operations that produce VMI values + back to scf.if. MLIR canonicalization may fold simple scf.if regions into + arith.select, but VMI values must not cross non-VMI semantic ops before + vmi-to-vpto. This pass restores an explicit structural control-flow form + that the VMI converter already handles. + }]; + let constructor = "mlir::pto::createVMILegalizeArithSelectPass()"; + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::cf::ControlFlowDialect", + "mlir::func::FuncDialect", + "mlir::pto::PTODialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect"]; +} + def VMIToVPTO : Pass<"vmi-to-vpto", "ModuleOp"> { let summary = "Convert layout-assigned VMI IR to physical VPTO IR"; let description = [{ diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h new file mode 100644 index 0000000000..7356be9e92 --- /dev/null +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -0,0 +1,198 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILocalRecipeRegistry.h - VMI local recipe queries ------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H +#define PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir::pto { + +class VMITargetCapabilityRegistry; + +enum class VMIContiguousStoreRecipeKind { + ContiguousVsts, + Deinterleaved2Vstsx2, + DeinterleavedMaterializeThenVsts, +}; + +struct VMIContiguousStoreRecipe { + VMIContiguousStoreRecipeKind kind = + VMIContiguousStoreRecipeKind::ContiguousVsts; +}; + +enum class VMILayoutMaterializationRecipeKind { + Identity, + ContiguousToDeinterleaved, + DeinterleavedToContiguous, +}; + +struct VMILayoutMaterializationRecipe { + VMILayoutMaterializationRecipeKind kind = + VMILayoutMaterializationRecipeKind::Identity; +}; + +enum class VMIMaskGranularityMaterializationRecipeKind { + Identity, + PredicateCast, +}; + +struct VMIMaskGranularityMaterializationRecipe { + VMIMaskGranularityMaterializationRecipeKind kind = + VMIMaskGranularityMaterializationRecipeKind::Identity; +}; + +enum class VMIGroupSlotLoadRecipeKind { + Slots8UnitStrideVsldb, + Slots1AlignedLane0Vsldb, +}; + +struct VMIGroupSlotLoadRecipe { + VMIGroupSlotLoadRecipeKind kind = + VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb; +}; + +enum class VMIGroupLoadRecipeKind { + S16Block8Vsldb, + S32Block8Vsldb, +}; + +struct VMIGroupLoadRecipe { + VMIGroupLoadRecipeKind kind = VMIGroupLoadRecipeKind::S16Block8Vsldb; +}; + +enum class VMIGroupSlotsStoreRecipeKind { + Slots8UnitStrideVsts, + Slots1AlignedLane0Vsts, +}; + +struct VMIGroupSlotsStoreRecipe { + VMIGroupSlotsStoreRecipeKind kind = + VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts; +}; + +enum class VMIGroupReduceAddFRecipeKind { + S8Vcgadd, + S16Deinterleaved2VcgaddVadd, + S32Deinterleaved4VcgaddTree, + S64ContiguousVcaddRows, +}; + +struct VMIGroupReduceAddFRecipe { + VMIGroupReduceAddFRecipeKind kind = VMIGroupReduceAddFRecipeKind::S8Vcgadd; +}; + +enum class VMIGroupBroadcastRecipeKind { + GroupSlotsVselr, +}; + +struct VMIGroupBroadcastRecipe { + VMIGroupBroadcastRecipeKind kind = + VMIGroupBroadcastRecipeKind::GroupSlotsVselr; +}; + +enum class VMITruncFRecipeKind { + Deinterleaved2F32ToContiguousF16, + Deinterleaved4F32ToContiguousF8, + GroupSlots1F32ToF16, +}; + +struct VMITruncFRecipe { + VMITruncFRecipeKind kind = + VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16; +}; + +enum class VMIExtFRecipeKind { + ContiguousF16ToDeinterleaved2F32, + ContiguousF8ToDeinterleaved4F32, +}; + +struct VMIExtFRecipe { + VMIExtFRecipeKind kind = + VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; +}; + +enum class VMIBitcastRecipeKind { + PerPartVbitcast, +}; + +struct VMIBitcastRecipe { + VMIBitcastRecipeKind kind = VMIBitcastRecipeKind::PerPartVbitcast; +}; + +class VMILocalRecipeRegistry { +public: + FailureOr + getContiguousStoreRecipe(VMIVRegType valueType, + std::string *reason = nullptr) const; + + LogicalResult canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getDataLayoutMaterializationRecipe(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskLayoutMaterializationRecipe(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskGranularityMaterializationRecipe(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotLoadRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupLoadRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotsStoreRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddFRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, + std::string *reason = nullptr) const; + + FailureOr + getTruncFRecipe(VMITruncFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; + + FailureOr + getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index f2fe3ece10..12dbb7c8e9 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -36,7 +36,12 @@ add_mlir_dialect_library(PTOTransforms PTOValidateVPTOIR.cpp PTOUnrollSIMTForPass.cpp PTOValidateVMIIR.cpp + VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp + VMILayoutFoldConsumers.cpp + VMILocalRecipeRegistry.cpp + VMILayoutRematerialize.cpp + VMILayoutSinkMaterialization.cpp VMIToVPTO.cpp PTOInferVPTOVecScope.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6ce3e8eecd..7234084c47 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -12,6 +12,8 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -159,6 +161,49 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, + Twine message) { + InFlightDiagnostic diag = + op->emitError() << kVMIDiagLayoutContractPrefix << message; + (void)diag; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + +LogicalResult emitHelperMaterializationContract(Operation *helper, + Type sourceType, + Type resultType, + StringRef helperName, + StringRef reason, + llvm::raw_ostream *diagOS) { + auto emitFallback = [&]() { + return emitLayoutContract( + helper, diagOS, + Twine(helperName) + " has no registered materialization recipe: " + + reason); + }; + + if (helper->getNumResults() != 1 || !helper->getResult(0).hasOneUse()) + return emitFallback(); + + OpOperand &use = *helper->getResult(0).use_begin(); + Operation *requester = use.getOwner(); + std::string message; + llvm::raw_string_ostream os(message); + os << requester->getName() << " operand #" << use.getOperandNumber() + << " has type " << sourceType << " but requires " << resultType << "; " + << helperName << " has no registered materialization recipe: " << reason; + os.flush(); + + InFlightDiagnostic diag = + requester->emitError() << kVMIDiagLayoutContractPrefix << message; + diag.attachNote(helper->getLoc()) + << "failed helper conversion " << sourceType << " -> " << resultType + << " (" << reason << ")"; + mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); + return failure(); +} + LogicalResult verifyBoundaryType(Operation *owner, Type type, llvm::raw_ostream *diagOS) { if (isPhysicalVPTOType(type)) @@ -350,6 +395,12 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, return success(); } +LogicalResult verifyLayoutHelperRecipe(Operation *op, + llvm::raw_ostream *diagOS); + +LogicalResult verifyLayoutSemanticRecipe(Operation *op, + llvm::raw_ostream *diagOS); + LogicalResult verifyOperationBoundary(Operation *op, llvm::raw_ostream *diagOS) { if (failed(verifyOperationTypes(op, diagOS))) @@ -380,19 +431,209 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) - return success(); + return verifyLayoutHelperRecipe(op, diagOS); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); } - if (isVMISemanticOp(op) || isStructuralOp(op)) + if (isVMISemanticOp(op)) + return verifyLayoutSemanticRecipe(op, diagOS); + if (isStructuralOp(op)) return success(); return emitInvariant(op, diagOS, "VMI typed value is used by a non-VMI semantic op"); } +LogicalResult verifyLayoutHelperRecipe(Operation *op, + llvm::raw_ostream *diagOS) { + VMILocalRecipeRegistry recipes; + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getDataLayoutMaterializationRecipe(sourceType, + resultType, + &reason))) + return emitHelperMaterializationContract( + op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getMaskLayoutMaterializationRecipe(sourceType, + resultType, + &reason))) + return emitHelperMaterializationContract( + op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, + diagOS); + return success(); + } + + if (auto ensure = dyn_cast(op)) { + auto sourceType = cast(ensure.getSource().getType()); + auto resultType = cast(ensure.getResult().getType()); + std::string reason; + if (failed(recipes.getMaskGranularityMaterializationRecipe( + sourceType, resultType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.ensure_mask_granularity has no registered " + "materialization recipe: ") + + reason); + return success(); + } + + return success(); +} + +LogicalResult verifyLayoutSemanticRecipe(Operation *op, + llvm::raw_ostream *diagOS) { + VMILocalRecipeRegistry recipes; + VMITargetCapabilityRegistry capabilities; + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || layout.isContiguous()) + return success(); + + std::string reason; + if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.store has no registered contiguous-memory local " + "recipe: ") + + reason); + return success(); + } + + if (auto tileWrite = dyn_cast(op)) { + auto valueType = cast(tileWrite.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || layout.isContiguous()) + return success(); + + std::string reason; + if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.tile_write has no registered contiguous-memory local " + "recipe: ") + + reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return success(); + + std::string reason; + if (failed(recipes.getGroupLoadRecipe(capabilities, load, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_load has no registered block8 local recipe: ") + + reason); + return success(); + } + + if (auto load = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getGroupSlotLoadRecipe(capabilities, load, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_slot_load has no registered local recipe: ") + + reason); + return success(); + } + + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, store, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_store has no registered group_slots local " + "recipe: ") + + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed(recipes.getGroupReduceAddFRecipe(capabilities, reduce, + &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_reduce_addf has no registered group_slots " + "local recipe: ") + + reason); + return success(); + } + + if (auto broadcast = dyn_cast(op)) { + auto sourceType = cast(broadcast.getSource().getType()); + VMILayoutAttr layout = sourceType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return success(); + + std::string reason; + if (failed(recipes.getGroupBroadcastRecipe(capabilities, broadcast, + &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.group_broadcast has no registered local recipe: ") + + reason); + return success(); + } + + if (auto truncf = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getTruncFRecipe(truncf, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.truncf has no registered local recipe: ") + reason); + return success(); + } + + if (auto extf = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getExtFRecipe(extf, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.extf has no registered local recipe: ") + reason); + return success(); + } + + if (auto bitcast = dyn_cast(op)) { + std::string reason; + if (failed(recipes.getBitcastRecipe(bitcast, &reason))) + return emitLayoutContract( + op, diagOS, + Twine("pto.vmi.bitcast has no registered local recipe: ") + reason); + return success(); + } + + return success(); +} + struct PTOValidateVMIIRPass : public mlir::pto::impl::PTOValidateVMIIRBase { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PTOValidateVMIIRPass) diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp new file mode 100644 index 0000000000..26536f196d --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutFoldConsumers.cpp - Fold VMI layout consumers ------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTFOLDCONSUMERS +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !resultType) + return false; + + VMILocalRecipeRegistry recipes; + return succeeded( + recipes.canFoldContiguousStoreMaterialization(sourceType, resultType)); +} + +static void tryFoldEnsureLayoutIntoOperand( + OpOperand &operand, SmallVectorImpl &maybeDeadEnsures) { + auto ensure = operand.get().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + + operand.set(ensure.getSource()); + maybeDeadEnsures.push_back(ensure); +} + +static void tryFoldEnsureLayoutIntoMaskedStore( + VMIMaskedStoreOp store, + SmallVectorImpl &maybeDeadEnsures, + SmallVectorImpl &maybeDeadMaskEnsures) { + auto ensure = store.getValue().getDefiningOp(); + if (!ensure || !isFoldableStoreEnsure(ensure)) + return; + auto maskEnsure = store.getMask().getDefiningOp(); + if (!maskEnsure) + return; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!sourceType || !maskSourceType || !maskResultType) + return; + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskSourceLayout = maskSourceType.getLayoutAttr(); + VMILayoutAttr maskResultLayout = maskResultType.getLayoutAttr(); + if (!sourceLayout || !maskSourceLayout || !maskResultLayout) + return; + if (sourceLayout != maskSourceLayout || !maskResultLayout.isContiguous()) + return; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskSourceType); + if (failed(sourceArity) || failed(maskArity) || *sourceArity != *maskArity) + return; + + store.getValueMutable().set(ensure.getSource()); + store.getMaskMutable().set(maskEnsure.getSource()); + maybeDeadEnsures.push_back(ensure); + maybeDeadMaskEnsures.push_back(maskEnsure); +} + +struct VMILayoutFoldConsumersPass + : public mlir::pto::impl::VMILayoutFoldConsumersBase< + VMILayoutFoldConsumersPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutFoldConsumersPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector maybeDeadEnsures; + SmallVector maybeDeadMaskEnsures; + + module.walk([&](Operation *op) { + if (auto store = dyn_cast(op)) + tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), + maybeDeadEnsures); + if (auto tileWrite = dyn_cast(op)) + tryFoldEnsureLayoutIntoOperand(tileWrite.getValueMutable(), + maybeDeadEnsures); + if (auto maskedStore = dyn_cast(op)) + tryFoldEnsureLayoutIntoMaskedStore(maskedStore, maybeDeadEnsures, + maybeDeadMaskEnsures); + }); + + for (VMIEnsureMaskLayoutOp ensure : llvm::reverse(maybeDeadMaskEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + for (VMIEnsureLayoutOp ensure : llvm::reverse(maybeDeadEnsures)) { + if (ensure->use_empty()) + ensure.erase(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutFoldConsumersPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutRematerialize.cpp b/lib/PTO/Transforms/VMILayoutRematerialize.cpp new file mode 100644 index 0000000000..4f230d4189 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -0,0 +1,172 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutRematerialize.cpp - Rematerialize VMI producers ----------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTREMATERIALIZE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool hasConcreteLayout(VMIVRegType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static bool hasConcreteLayout(VMIMaskType type) { + return type && static_cast(type.getLayoutAttr()); +} + +static std::optional rematerializeDataProducer(Value value, + VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto constant = value.getDefiningOp()) { + auto denseAttr = dyn_cast(constant.getValue()); + if (denseAttr && denseAttr.isSplat()) + return builder + .create(loc, resultType, constant.getValue()) + .getResult(); + } + + if (auto broadcast = value.getDefiningOp()) + return builder.create(loc, resultType, + broadcast.getValue()) + .getResult(); + + if (auto iota = value.getDefiningOp()) + return builder + .create(loc, resultType, iota.getBase(), + iota.getOrderAttr()) + .getResult(); + + return std::nullopt; +} + +static std::optional rematerializeMaskProducer(Value value, + VMIMaskType resultType, + Location loc, + OpBuilder &builder) { + if (!hasConcreteLayout(resultType)) + return std::nullopt; + + if (auto createMask = value.getDefiningOp()) + return builder + .create(loc, resultType, createMask.getActiveLanes()) + .getResult(); + + if (auto createGroupMask = value.getDefiningOp()) + return builder + .create( + loc, resultType, createGroupMask.getActiveElemsPerGroup(), + createGroupMask.getNumGroupsAttr(), createGroupMask.getGroupSizeAttr()) + .getResult(); + + if (auto constantMask = value.getDefiningOp()) + return builder + .create(loc, resultType, + constantMask.getValueAttr()) + .getResult(); + + return std::nullopt; +} + +static bool tryReplaceDataEnsure(VMIEnsureLayoutOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeDataProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +template +static bool tryReplaceMaskEnsure(EnsureOp ensure) { + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return false; + + OpBuilder builder(ensure); + auto result = rematerializeMaskProducer(ensure.getSource(), resultType, + ensure->getLoc(), builder); + if (!result) + return false; + + ensure.getResult().replaceAllUsesWith(*result); + ensure.erase(); + return true; +} + +struct VMILayoutRematerializePass + : public mlir::pto::impl::VMILayoutRematerializeBase< + VMILayoutRematerializePass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutRematerializePass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector helpers; + module.walk([&](Operation *op) { + if (isa(op)) + helpers.push_back(op); + }); + + for (Operation *op : helpers) { + if (op->getBlock() == nullptr) + continue; + + if (auto ensure = dyn_cast(op)) { + tryReplaceDataEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) { + tryReplaceMaskEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) + tryReplaceMaskEnsure(ensure); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutRematerializePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp new file mode 100644 index 0000000000..c3bbf67731 --- /dev/null +++ b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -0,0 +1,363 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSinkMaterialization.cpp - Sink VMI layout helpers --------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILAYOUTSINKMATERIALIZATION +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +struct BinaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct UnaryVRegOperand { + OpOperand *source = nullptr; +}; + +struct BinaryMaskOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; +}; + +struct UnaryMaskOperand { + OpOperand *source = nullptr; +}; + +static std::optional getSinkableBinaryOperands(Operation *op) { + if (auto addf = dyn_cast(op)) + return BinaryVRegOperands{&addf.getLhsMutable(), &addf.getRhsMutable()}; + if (auto addi = dyn_cast(op)) + return BinaryVRegOperands{&addi.getLhsMutable(), &addi.getRhsMutable()}; + if (auto subf = dyn_cast(op)) + return BinaryVRegOperands{&subf.getLhsMutable(), &subf.getRhsMutable()}; + if (auto subi = dyn_cast(op)) + return BinaryVRegOperands{&subi.getLhsMutable(), &subi.getRhsMutable()}; + if (auto mulf = dyn_cast(op)) + return BinaryVRegOperands{&mulf.getLhsMutable(), &mulf.getRhsMutable()}; + if (auto muli = dyn_cast(op)) + return BinaryVRegOperands{&muli.getLhsMutable(), &muli.getRhsMutable()}; + if (auto divf = dyn_cast(op)) + return BinaryVRegOperands{&divf.getLhsMutable(), &divf.getRhsMutable()}; + if (auto minf = dyn_cast(op)) + return BinaryVRegOperands{&minf.getLhsMutable(), &minf.getRhsMutable()}; + if (auto maxf = dyn_cast(op)) + return BinaryVRegOperands{&maxf.getLhsMutable(), &maxf.getRhsMutable()}; + if (auto andi = dyn_cast(op)) + return BinaryVRegOperands{&andi.getLhsMutable(), &andi.getRhsMutable()}; + if (auto ori = dyn_cast(op)) + return BinaryVRegOperands{&ori.getLhsMutable(), &ori.getRhsMutable()}; + if (auto xori = dyn_cast(op)) + return BinaryVRegOperands{&xori.getLhsMutable(), &xori.getRhsMutable()}; + if (auto shli = dyn_cast(op)) + return BinaryVRegOperands{&shli.getLhsMutable(), &shli.getRhsMutable()}; + if (auto shrui = dyn_cast(op)) + return BinaryVRegOperands{&shrui.getLhsMutable(), &shrui.getRhsMutable()}; + return std::nullopt; +} + +static std::optional getSinkableUnaryOperand(Operation *op) { + if (auto negf = dyn_cast(op)) + return UnaryVRegOperand{&negf.getSourceMutable()}; + if (auto absf = dyn_cast(op)) + return UnaryVRegOperand{&absf.getSourceMutable()}; + if (auto absi = dyn_cast(op)) + return UnaryVRegOperand{&absi.getSourceMutable()}; + if (auto sqrt = dyn_cast(op)) + return UnaryVRegOperand{&sqrt.getSourceMutable()}; + if (auto exp = dyn_cast(op)) + return UnaryVRegOperand{&exp.getSourceMutable()}; + if (auto ln = dyn_cast(op)) + return UnaryVRegOperand{&ln.getSourceMutable()}; + if (auto relu = dyn_cast(op)) + return UnaryVRegOperand{&relu.getSourceMutable()}; + if (auto notOp = dyn_cast(op)) + return UnaryVRegOperand{¬Op.getSourceMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableBinaryMaskOperands(Operation *op) { + if (auto maskAnd = dyn_cast(op)) + return BinaryMaskOperands{&maskAnd.getLhsMutable(), + &maskAnd.getRhsMutable()}; + if (auto maskOr = dyn_cast(op)) + return BinaryMaskOperands{&maskOr.getLhsMutable(), + &maskOr.getRhsMutable()}; + if (auto maskXor = dyn_cast(op)) + return BinaryMaskOperands{&maskXor.getLhsMutable(), + &maskXor.getRhsMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableUnaryMaskOperand(Operation *op) { + if (auto maskNot = dyn_cast(op)) + return UnaryMaskOperand{&maskNot.getSourceMutable()}; + return std::nullopt; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp ensure, + VMIVRegType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +template +static bool isSameMaskMaterialization(EnsureOp ensure, VMIMaskType resultType) { + if (!ensure || !resultType) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto ensureResultType = dyn_cast(ensure.getResult().getType()); + if (!sourceType || !ensureResultType) + return false; + + return ensureResultType == resultType && sourceType != resultType; +} + +template +static bool isSameMaskMaterialization(EnsureOp lhsEnsure, EnsureOp rhsEnsure, + VMIMaskType resultType) { + if (!lhsEnsure || !rhsEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsResultType == rhsResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool trySinkBinaryMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +template +static bool trySinkBinaryMaskMaterialization(Operation *op) { + std::optional operands = getSinkableBinaryMaskOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!isSameMaskMaterialization(lhsEnsure, rhsEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkUnaryMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +template +static bool trySinkUnaryMaskMaterialization(Operation *op) { + std::optional operand = getSinkableUnaryMaskOperand(op); + if (!operand || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto sourceEnsure = + operand->source->get().getDefiningOp(); + if (!isSameMaskMaterialization(sourceEnsure, resultType)) + return false; + + auto sourceType = cast(sourceEnsure.getSource().getType()); + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands(sourceEnsure.getSource()); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = + builder.create(op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (sourceEnsure->use_empty()) + sourceEnsure.erase(); + return true; +} + +static bool trySinkMaskMaterialization(Operation *op) { + return trySinkBinaryMaskMaterialization(op) || + trySinkBinaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op) || + trySinkUnaryMaskMaterialization(op); +} + +struct VMILayoutSinkMaterializationPass + : public mlir::pto::impl::VMILayoutSinkMaterializationBase< + VMILayoutSinkMaterializationPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + VMILayoutSinkMaterializationPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector candidates; + module.walk([&](Operation *op) { + if (getSinkableBinaryOperands(op) || getSinkableUnaryOperand(op) || + getSinkableBinaryMaskOperands(op) || getSinkableUnaryMaskOperand(op)) + candidates.push_back(op); + }); + + for (Operation *op : candidates) { + if (op->getBlock() == nullptr) + continue; + if (!trySinkBinaryMaterialization(op)) { + if (!trySinkUnaryMaterialization(op)) + trySinkMaskMaterialization(op); + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILayoutSinkMaterializationPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILegalizeArithSelect.cpp b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp new file mode 100644 index 0000000000..471215985f --- /dev/null +++ b/lib/PTO/Transforms/VMILegalizeArithSelect.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILegalizeArithSelect.cpp - Legalize VMI arith.select ------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMILEGALIZEARITHSELECT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isVMIValueType(Type type) { + return isa(type); +} + +static bool hasScalarI1Condition(arith::SelectOp select) { + return select.getCondition().getType().isSignlessInteger(1); +} + +static void rewriteSelectToIf(arith::SelectOp select) { + OpBuilder builder(select); + auto ifOp = builder.create( + select.getLoc(), TypeRange{select.getResult().getType()}, + select.getCondition(), /*withElseRegion=*/true); + + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(select.getLoc(), select.getTrueValue()); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(select.getLoc(), select.getFalseValue()); + } + + select.getResult().replaceAllUsesWith(ifOp.getResult(0)); + select.erase(); +} + +struct VMILegalizeArithSelectPass + : public mlir::pto::impl::VMILegalizeArithSelectBase< + VMILegalizeArithSelectPass> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILegalizeArithSelectPass) + + void runOnOperation() override { + ModuleOp module = getOperation(); + SmallVector selects; + module.walk([&](arith::SelectOp select) { + if (isVMIValueType(select.getResult().getType()) && + hasScalarI1Condition(select)) + selects.push_back(select); + }); + + for (arith::SelectOp select : llvm::reverse(selects)) { + if (select->getBlock() != nullptr) + rewriteSelectToIf(select); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMILegalizeArithSelectPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp new file mode 100644 index 0000000000..7364084028 --- /dev/null +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -0,0 +1,1006 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILocalRecipeRegistry.cpp - VMI local recipe queries --------------===// +//===----------------------------------------------------------------------===// + +#include "PTO/Transforms/VMILocalRecipeRegistry.h" + +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/IR/VMIUtils.h" +#include "PTO/Transforms/VMITargetCapabilities.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "llvm/ADT/Twine.h" + +#include + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult failWithReason(const Twine &message, std::string *reason) { + if (reason) + *reason = message.str(); + return failure(); +} + +static LogicalResult checkFullDataPhysicalChunks(VMIVRegType type, + std::string *reason) { + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (failed(lanesPerPart)) + return failWithReason("requires known physical lanes per part", reason); + + FailureOr arity = getVMIPhysicalArity(type); + if (failed(arity)) + return failWithReason("requires computable physical arity", reason); + + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failWithReason("requires assigned layout", reason); + int64_t factor = layout.isDeinterleaved() ? layout.getFactor() : 1; + if (factor <= 0 || *arity % factor != 0) + return failWithReason("requires arity divisible by layout factor", reason); + + int64_t chunksPerPart = *arity / factor; + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < chunksPerPart; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failWithReason("failed to map physical padding lane", reason); + if (*padding) + return failWithReason("found padding lane in physical chunk", reason); + } + } + } + + return success(); +} + +static bool hasX2MemoryDistToken(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + if (auto constant = value.getDefiningOp()) { + if (constant.getType().isIndex()) + return constant.value(); + } + return std::nullopt; +} + +static int64_t ceilDivNonNegative(int64_t lhs, int64_t rhs) { + assert(lhs >= 0 && rhs > 0); + return (lhs + rhs - 1) / rhs; +} + +static FailureOr getVMITypeElementCount(Type type) { + if (auto vregType = dyn_cast(type)) + return vregType.getElementCount(); + if (auto maskType = dyn_cast(type)) + return maskType.getElementCount(); + return failure(); +} + +static FailureOr getVMITypeLayoutFactor(Type type) { + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return failure(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr getVMITypeLanesPerPart(Type type) { + if (auto vregType = dyn_cast(type)) + return getDataLanesPerPart(vregType.getElementType()); + if (auto maskType = dyn_cast(type)) + return getMaskLanesPerPart(maskType.getGranularity()); + return failure(); +} + +static FailureOr getVMITypeChunksInPart(Type type, int64_t part) { + FailureOr elementCount = getVMITypeElementCount(type); + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(elementCount) || failed(factor) || failed(lanesPerPart) || + part < 0 || part >= *factor) + return failure(); + + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; + return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); +} + +static LogicalResult checkFullVMIPhysicalChunks(Type type, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr factor = getVMITypeLayoutFactor(type); + FailureOr lanesPerPart = getVMITypeLanesPerPart(type); + if (failed(factor) || failed(lanesPerPart)) + return fail("requires assigned layout with known physical lanes per part"); + + for (int64_t part = 0; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks"); + for (int64_t chunk = 0; chunk < *chunks; ++chunk) { + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return fail("failed to map physical padding lane"); + if (*padding) + return fail("found padding lane in physical chunk"); + } + } + } + + return success(); +} + +static FailureOr +getContiguousMaterializationPartCount(Type type, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr arity = getVMIPhysicalArity(type); + FailureOr factor = getVMITypeLayoutFactor(type); + if (failed(arity) || failed(factor)) + return fail("requires computable physical arity and assigned layout"); + + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + else + return fail("requires VMI data or mask type"); + + if (!layout) + return fail("requires assigned layout"); + if (layout.isContiguous()) + return *arity; + if (!layout.isDeinterleaved() || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return fail("requires contiguous or deinterleaved=2/4 layout"); + + FailureOr chunksPerGroup = getVMITypeChunksInPart(type, 0); + if (failed(chunksPerGroup)) + return fail("requires known physical chunks per part"); + if (*chunksPerGroup == 0) + return fail("requires at least one physical chunk per part"); + + for (int64_t part = 1; part < *factor; ++part) { + FailureOr chunks = getVMITypeChunksInPart(type, part); + if (failed(chunks)) + return fail("requires known physical chunks per part"); + if (*chunks != *chunksPerGroup) + return fail("requires every deinterleaved part to have the same " + "physical chunk count"); + } + + return *arity; +} + +static LogicalResult checkLayoutMaterializationShape(Type sourceType, + Type resultType, + VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + if (sourceLayout == resultLayout) + return success(); + + std::string sourceReason; + std::string resultReason; + LogicalResult sourceFull = + checkFullVMIPhysicalChunks(sourceType, &sourceReason); + LogicalResult resultFull = + checkFullVMIPhysicalChunks(resultType, &resultReason); + if (succeeded(sourceFull) && succeeded(resultFull)) + return success(); + + std::string sourceMaterializationReason; + FailureOr sourceMaterializedParts = + getContiguousMaterializationPartCount(sourceType, + &sourceMaterializationReason); + std::string resultMaterializationReason; + FailureOr resultMaterializedParts = + getContiguousMaterializationPartCount(resultType, + &resultMaterializationReason); + if (succeeded(sourceMaterializedParts) && + succeeded(resultMaterializedParts) && + *sourceMaterializedParts == *sourceArity && + *resultMaterializedParts == *resultArity) + return success(); + + if (failed(sourceFull)) + return fail(Twine("source ") + sourceReason + "; source materialization " + + sourceMaterializationReason); + return fail(Twine("result ") + resultReason + "; result materialization " + + resultMaterializationReason); +} + +static FailureOr getGroupSizeFromNumGroups(VMIVRegType type, + int64_t numGroups, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + if (numGroups <= 0) + return fail("requires num_groups to be positive"); + if (type.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide logical lane count"); + return type.getElementCount() / numGroups; +} + +static FailureOr getDataLayoutFactor(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return failure(); + return layout.isDeinterleaved() ? layout.getFactor() : 1; +} + +static FailureOr> +getPhysicalLogicalBitFootprint(VMIVRegType type) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (elementBits == 0) + return failure(); + + FailureOr factor = getDataLayoutFactor(type); + FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); + FailureOr arity = getVMIPhysicalArity(type); + if (failed(factor) || failed(lanesPerPart) || failed(arity) || *factor <= 0) + return failure(); + + SmallVector bits; + bits.reserve(*arity); + for (int64_t part = 0; part < *factor; ++part) { + for (int64_t chunk = 0; chunk < *arity; ++chunk) { + int64_t activeLanes = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(type, part, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++activeLanes; + } + if (activeLanes > 0) + bits.push_back(activeLanes * static_cast(elementBits)); + } + } + if (static_cast(bits.size()) != *arity) + return failure(); + return bits; +} + +static FailureOr +getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout == resultLayout) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::Identity}; + if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::ContiguousToDeinterleaved}; + if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) + return VMILayoutMaterializationRecipe{ + VMILayoutMaterializationRecipeKind::DeinterleavedToContiguous}; + return fail("unsupported source/result layout pair"); +} + +} // namespace + +FailureOr +VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout) + return fail("requires assigned value layout"); + if (layout.isContiguous()) + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::ContiguousVsts}; + if (!layout.isDeinterleaved()) + return fail("requires contiguous or deinterleaved value layout"); + if (layout.getBlockElems() != 1) + return fail("requires block_elems=1 deinterleaved value layout"); + if (failed(checkFullDataPhysicalChunks(valueType, reason))) + return failure(); + + if (layout.getFactor() == 2) { + if (!hasX2MemoryDistToken(valueType.getElementType())) + return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2}; + } + + if (layout.getFactor() == 4) + return VMIContiguousStoreRecipe{ + VMIContiguousStoreRecipeKind::DeinterleavedMaterializeThenVsts}; + + return fail("requires deinterleaved factor 2 or 4"); +} + +LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (sourceType.getElementType() != resultType.getElementType()) + return failWithReason("source/result element types must match", reason); + if (sourceType.getElementCount() != resultType.getElementCount()) + return failWithReason("source/result element counts must match", reason); + + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!resultLayout || !resultLayout.isContiguous()) + return failWithReason("result layout must be contiguous", reason); + + FailureOr recipe = + getContiguousStoreRecipe(sourceType, reason); + if (failed(recipe)) + return failure(); + if (recipe->kind == VMIContiguousStoreRecipeKind::ContiguousVsts) + return failWithReason("source layout is already contiguous", reason); + + return success(); +} + +FailureOr +VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementType() != resultType.getElementType()) + return fail("source/result element types must match"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result element counts must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr recipe = + getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); + if (failed(recipe)) + return failure(); + if (failed(checkLayoutMaterializationShape(sourceType, resultType, + sourceLayout, resultLayout, + reason))) + return failure(); + return recipe; +} + +FailureOr +VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getGranularity() != resultType.getGranularity()) + return fail("source/result mask granularities must match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr recipe = + getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); + if (failed(recipe)) + return failure(); + if (failed(checkLayoutMaterializationShape(sourceType, resultType, + sourceLayout, resultLayout, + reason))) + return failure(); + return recipe; +} + +FailureOr +VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) + -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("source/result mask element counts must match"); + if (sourceType.getLayoutAttr() != resultType.getLayoutAttr()) + return fail("source/result mask layouts must match"); + if (!VMIMaskType::isConcreteGranularity(sourceType.getGranularity()) || + !VMIMaskType::isConcreteGranularity(resultType.getGranularity())) + return fail("requires concrete b8/b16/b32 source and result granularities"); + if (sourceType.getGranularity() == resultType.getGranularity()) + return VMIMaskGranularityMaterializationRecipe{ + VMIMaskGranularityMaterializationRecipeKind::Identity}; + + return VMIMaskGranularityMaterializationRecipe{ + VMIMaskGranularityMaterializationRecipeKind::PredicateCast}; +} + +FailureOr +VMILocalRecipeRegistry::getGroupSlotLoadRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || + layout.getNumGroups() != op.getNumGroupsAttr().getInt() || + layout.getSlots() <= 0) + return fail("requires explicit group_slots result layout matching " + "num_groups"); + + if (layout.getSlots() != 8 && layout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_slot_load layouts"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for vsldb lowering"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (layout.getSlots() == 8) { + if (!stride || *stride != 1) + return fail("slots=8 group_slot_load requires constant unit " + "source_group_stride"); + return VMIGroupSlotLoadRecipe{ + VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb}; + } + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_slot_load requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_slot_load currently lowers as one " + "lane-0 vsldb per group and requires constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B load alignment; packed or unaligned " + "scalar load lowering is not implemented"); + + return VMIGroupSlotLoadRecipe{ + VMIGroupSlotLoadRecipeKind::Slots1AlignedLane0Vsldb}; +} + +FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isDeinterleaved() || layout.getBlockElems() != 8 || + !resultType.getElementType().isF32()) + return fail("requires deinterleaved block8 f32 result layout"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), + reason); + if (failed(groupSize)) + return failure(); + + if ((*groupSize != 16 || layout.getFactor() != 2) && + (*groupSize != 32 || layout.getFactor() != 4)) + return fail("block8 strided group_load requires S=16/factor=2 or " + "S=32/factor=4"); + + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("block8 strided group_load requires !pto.ptr source"); + + if (op.getNumGroupsAttr().getInt() % 8 != 0) + return fail("block8 strided group_load requires num_groups multiple of 8"); + + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) + return fail("block8 strided group_load requires constant positive " + "row_stride divisible by 8 f32 elements"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("block8 strided group_load requires full physical " + "result chunks; ") + + fullChunkReason); + + if (*groupSize == 16) + return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S16Block8Vsldb}; + return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S32Block8Vsldb}; +} + +FailureOr +VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return fail("requires group_slots value layout"); + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (layout.getNumGroups() != numGroups) + return fail("group_slots group_store requires layout num_groups to " + "match op num_groups"); + + VMICapabilityResult elementCapability = capabilities.supportsElementType( + valueType.getElementType(), VMIElementPurpose::PredicateMask); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + + FailureOr arity = getVMIPhysicalArity(valueType); + if (failed(arity) || *arity < 1) + return fail("requires computable non-empty physical vreg parts"); + + if (layout.getSlots() == 1) { + if (*arity != numGroups) + return fail("slots=1 group_store requires one physical part per " + "group"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return fail("slots=1 group_store requires an 8/16/32-bit element " + "type"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) + return fail(Twine("slots=1 group_store currently lowers as one " + "lane-0 vsts per group and requires constant " + "positive row_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B store alignment; packed or unaligned " + "contiguous store lowering is not implemented"); + return VMIGroupSlotsStoreRecipe{ + VMIGroupSlotsStoreRecipeKind::Slots1AlignedLane0Vsts}; + } + + if (layout.getSlots() == 8) { + std::optional rowStride = getConstantIndexValue(op.getRowStride()); + if (!rowStride || *rowStride != 1) + return fail("slots=8 group_store currently requires constant unit " + "row_stride"); + if (*arity != ceilDivNonNegative(numGroups, 8)) + return fail("slots=8 group_store arity must equal ceil(num_groups / " + "8)"); + return VMIGroupSlotsStoreRecipe{ + VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts}; + } + + return fail("group_slots group_store currently supports only slots=1 or " + "unit-stride slots=8"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddFRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!op->hasAttr("reassoc")) + return fail("requires reassoc attr for pair-wise floating-point " + "reduction"); + + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned source, mask, and result layouts"); + if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups) + return fail("requires group_slots result layout matching num_groups"); + if (resultLayout.getSlots() != 8 && resultLayout.getSlots() != 1) { + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && + *groupSize != 8 && *groupSize != 16 && *groupSize != 32) + return fail("stable group_reduce_addf slots=8 recipes support group " + "size 8, 16, or 32"); + return fail("stable group_reduce_addf local recipes currently require " + "result layout slots=8 or slots=1"); + } + + VMICapabilityResult elementCapability = + capabilities.supportsReductionElementType(VMIReductionKind::AddF, + sourceType.getElementType()); + if (!elementCapability.isSupported()) + return fail(elementCapability.reason); + if (!sourceType.getElementType().isF32() || + sourceType.getElementType() != resultType.getElementType()) + return fail("stable group_reduce_addf local recipes require f32 " + "source/result"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) + return fail("requires computable source/mask/result physical arity"); + if (*sourceArity < 1 || *maskArity != *sourceArity) + return fail("requires matching non-empty source/mask physical arity"); + + if (resultLayout.getSlots() == 1) { + if (*groupSize != 64) + return fail("stable group_reduce_addf slots=1 recipes support group " + "size 64"); + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("s64 group_reduce_addf requires contiguous source/mask " + "layouts"); + if (*resultArity != *sourceArity) + return fail("s64 group_reduce_addf requires source/result physical " + "arity to match"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("s64 group_reduce_addf requires full source chunks; ") + + sourceFullReason); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S64ContiguousVcaddRows}; + } + + if (*groupSize == 8) { + if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("s8 group_reduce_addf requires contiguous source/mask " + "layouts"); + std::string sourceFullReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) + return fail(Twine("s8 group_reduce_addf requires full source chunks; ") + + sourceFullReason); + if (*resultArity != *sourceArity) + return fail("s8 group_reduce_addf requires source/result physical " + "arity to match"); + return VMIGroupReduceAddFRecipe{VMIGroupReduceAddFRecipeKind::S8Vcgadd}; + } + + if (*groupSize == 16) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("s16 group_reduce_addf requires source layout " + "deinterleaved=2 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s16 group_reduce_addf requires matching mask layout " + "deinterleaved=2 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || + *sourceArity != *resultArity * 2) + return fail("s16 group_reduce_addf requires two source/mask parts per " + "result part"); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S16Deinterleaved2VcgaddVadd}; + } + + if (*groupSize == 32) { + if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || + (sourceLayout.getBlockElems() != 1 && + sourceLayout.getBlockElems() != 8)) + return fail("s32 group_reduce_addf requires source layout " + "deinterleaved=4 with block_elems=1 or block_elems=8"); + if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || + maskLayout.getBlockElems() != sourceLayout.getBlockElems()) + return fail("s32 group_reduce_addf requires matching mask layout " + "deinterleaved=4 with the same block_elems"); + int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); + if (*resultArity != expectedResultArity || + *sourceArity != *resultArity * 4) + return fail("s32 group_reduce_addf requires four source/mask parts per " + "result part"); + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::S32Deinterleaved4VcgaddTree}; + } + + return fail("stable group_reduce_addf slots=8 recipes support group size " + "8, 16, or 32"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupBroadcastRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, + std::string *reason) const { + (void)capabilities; + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + if (sourceType.getElementType() != resultType.getElementType() || + sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result shape and element type to match"); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != numGroups) + return fail("requires matching num_groups source layout"); + if (resultLayout.isGroupSlots()) + return fail("requires dense result layout"); + if (sourceLayout.getSlots() > 0 && sourceLayout.getSlots() != 8 && + sourceLayout.getSlots() != 1) + return fail("supports only slots=8 or slots=1 group_broadcast source " + "layouts"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) + return fail(Twine("requires full source physical chunks; ") + + fullChunkReason); + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + FailureOr resultLanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || failed(resultLanesPerPart) || + *lanesPerPart != *resultLanesPerPart) + return fail("requires matching physical lanes per part"); + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) + return fail("requires derived group size to divide or be a multiple of " + "physical lanes per part"); + + FailureOr resultFactor = getDataLayoutFactor(resultType); + if (failed(resultFactor)) + return fail("requires known result layout factor"); + if (*resultFactor == 1) + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + + bool blockFragmentSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < *lanesPerPart && + *lanesPerPart % resultLayout.getBlockElems() == 0; + if (blockFragmentSmallGroup) + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; + if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) + return fail("deinterleaved result requires every physical result chunk to " + "stay within one logical group"); + + return VMIGroupBroadcastRecipe{ + VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; +} + +FailureOr +VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + !sourceType.getElementType().isF32() || resultBits != 16 || + *sourceArity != *resultArity) + return fail("group-slot truncf requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "f32 source, f16 result, and matching physical arity"); + return VMITruncFRecipe{VMITruncFRecipeKind::GroupSlots1F32ToF16}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + !sourceType.getElementType().isF32() || *resultArity != 1) + return fail("requires f32 deinterleaved source and contiguous result"); + + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return VMITruncFRecipe{ + VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16}; + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) + return VMITruncFRecipe{ + VMITruncFRecipeKind::Deinterleaved4F32ToContiguousF8}; + + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); +} + +FailureOr +VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + !resultType.getElementType().isF32()) + return fail("requires contiguous source layout and deinterleaved f32 " + "result layout"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + if (sourceBits == 16 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return VMIExtFRecipe{ + VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32}; + if (sourceBits == 8 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return VMIExtFRecipe{ + VMIExtFRecipeKind::ContiguousF8ToDeinterleaved4F32}; + + return fail("unsupported extf source element width, result factor, or " + "physical arity"); +} + +FailureOr +VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source and result layouts"); + if (sourceLayout != resultLayout) + return fail("requires matching source and result layouts"); + if (sourceLayout.isGroupSlots()) + return fail("does not support group_slots layouts"); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source and result physical arity"); + if (*sourceArity != *resultArity) + return fail("requires source and result to have the same physical arity"); + + FailureOr> sourceBits = + getPhysicalLogicalBitFootprint(sourceType); + FailureOr> resultBits = + getPhysicalLogicalBitFootprint(resultType); + if (failed(sourceBits) || failed(resultBits)) + return fail("requires computable physical logical bit footprints"); + if (sourceBits->size() != resultBits->size()) + return fail("requires source and result physical footprint counts to " + "match"); + for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { + if (source != result) + return fail("requires matching logical bit footprint in every physical " + "chunk"); + } + + return VMIBitcastRecipe{VMIBitcastRecipeKind::PerPartVbitcast}; +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 5b050d640a..a59b5dbadb 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -18,6 +18,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILocalRecipeRegistry.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1200,24 +1201,9 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - if ((*groupSize != 16 || resultLayout.getFactor() != 2) && - (*groupSize != 32 || resultLayout.getFactor() != 4)) - return fail("block8 strided group_load requires S=16/factor=2 or " - "S=32/factor=4"); - if (!isa(op.getSource().getType())) - return fail("block8 strided group_load requires !pto.ptr source"); - if (op.getNumGroupsAttr().getInt() % 8 != 0) - return fail("block8 strided group_load requires num_groups multiple " - "of 8"); - std::optional rowStride = getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride <= 0 || *rowStride % 8 != 0) - return fail("block8 strided group_load requires constant positive " - "row_stride divisible by 8 f32 elements"); - std::string fullChunkReason; - if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) - return fail(Twine("block8 strided group_load requires full physical " - "result chunks; ") + - fullChunkReason); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupLoadRecipe(capabilities, op, reason))) + return failure(); return success(); } @@ -1227,55 +1213,10 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, LogicalResult checkSupportedGroupSlotLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupSlotLoadRecipe(capabilities, op, reason))) return failure(); - }; - - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr layout = resultType.getLayoutAttr(); - if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != op.getNumGroupsAttr().getInt() || - layout.getSlots() <= 0) - return fail("requires explicit group_slots result layout matching " - "num_groups"); - - if (layout.getSlots() != 8 && layout.getSlots() != 1) - return fail("supports only slots=8 or slots=1 group_slot_load layouts"); - - if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") - .isSupported()) - return fail("requires supported direct memory source"); - if (!isa(op.getSource().getType())) - return fail("requires !pto.ptr source for vsldb lowering"); - if (layout.getSlots() == 8) { - std::optional stride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride != 1) - return fail("slots=8 group_slot_load requires constant unit " - "source_group_stride"); - return success(); - } - if (layout.getSlots() == 1) { - unsigned elementBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (elementBits == 0 || 256 % elementBits != 0) - return fail("slots=1 group_slot_load requires an 8/16/32-bit element " - "type"); - int64_t alignedStrideElems = 256 / elementBits; - std::optional stride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride <= 0 || *stride % alignedStrideElems != 0) - return fail(Twine("slots=1 group_slot_load currently lowers as one " - "lane-0 vsldb per group and requires constant " - "positive source_group_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B load alignment; packed or unaligned " - "scalar load lowering is not implemented"); - return success(); - } - llvm_unreachable("unsupported group_slot_load slots should be rejected"); + return success(); } LogicalResult @@ -1301,46 +1242,10 @@ checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); - if (failed(checkSupportedMaskableVReg(capabilities, valueType, reason))) + VMILocalRecipeRegistry recipes; + if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, op, reason))) return failure(); - - FailureOr arity = getVMIPhysicalArity(valueType); - if (failed(arity)) - return fail("requires computable physical arity"); - if (layout.getSlots() == 1) { - if (*arity != numGroups) - return fail("slots=1 group_store requires one physical part per " - "group"); - unsigned elementBits = - pto::getPTOStorageElemBitWidth(valueType.getElementType()); - if (elementBits == 0 || 256 % elementBits != 0) - return fail("slots=1 group_store requires an 8/16/32-bit element " - "type"); - int64_t alignedStrideElems = 256 / elementBits; - std::optional rowStride = - getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) - return fail(Twine("slots=1 group_store currently lowers as one " - "lane-0 vsts per group and requires constant " - "positive row_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B store alignment; packed or unaligned " - "contiguous store lowering is not implemented"); - return success(); - } - if (layout.getSlots() == 8) { - std::optional rowStride = - getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride != 1) - return fail("slots=8 group_store currently requires constant unit " - "row_stride"); - if (*arity != ceilDivNonNegative(numGroups, 8)) - return fail("slots=8 group_store arity must equal ceil(num_groups / " - "8)"); - return success(); - } - return fail("group_slots group_store currently supports only slots=1 or " - "unit-stride slots=8"); + return success(); } FailureOr groupSize = getGroupSizeFromNumGroups( @@ -3309,6 +3214,13 @@ struct OneToNVMIEnsureLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getDataLayoutMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_layout has no registered materialization recipe: ") + + recipeReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3336,6 +3248,14 @@ struct OneToNVMIEnsureMaskLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getMaskLayoutMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, + Twine("ensure_mask_layout has no registered materialization recipe: ") + + recipeReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3367,6 +3287,14 @@ struct OneToNVMIEnsureMaskGranularityOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); + VMILocalRecipeRegistry recipes; + std::string recipeReason; + if (failed(recipes.getMaskGranularityMaterializationRecipe( + sourceType, resultType, &recipeReason))) + return rewriter.notifyMatchFailure( + op, Twine("ensure_mask_granularity has no registered materialization " + "recipe: ") + + recipeReason); if (sourceType.getLayout() != resultType.getLayout()) return rewriter.notifyMatchFailure( op, "mask granularity helper cannot also change layout"); @@ -4549,7 +4477,6 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { op, "store requires known physical lanes per part"); bool fullPhysicalChunks = succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); - FailureOr destination = getSingleValue(op, adaptor.getDestination(), "store destination must convert to one value", rewriter); @@ -4560,9 +4487,12 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { return failure(); ValueRange valueParts = adaptor.getValue(); - VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); - if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && - valueLayout.getFactor() == 2) { + VMILocalRecipeRegistry localRecipes; + FailureOr storeRecipe = + localRecipes.getContiguousStoreRecipe(valueVMIType); + if (succeeded(storeRecipe) && + storeRecipe->kind == + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -5007,7 +4937,6 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { op, "tile_write requires known physical lanes per part"); bool fullPhysicalChunks = succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); - FailureOr destination = getSingleValue( op, adaptor.getDestination(), "tile_write destination must convert to one value", rewriter); @@ -5016,9 +4945,12 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { ValueRange valueParts = adaptor.getValue(); Value zero = rewriter.create(op.getLoc(), 0); - VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); - if (fullPhysicalChunks && valueLayout && valueLayout.isDeinterleaved() && - valueLayout.getFactor() == 2) { + VMILocalRecipeRegistry localRecipes; + FailureOr storeRecipe = + localRecipes.getContiguousStoreRecipe(valueVMIType); + if (succeeded(storeRecipe) && + storeRecipe->kind == + VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -6862,145 +6794,26 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { return failure(result.wasInterrupted()); } -LogicalResult checkSupportedExtFShape(VMIExtFOp op) { - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity) || !sourceLayout.isContiguous() || - !resultLayout.isDeinterleaved() || !resultType.getElementType().isF32()) +LogicalResult checkSupportedExtFShape(VMIExtFOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtFRecipe(op, reason))) return failure(); - - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - if (sourceBits == 16 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return success(); - if (sourceBits == 8 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return success(); - return failure(); + return success(); } LogicalResult checkSupportedTruncFShape(VMITruncFOp op, std::string *reason = nullptr) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getTruncFRecipe(op, reason))) return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (!sourceLayout || !resultLayout || failed(sourceArity) || - failed(resultArity)) - return fail("requires assigned source/result layouts and computable " - "physical arity"); - - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - - if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { - if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || - sourceLayout.getNumGroups() != resultLayout.getNumGroups() || - sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || - !sourceType.getElementType().isF32() || resultBits != 16 || - *sourceArity != *resultArity) - return fail("group-slot truncf requires matching " - "group_slots(num_groups=G, slots=1) source/result layouts, " - "f32 source, f16 result, and matching physical arity"); - - return success(); - } - - if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - !sourceType.getElementType().isF32() || *resultArity != 1) - return fail("requires f32 deinterleaved source and contiguous result"); - - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return success(); - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) - return success(); - return fail("unsupported deinterleaved truncf factor, arity, or result " - "element width"); -} - -FailureOr> -getPhysicalLogicalBitFootprint(VMIVRegType type) { - unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); - if (elementBits == 0) - return failure(); - - FailureOr factor = getDataLayoutFactor(type); - FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); - if (failed(factor) || failed(lanesPerPart)) - return failure(); - - SmallVector bits; - for (int64_t part = 0; part < *factor; ++part) { - FailureOr chunks = getDataChunksInPart(type, part); - if (failed(chunks)) - return failure(); - for (int64_t chunk = 0; chunk < *chunks; ++chunk) { - int64_t activeLanes = 0; - for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { - FailureOr padding = isPaddingLane(type, part, chunk, lane); - if (failed(padding)) - return failure(); - if (!*padding) - ++activeLanes; - } - bits.push_back(activeLanes * static_cast(elementBits)); - } - } - return bits; + return success(); } LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); + VMILocalRecipeRegistry recipes; + if (failed(recipes.getBitcastRecipe(op, reason))) return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - if (!sourceLayout || !resultLayout) - return fail("requires assigned source and result layouts"); - if (sourceLayout != resultLayout) - return fail("requires matching source and result layouts"); - - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(resultArity)) - return fail("requires computable source and result physical arity"); - if (*sourceArity != *resultArity) - return fail("requires source and result to have the same physical arity"); - - FailureOr> sourceBits = - getPhysicalLogicalBitFootprint(sourceType); - FailureOr> resultBits = - getPhysicalLogicalBitFootprint(resultType); - if (failed(sourceBits) || failed(resultBits)) - return fail("requires computable physical logical bit footprints"); - if (sourceBits->size() != resultBits->size()) - return fail("requires source and result physical footprint counts to " - "match"); - for (auto [source, result] : llvm::zip_equal(*sourceBits, *resultBits)) { - if (source != result) - return fail("requires matching logical bit footprint in every physical " - "chunk"); - } - return success(); } @@ -7315,6 +7128,10 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (!sourceLayout || !resultLayout || !maskLayout) return fail("requires assigned source, mask, and result layouts"); + VMILocalRecipeRegistry recipes; + if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + return success(); + FailureOr groupSize = getGroupSizeFromNumGroups( sourceType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) @@ -7381,6 +7198,9 @@ LogicalResult checkSupportedGroupBroadcastShape( VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); + VMILocalRecipeRegistry recipes; + if (succeeded(recipes.getGroupBroadcastRecipe(capabilities, op, nullptr))) + return success(); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) return fail("requires matching num_groups source layout"); @@ -8045,7 +7865,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, } if (auto extf = dyn_cast(op)) { - if (succeeded(checkSupportedExtFShape(extf))) + std::string reason; + if (succeeded(checkSupportedExtFShape(extf, &reason))) return WalkResult::advance(); extf.emitError() @@ -8053,7 +7874,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << "pto.vmi.extf supports contiguous 16-bit float-like or fp8-like " "physical source chunks to f32 deinterleaved=2/4 results; " "partial/tail is allowed only when source padding maps to result " - "padding"; + "padding (" + << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto index af6623a995..ec29b4387a 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -18,13 +18,8 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.store operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store operand #0 has type !pto.vmi.vreg<64xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<64xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' (unsupported source/result layout pair) pto.vmi.store %sum, %dst[%off] : !pto.vmi.vreg<64xf32>, !pto.ptr return diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto index c928df5320..6ed4e7f9e7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -26,13 +26,8 @@ module { -> !pto.vmi.vreg<128xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %dense_dst[%off] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index 3bea54d83f..eccb4e0007 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -15,9 +15,8 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf lowers through pto.vcgadd - // CHECK-SAME: num_groups deriving a group size aligned to physical chunks - // CHECK-SAME: found padding lane in physical chunk + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto index c66ff0eb3c..cface43bab 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -15,13 +15,12 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.group_reduce_addf operand #0 has type + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf operand #0 has type // CHECK-SAME: #pto.vmi.layout // CHECK-SAME: requires // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion + // CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe // CHECK: requires source and result to have the same physical arity - // CHECK-SAME: partial/tail layout materialization requires an explicit packing plan %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 9f629f55f2..94cd55c58c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -18,9 +18,10 @@ module { } func.func @vmi_layout_assignment_group_slot_load_slots1( - %src: !pto.ptr, %off: index, %stride: index) + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<512xf32> { - %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} + %c8 = arith.constant 8 : index + %out = pto.vmi.group_slot_load %src[%off], %c8 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<512xf32> return %out : !pto.vmi.vreg<512xf32> } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto index e6e459c435..b8cd439d23 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto index f8d7bc8af8..b432d7c68c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( %src: !pto.ptr, %off: index) { %c2 = arith.constant 2 : index - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_slot_load + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto index 452ee085ac..996760ed66 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<512xf32> - // CHECK: VMI-UNSUPPORTED: pto.vmi.group_store + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group // CHECK-SAME: requires constant positive row_stride divisible by 8 elements // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index 3005e53c0a..e57954b16e 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -19,13 +19,8 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<128xf32> - // CHECK: {{VMI-UNSUPPORTED}}: pto.vmi.truncf operand #0 has type - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: requires - // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion - // CHECK-SAME: unsupported source/result layout pair + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} diff --git a/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto b/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto new file mode 100644 index 0000000000..84ba3b5b1e --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %value_c, %dst[%offset] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_layout_fold_consumers_masked_store_deint4( + %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.mask<256xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[VALUE]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.vintlv +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.pintlv_b32 +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto b/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto new file mode 100644 index 0000000000..8f31b78f7b --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto @@ -0,0 +1,57 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_masked_store( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %dst: !pto.ptr, + %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// LOWER-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// LOWER-SAME: %[[M0:[^,]+]]: !pto.mask +// LOWER-SAME: %[[M1:[^,]+]]: !pto.mask +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[V0]], %[[V1]] +// LOWER: %[[ML:.*]], %[[MH:.*]] = pto.pintlv_b32 %[[M0]], %[[M1]] +// LOWER: pto.vsts %[[LOW]] +// LOWER-SAME: %[[ML]] +// LOWER: pto.vsts %[[HIGH]] +// LOWER-SAME: %[[MH]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_fold_consumers_store.pto b/test/lit/vmi/vmi_layout_fold_consumers_store.pto new file mode 100644 index 0000000000..281d737861 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_consumers_store.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_consumers_store( + %src: !pto.vmi.vreg<128xf16>, + %scale: f32, + %out1: !pto.ptr, + %out2: !pto.ptr, + %offset: index) { + %scale_v = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32> + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %prod = pto.vmi.mulf %wide, %scale_v + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %prod, %out1[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + pto.vmi.store %wide, %out2[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } + + func.func @vmi_layout_fold_consumers_tile_write( + %src: !pto.vmi.vreg<128xf16>, + %dst: memref<128xf32>) { + %wide = pto.vmi.extf %src + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.tile_write %wide, %dst + : !pto.vmi.vreg<128xf32>, memref<128xf32> + return + } + +} + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store( +// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// FOLD: %[[SCALE:.*]] = pto.vmi.broadcast +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[PROD:.*]] = pto.vmi.mulf %[[WIDE]], %[[SCALE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[PROD]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[WIDE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store( +// LOWER: %[[SCALE0:.*]] = pto.vdup +// LOWER: %[[SCALE1:.*]] = pto.vdup +// LOWER: %[[WIDE0:.*]] = pto.vcvt +// LOWER: %[[WIDE1:.*]] = pto.vcvt +// LOWER: %[[PROD0:.*]] = pto.vmul %[[WIDE0]], %[[SCALE0]] +// LOWER: %[[PROD1:.*]] = pto.vmul %[[WIDE1]], %[[SCALE1]] +// LOWER-NOT: pto.vintlv +// LOWER: pto.vstsx2 %[[PROD0]], %[[PROD1]] +// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + + +// FOLD-LABEL: func.func @vmi_layout_fold_consumers_tile_write( +// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.tile_write %[[WIDE]] +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return + +// LOWER-LABEL: func.func @vmi_layout_fold_consumers_tile_write( +// LOWER: %[[WIDE0:.*]] = pto.vcvt +// LOWER: %[[WIDE1:.*]] = pto.vcvt +// LOWER-NOT: pto.vintlv +// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto new file mode 100644 index 0000000000..e63567e48d --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_group_slots_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK-SAME: does not support group_slots layouts + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto new file mode 100644 index 0000000000..806aaa26dd --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_bitcast_recipe_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK-SAME: requires matching logical bit footprint in every physical chunk + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto new file mode 100644 index 0000000000..7bda214fed --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_extf_recipe_invalid( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered local recipe + // CHECK-SAME: requires contiguous source layout and deinterleaved f32 result layout + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" + %out = pto.vmi.extf %source + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto new file mode 100644 index 0000000000..224858064c --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_broadcast_recipe_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered local recipe + // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" + %out = pto.vmi.group_broadcast %source {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto new file mode 100644 index 0000000000..8f9fb2c809 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_load_recipe_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 local recipe + // CHECK-SAME: block8 strided group_load requires constant positive row_stride divisible by 8 f32 elements + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_load" + %out = pto.vmi.group_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto new file mode 100644 index 0000000000..673f3ee47b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_recipe_invalid( + %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<96xf32, #pto.vmi.layout>, + !pto.vmi.mask<96xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<96xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto new file mode 100644 index 0000000000..f4071e4c47 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_reduce_slots1_recipe_invalid( + %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group size 64 + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto new file mode 100644 index 0000000000..31e7f13c3e --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_slot_load_recipe_invalid( + %src: !pto.ptr, %off: index, %stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK-SAME: slots=8 group_slot_load requires constant unit source_group_stride + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" + %out = pto.vmi.group_slot_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto new file mode 100644 index 0000000000..b8576fe3b7 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_slots2_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_group_reduce_slots2_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe + // CHECK-SAME: stable group_reduce_addf local recipes currently require result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto new file mode 100644 index 0000000000..c7003a887d --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_group_store_recipe_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index, %row_stride: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride + // CHECK: note: see current operation: "pto.vmi.group_store" + pto.vmi.group_store %value, %dst[%off], %row_stride + {num_groups = 8} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto new file mode 100644 index 0000000000..53cc5c2a12 --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_ensure_layout_shape_invalid( + %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_ensure_mask_layout_shape_invalid( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization recipe + // CHECK-SAME: requires source and result to have the same physical arity + %dense = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto new file mode 100644 index 0000000000..871e14eb5b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_helper_recipe_invalid( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %bad = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK-SAME: unsupported source/result layout pair diff --git a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto new file mode 100644 index 0000000000..3877eb1a3a --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_store_deint_tail_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory local recipe + // CHECK-SAME: requires arity divisible by layout factor + pto.vmi.store %value, %dst[%offset] + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// ----- + +module { + func.func @vmi_layout_gate_tile_write_deint_tail_invalid( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + %dst: memref<129xf32>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory local recipe + // CHECK-SAME: requires arity divisible by layout factor + pto.vmi.tile_write %value, %dst + : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, + memref<129xf32> + return + } +} diff --git a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto new file mode 100644 index 0000000000..68e7963b1b --- /dev/null +++ b/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_gate_truncf_recipe_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered local recipe + // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) + // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" + %out = pto.vmi.truncf %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + return + } +} diff --git a/test/lit/vmi/vmi_layout_rematerialize_data.pto b/test/lit/vmi/vmi_layout_rematerialize_data.pto new file mode 100644 index 0000000000..29faa34fb1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_data.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_data( + %scalar: f32, + %base: f32) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %broadcast = pto.vmi.broadcast %scalar + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %broadcast_deint = pto.vmi.ensure_layout %broadcast + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %iota = pto.vmi.iota %base + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %iota_deint = pto.vmi.ensure_layout %iota + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + %constant = "pto.vmi.constant"() { + value = dense<1.000000e+00> : tensor<128xf32> + } : () -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %constant_deint = pto.vmi.ensure_layout %constant + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + + return %broadcast_deint, %iota_deint, %constant_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_data( +// CHECK: %[[BCAST:.*]] = pto.vmi.broadcast %arg0 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[IOTA:.*]] = pto.vmi.iota %arg1 : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[CONST:.*]] = "pto.vmi.constant"(){{.*}}dense<1.000000e+00> : tensor<128xf32>{{.*}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_layout +// CHECK: return %[[BCAST]], %[[IOTA]], %[[CONST]] diff --git a/test/lit/vmi/vmi_layout_rematerialize_mask.pto b/test/lit/vmi/vmi_layout_rematerialize_mask.pto new file mode 100644 index 0000000000..6c3bb60053 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_mask.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize | FileCheck %s + +module { + func.func @vmi_layout_rematerialize_mask(%active: index) + -> (!pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %mask = pto.vmi.create_mask %active + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %mask_b16 = pto.vmi.ensure_mask_granularity %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %group_mask = pto.vmi.create_group_mask %active + {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %group_mask_deint = pto.vmi.ensure_mask_layout %group_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + + %constant_mask = "pto.vmi.constant_mask"() { + value = dense : tensor<128xi1> + } : () -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %constant_mask_b16 = pto.vmi.ensure_mask_granularity %constant_mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + + return %mask_b16, %mask_deint, %group_mask_deint, %constant_mask_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_mask( +// CHECK: %[[M16:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: %[[MDEINT:.*]] = pto.vmi.create_mask %arg0 : index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[GMDEINT:.*]] = pto.vmi.create_group_mask %arg0{{.*}}group_size = 16{{.*}}num_groups = 8{{.*}}index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"(){{.*}}dense : tensor<128xi1>{{.*}}!pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK-NOT: pto.vmi.ensure_mask_layout +// CHECK-NOT: pto.vmi.ensure_mask_granularity +// CHECK: return %[[M16]], %[[MDEINT]], %[[GMDEINT]], %[[CM16]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto new file mode 100644 index 0000000000..9db3fcb22b --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto @@ -0,0 +1,202 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_materialization_addf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_muli( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %prod = pto.vmi.muli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %prod : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_single_ensure_kept( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %sum = pto.vmi.addf %lhs_deint, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary( + %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %neg = pto.vmi.negf %src_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %neg : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_unary_integer( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %abs = pto.vmi.absi %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %abs : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_bitwise( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %and = pto.vmi.andi %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %and : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_shift( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %shift = pto.vmi.shli %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %shift : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_not( + %src: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %src_deint = pto.vmi.ensure_layout %src + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %not = pto.vmi.not %src_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return %not : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_addf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SUM:.*]] = pto.vmi.addf %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SUM_DEINT:.*]] = pto.vmi.ensure_layout %[[SUM]] +// CHECK-SAME: #pto.vmi.layout +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SUM_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_muli( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[PROD:.*]] = pto.vmi.muli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[PROD_DEINT:.*]] = pto.vmi.ensure_layout %[[PROD]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[PROD_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_single_ensure_kept( +// CHECK: %[[LHS_DEINT:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK: %[[SUM2:.*]] = pto.vmi.addf %[[LHS_DEINT]], %arg1 +// CHECK: return %[[SUM2]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NEG:.*]] = pto.vmi.negf %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[NEG_DEINT:.*]] = pto.vmi.ensure_layout %[[NEG]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NEG_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary_integer( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[ABS:.*]] = pto.vmi.absi %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[ABS_DEINT:.*]] = pto.vmi.ensure_layout %[[ABS]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[ABS_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_bitwise( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[AND:.*]] = pto.vmi.andi %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[AND_DEINT:.*]] = pto.vmi.ensure_layout %[[AND]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[AND_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_shift( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[SHIFT:.*]] = pto.vmi.shli %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[SHIFT_DEINT:.*]] = pto.vmi.ensure_layout %[[SHIFT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SHIFT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_not( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK: %[[NOT:.*]] = pto.vmi.not %arg0 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[NOT_DEINT:.*]] = pto.vmi.ensure_layout %[[NOT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[NOT_DEINT]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_mask.pto b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto new file mode 100644 index 0000000000..0effb48323 --- /dev/null +++ b/test/lit/vmi/vmi_layout_sink_materialization_mask.pto @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-sink-materialization -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_layout_sink_mask_layout_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_mask_layout %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_mask_layout %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_and %lhs_deint, %rhs_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_granularity_binary( + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> { + %lhs_b16 = pto.vmi.ensure_mask_granularity %lhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %rhs_b16 = pto.vmi.ensure_mask_granularity %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + %out = pto.vmi.mask_or %lhs_b16, %rhs_b16 + : !pto.vmi.mask<128xb16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.mask<128xb16, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb16, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_mask_layout_unary( + %source: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %source_deint = pto.vmi.ensure_mask_layout %source + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %out = pto.vmi.mask_not %source_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %out : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_binary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_and %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_granularity_binary( +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg0 +// CHECK-NOT: pto.vmi.ensure_mask_granularity %arg1 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_or %arg0, %arg1 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_B16:.*]] = pto.vmi.ensure_mask_granularity %[[OUT]] +// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// CHECK: return %[[OUT_B16]] + +// CHECK-LABEL: func.func @vmi_layout_sink_mask_layout_unary( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK: %[[OUT:.*]] = pto.vmi.mask_not %arg0 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[OUT_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[OUT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[OUT_DEINT]] diff --git a/test/lit/vmi/vmi_legalize_arith_select.pto b/test/lit/vmi/vmi_legalize_arith_select.pto new file mode 100644 index 0000000000..0661b6764e --- /dev/null +++ b/test/lit/vmi/vmi_legalize_arith_select.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-legalize-arith-select -pto-validate-vmi-layout-ir | FileCheck %s + +module { + func.func @vmi_legalize_arith_select_vreg( + %cond: i1, + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_legalize_arith_select_mask( + %cond: i1, + %lhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %rhs: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %selected = arith.select %cond, %lhs, %rhs + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %selected : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_vreg( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] + +// CHECK-LABEL: func.func @vmi_legalize_arith_select_mask( +// CHECK-NOT: arith.select +// CHECK: %[[IF:.*]] = scf.if %arg0 -> (!pto.vmi.mask<128xb32, #pto.vmi.layout>) { +// CHECK: scf.yield %arg1 +// CHECK: } else { +// CHECK: scf.yield %arg2 +// CHECK: return %[[IF]] diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto index 8957bb1f40..e49dba60c3 100644 --- a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -22,6 +22,19 @@ module attributes {pto.target_arch = "a5"} { : !pto.vmi.vreg<128xf32>, !pto.ptr return } + + func.func @vmi_ptoas_cli_fold_consumers_pipeline( + %src: !pto.ptr, + %dst: !pto.ptr, + %offset: index) { + %x16 = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + return + } } } @@ -34,6 +47,16 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast +// CHECK-LABEL: func.func @vmi_ptoas_cli_fold_consumers_pipeline +// CHECK: pto.vlds +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} +// CHECK: pto.vcvt {{.*}} {part = "ODD"} +// CHECK-NOT: pto.vintlv +// CHECK: pto.vstsx2 +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + // ATTR-LABEL: func.func @vmi_ptoas_cli_pipeline // ATTR: pto.vecscope // ATTR: pto.vdup diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto new file mode 100644 index 0000000000..fa1a5524dc --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_deint_tail.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_deint_tail( + %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> { + %cast = pto.vmi.bitcast %value + : !pto.vmi.vreg<129xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<129xi32, #pto.vmi.layout> + return %cast : !pto.vmi.vreg<129xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_deint_tail( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: %[[V2:[^)]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.vreg<64xi32>) +// CHECK-DAG: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B1:.*]] = pto.vbitcast %[[V1]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK-DAG: %[[B2:.*]] = pto.vbitcast %[[V2]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: return %[[B0]], %[[B1]], %[[B2]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto new file mode 100644 index 0000000000..2d7b904af1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_footprint_invalid( + %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<65xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<130xi16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts +// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints +// CHECK-SAME: requires matching logical bit footprint in every physical chunk diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto new file mode 100644 index 0000000000..49d728f73d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_bitcast_group_slots_invalid( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %out = pto.vmi.bitcast %source + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts +// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints +// CHECK-SAME: does not support group_slots layouts diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index 5297123e5a..f78e4ef5f2 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -17,9 +17,9 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion +// CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe // CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: requires source and result to have the same physical arity diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index fbe74e9b69..0639dcd7e8 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1807,8 +1807,22 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.enableVerifier(); pm.addPass(pto::createPTOValidateVMIIRPass()); pm.addPass(pto::createVMILayoutAssignmentPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutFoldConsumersPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutRematerializePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILayoutSinkMaterializationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILegalizeArithSelectPass()); pm.addPass(pto::createPTOValidateVMILayoutIRPass()); pm.addPass(pto::createVMIToVPTOPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); if (failed(applyConfiguredPassManagerCLOptions(pm, "VMI-to-VPTO pipeline"))) return failure(); From 077a613096854ad230e9a7e7910fdf8a35f88b8e Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:36:42 +0800 Subject: [PATCH 20/54] Support multi-chunk VMI group reduce slots --- docs/designs/vmi-implementation-manual.md | 4 +- .../vmi-layout-assignment-implementation.md | 20 ++++--- .../vmi-layout-assignment-lowering-design.md | 7 +++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 2 +- lib/PTO/Transforms/VMILayoutAssignment.cpp | 5 +- lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 20 ++++--- lib/PTO/Transforms/VMIToVPTO.cpp | 55 +++++++++++++++---- ...mi_layout_assignment_group_reduce_s256.pto | 28 ++++++++++ ...ate_group_reduce_slots1_recipe_invalid.pto | 2 +- .../vmi/vmi_to_vpto_group_broadcast_deint.pto | 4 +- ...mi_to_vpto_group_reduce_s256_broadcast.pto | 44 +++++++++++++++ 11 files changed, 159 insertions(+), 32 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 04da993699..04d32aea51 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3111,7 +3111,9 @@ pto.vmi.group_broadcast: result may be contiguous with full physical chunks result may also be deinterleaved when S is large enough that every physical result chunk stays inside one logical group, for example N=512, G=2, S=256, - L=64, deinterleaved=4 + L=64, deinterleaved=4. If the source is + #pto.vmi.layout, the source physical part is + selected by group id rather than by source chunk id. derived group size S must divide or be a multiple of L for canonical group-slot addressing if result is contiguous and S < L, each physical chunk contains multiple group diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 03f22ffd42..31cc618335 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -466,7 +466,9 @@ group_slots group_reduce_addf semantic recipes: S=8 vcgadd S=16 deinterleaved=2 vcgadd+vadd S=32 deinterleaved=4 vcgadd+vadd tree - S=64 contiguous slots=1 vcadd/vadd/vsel row-local reduction + S>=physical_chunk_lanes contiguous slots=1 vcadd/vadd/vsel row-local + reduction, with one physical result part per group. For f32 this covers + S=64, S=128, S=256, ... explicit-slots group_broadcast semantic recipes: slots=8/slots=1 vselr materialization to contiguous or supported @@ -1031,7 +1033,7 @@ Target local recipe matrix: load, recipe=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths - covers dense store users and S=64 row-local reduce input + covers dense store users and full-chunk row-local reduce input load, recipe=load_dintlv2: result layout deinterleaved=2, block_elems=1 @@ -1088,8 +1090,9 @@ group_reduce_addf, recipe=s32_reduce_block8_stride: produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=s64_reduce_row_local: - consumes contiguous f32 with group size 64 +group_reduce_addf, recipe=full_chunk_reduce_row_local: + consumes contiguous f32 with group size that is a multiple of one physical + chunk produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same @@ -1143,9 +1146,10 @@ group_reduce_addf: #pto.vmi.layout, result #pto.vmi.layout; vmi-to-vpto lowers through four VCGADDs plus a PAT_VL8 VADD tree per packed result block. - S=64 row-local assignment uses #pto.vmi.layout - and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit - slots=1 generic VCADD row-local path is registered and selected locally. + Full-chunk row-local assignment, including S=64 and S=256 f32 cases, uses + #pto.vmi.layout and has focused + layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic + VCADD row-local path is registered and selected locally. group_broadcast: explicit slots=8/1 source layouts select @@ -2089,7 +2093,7 @@ Current evidence for the case-catalog objective: 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py 4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 340/340 +5. the latest full VMI lit sweep passed: 342/342 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 4c13b07ef8..13588bce3b 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -498,6 +498,13 @@ group_reduce f32 S=64: input contiguous result group_slots(G, slots=1) +group_reduce f32 S=128/S=256/...: + input contiguous + result group_slots(G, slots=1) + lowering reduces each full physical chunk with vcadd, accumulates all chunks + in the same logical group with lane0 vadd, and writes one physical result + part per group + group_slot_load: result group_slots(G, slots=8) for packed slots result group_slots(G, slots=1) for row-local slots diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h index 7356be9e92..10cde1dc96 100644 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -88,7 +88,7 @@ enum class VMIGroupReduceAddFRecipeKind { S8Vcgadd, S16Deinterleaved2VcgaddVadd, S32Deinterleaved4VcgaddTree, - S64ContiguousVcaddRows, + ContiguousVcaddRows, }; struct VMIGroupReduceAddFRecipe { diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 85a57e4ac1..2ff9e50ae2 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -249,7 +249,10 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); if (groupSize == 32) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 64) + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart) && groupSize >= *lanesPerPart && + groupSize % *lanesPerPart == 0) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } return getGroupSlotsLayout(numGroups); diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp index 7364084028..7cd5281353 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -719,21 +719,25 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return fail("requires matching non-empty source/mask physical arity"); if (resultLayout.getSlots() == 1) { - if (*groupSize != 64) + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *groupSize < *lanesPerPart || + *groupSize % *lanesPerPart != 0) return fail("stable group_reduce_addf slots=1 recipes support group " - "size 64"); + "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("s64 group_reduce_addf requires contiguous source/mask " + return fail("slots=1 group_reduce_addf requires contiguous source/mask " "layouts"); - if (*resultArity != *sourceArity) - return fail("s64 group_reduce_addf requires source/result physical " - "arity to match"); + if (*resultArity != numGroups) + return fail("slots=1 group_reduce_addf requires one physical result " + "part per group"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("s64 group_reduce_addf requires full source chunks; ") + + return fail(Twine("slots=1 group_reduce_addf requires full source " + "chunks; ") + sourceFullReason); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S64ContiguousVcaddRows}; + VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; } if (*groupSize == 8) { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index a59b5dbadb..1c92be4018 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -5725,10 +5725,17 @@ struct OneToNVMIGroupReduceAddFOpPattern &lanesPerPart, &groupCount, &chunksPerGroup, rewriter))) return failure(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + bool rowLocalSlots1Result = + resultLayout && resultLayout.isGroupSlots() && + resultLayout.getNumGroups() == groupCount && + resultLayout.getSlots() == 1; + int64_t expectedResultParts = + rowLocalSlots1Result ? groupCount : groupCount * chunksPerGroup; if (sourceParts.size() != maskParts.size() || static_cast(sourceParts.size()) != groupCount * chunksPerGroup || - resultTypes.size() != sourceParts.size()) + static_cast(resultTypes.size()) != expectedResultParts) return rewriter.notifyMatchFailure( op, "group_reduce_addf requires matching source/mask/result arity"); @@ -5782,7 +5789,7 @@ struct OneToNVMIGroupReduceAddFOpPattern .getResult(); } - int64_t destChunk = group * chunksPerGroup; + int64_t destChunk = rowLocalSlots1Result ? group : group * chunksPerGroup; results[destChunk] = rewriter .create(op.getLoc(), resultType, *accumulator, @@ -5857,6 +5864,18 @@ struct OneToNVMIGroupBroadcastOpPattern resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart) selectionGroupSize = resultLayout.getBlockElems(); + auto resolveLargeGroupSource = [&](int64_t group, int64_t chunksPerGroup, + int64_t &sourceChunk, + int64_t &baseGroupSlot) { + int64_t slots = sourceLayout.getSlots(); + if (slots > 0) { + sourceChunk = group / slots; + baseGroupSlot = group % slots; + return; + } + sourceChunk = group * chunksPerGroup; + baseGroupSlot = 0; + }; SmallVector results; results.resize(resultTypes.size()); @@ -5871,7 +5890,8 @@ struct OneToNVMIGroupBroadcastOpPattern if (*groupSize >= lanesPerPart) { int64_t chunksPerGroup = *groupSize / lanesPerPart; int64_t group = flatIndex / chunksPerGroup; - sourceChunk = group * chunksPerGroup; + resolveLargeGroupSource(group, chunksPerGroup, sourceChunk, + baseGroupSlot); } else { VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); int64_t slots = sourceLayout.getSlots(); @@ -5953,7 +5973,8 @@ struct OneToNVMIGroupBroadcastOpPattern return rewriter.notifyMatchFailure( op, "group_broadcast result chunk crosses logical groups"); int64_t chunksPerGroup = *groupSize / lanesPerPart; - sourceChunk = firstGroup * chunksPerGroup; + resolveLargeGroupSource(firstGroup, chunksPerGroup, sourceChunk, + baseGroupSlot); found = true; break; } @@ -5968,12 +5989,26 @@ struct OneToNVMIGroupBroadcastOpPattern sourceChunk >= static_cast(sourceParts.size())) return rewriter.notifyMatchFailure( op, "group_broadcast source chunk is out of range"); - results[flatIndex] = - rewriter - .create(op.getLoc(), resultType, - sourceParts[sourceChunk], *allMask, - rewriter.getStringAttr("LOWEST")) - .getResult(); + if (sourceLayout.getSlots() > 1) { + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } else { + results[flatIndex] = + rewriter + .create(op.getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) + .getResult(); + } } else { bool blockFragmentSmallGroup = resultLayout && resultLayout.isDeinterleaved() && diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto new file mode 100644 index 0000000000..15fba5a1de --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_s256( + %source: !pto.vmi.vreg<512xf32>, + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> + -> !pto.vmi.vreg<512xf32> + return %out : !pto.vmi.vreg<512xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s256( +// CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 +// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto index f4071e4c47..6e0b04e8f6 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group size 64 + // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto index 078b61b5bf..9c2aff3759 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_deint( - %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) -> !pto.vmi.vreg<512xf32> { %src_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32> %out = pto.vmi.mulf %sum_vec, %src_f32 : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto new file mode 100644 index 0000000000..f2681f3359 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_s256_broadcast( + %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.vmi.mask<512xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %broadcast = pto.vmi.group_broadcast %sum {num_groups = 2} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%broadcast) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s256_broadcast( +// CHECK: pto.vcadd +// CHECK: pto.vadd +// CHECK: pto.vsel +// CHECK: pto.vdup {{.*}} {position = "LOWEST"} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast From ef7405220c9dfe44c19daeddedf6c2797a5b2f66 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 22 Jun 2026 23:48:11 +0800 Subject: [PATCH 21/54] Implement typed VMI group reduce lowering --- docs/designs/vmi-implementation-manual.md | 4 +- .../vmi-layout-assignment-implementation.md | 168 ++++-- .../vmi-layout-assignment-lowering-design.md | 125 ++-- docs/designs/vmi-layout-lowering-cases.md | 552 ++++++++++++++++++ include/PTO/IR/VMIOps.td | 34 ++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 44 +- .../PTO/Transforms/VMITargetCapabilities.h | 23 +- lib/PTO/IR/VMI.cpp | 104 ++++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 158 ++++- lib/PTO/Transforms/VMILocalRecipeRegistry.cpp | 225 +++++-- lib/PTO/Transforms/VMIToVPTO.cpp | 429 ++++++++++++-- .../vmi/vmi_group_reduce_addi_i16_invalid.pto | 24 + .../vmi/vmi_group_reduce_addi_i8_invalid.pto | 24 + ...ut_assignment_group_reduce_s12_invalid.pto | 2 +- ...i_layout_assignment_group_reduce_typed.pto | 56 ++ ...ayout_gate_group_reduce_recipe_invalid.pto | 2 +- ...ate_group_reduce_slots1_recipe_invalid.pto | 2 +- ..._group_slots_unsupported_slots_invalid.pto | 2 +- .../vmi/vmi_to_vpto_group_reduce_typed.pto | 80 +++ .../vmi/vmi_to_vpto_integer_cast_reduce.pto | 44 ++ test/lit/vmi/vmi_to_vpto_integer_casts.pto | 64 ++ ...id.pto => vmi_to_vpto_reduce_addf_f16.pto} | 27 +- .../vmi_to_vpto_trunci_i8_signed_invalid.pto | 29 + .../group-reduce-f16-addf-store/compare.py | 37 ++ .../vmi/group-reduce-f16-addf-store/golden.py | 43 ++ .../group-reduce-f16-addf-store/kernel.pto | 51 ++ .../group-reduce-f16-addf-store/launch.cpp | 34 ++ .../vmi/group-reduce-f16-addf-store/main.cpp | 86 +++ .../group-reduce-f16-addf-store/ptoas.flags | 1 + .../compare.py | 37 ++ .../golden.py | 43 ++ .../kernel.pto | 52 ++ .../launch.cpp | 36 ++ .../main.cpp | 88 +++ .../ptoas.flags | 1 + .../group-reduce-i32-addi-store/compare.py | 37 ++ .../vmi/group-reduce-i32-addi-store/golden.py | 42 ++ .../group-reduce-i32-addi-store/kernel.pto | 51 ++ .../group-reduce-i32-addi-store/launch.cpp | 35 ++ .../vmi/group-reduce-i32-addi-store/main.cpp | 86 +++ .../group-reduce-i32-addi-store/ptoas.flags | 1 + .../compare.py | 37 ++ .../golden.py | 44 ++ .../kernel.pto | 52 ++ .../launch.cpp | 36 ++ .../main.cpp | 88 +++ .../ptoas.flags | 1 + 47 files changed, 3039 insertions(+), 202 deletions(-) create mode 100644 test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto create mode 100644 test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto create mode 100644 test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto create mode 100644 test/lit/vmi/vmi_to_vpto_integer_casts.pto rename test/lit/vmi/{vmi_to_vpto_reduce_addf_f16_invalid.pto => vmi_to_vpto_reduce_addf_f16.pto} (55%) create mode 100644 test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 04d32aea51..497e951e73 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3867,8 +3867,8 @@ Slice 4 完成条件: per-feature negative tests. 9. Same-family reduction ops reject unsupported direct-lowering shapes consistently. Covered by vmi_to_vpto_reduce_shape_invalid.pto together with the existing reduce add/min/max positive and - per-feature negative tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto and - vmi_to_vpto_reduce_addf_f16_invalid.pto. + per-feature tests, including vmi_to_vpto_reduce_addi_i16_invalid.pto for narrow integer rejection and + vmi_to_vpto_reduce_addf_f16.pto for f16 floating-point reduction lowering. 10. Target-specific element contracts are checked before OneToN rewriting for direct VPTO ops. Covered by vmi_to_vpto_bf16_arith.pto, vmi_to_vpto_math_element_type_invalid.pto, vmi_to_vpto_cmp_select.pto, vmi_to_vpto_cmp_element_type_invalid.pto, diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 31cc618335..a6583a3d8b 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -98,7 +98,7 @@ pto-validate-vmi-layout-ir: fail before `vmi-to-vpto`. It also checks the first semantic local-recipe families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots - `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_addf`, + `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. @@ -294,7 +294,7 @@ Local-decision table for the current implementation: op local decision inputs group_load result layout, num_groups, row_stride, source type group_slot_load result group_slots layout and source_group_stride -group_reduce_addf source/mask/result layouts, num_groups, reassoc +group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths ensure_layout always carries source/result layouts instead of recipe @@ -329,8 +329,8 @@ error. Examples of forbidden recovery in `vmi-to-vpto`: ```text -group_reduce_addf cannot walk to a load/group_load producer to choose S=16 - parity versus block8. +group_reduce_add{f|i} cannot walk to a load/group_load producer to choose + two-vlane parity versus block8. group_store cannot inspect the group_reduce producer; it consumes only the assigned source layout and explicit stride. group_broadcast cannot inspect sibling users to decide whether to rematerialize. @@ -354,12 +354,17 @@ create_group_mask extf truncf +extsi +extui +trunci addf +addi mulf select broadcast group_reduce_addf +group_reduce_addi group_broadcast group_store @@ -368,6 +373,30 @@ ensure_mask_layout // internal ensure_mask_granularity // internal ``` +Type policy before lowering: + +```text +storage / memory boundary: + f8-like, i8, f16, i16, f32, i32 may appear as load/store element types when + the target memory instruction supports the physical width. + +cast boundary: + f8-like may appear as extf/truncf source or destination. + i8 may appear as extsi/extui/trunci source or destination. Signedness is an + op semantic, not a VMI type spelling. + Current VPTO lowering supports 32-bit integer narrowing to unsigned i8 + storage, matching the available VCVTII s32/u32 -> u8 forms; signed i8 + narrowing needs a separate target recipe. + +compute / accumulator: + floating compute baseline: f16/f32, with reassoc required for reductions + that lower through pair-wise VPTO reductions. + integer compute baseline: i32 for grouped reduction; i8/i16 storage must + first cast to i32 because integer reduction instructions widen narrow inputs. + f8/i8 are not baseline accumulator/compute types. Supporting direct 8-bit + compute requires a target capability entry and a separate recipe family. +``` + Important semantic split: ```text @@ -462,13 +491,20 @@ group_slots group_store semantic recipes: slots=8 unit-stride vsts slots=1 aligned lane-0 vsts per group -group_slots group_reduce_addf semantic recipes: - S=8 vcgadd - S=16 deinterleaved=2 vcgadd+vadd - S=32 deinterleaved=4 vcgadd+vadd tree - S>=physical_chunk_lanes contiguous slots=1 vcadd/vadd/vsel row-local - reduction, with one physical result part per group. For f32 this covers - S=64, S=128, S=256, ... +group_slots group_reduce_add{f|i} semantic recipes: + define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. + T is the accumulator/reduce element type after any required storage cast. + f8 storage reduces through f32; i8 storage reduces through an explicit + signed/unsigned integer cast to an accumulator type such as i32. In the + baseline contract, f8/i8 are cast-boundary storage types rather than + accumulator/compute types. + S=VLaneElems contiguous vcgadd + S=2*VLaneElems deinterleaved=2 vcgadd+vadd + S=4*VLaneElems deinterleaved=4 vcgadd+vadd tree + S>=L && S%L==0 contiguous slots=1 vcadd/vadd/vsel row-local reduction, + with one physical result part per group. For 32-bit element types this covers + S=64, S=128, S=256, ...; for 16-bit element types this covers S=128, S=256, + ... explicit-slots group_broadcast semantic recipes: slots=8/slots=1 vselr materialization to contiguous or supported @@ -481,6 +517,14 @@ extf/truncf semantic recipes: deinterleaved=4 f32 -> contiguous f8-like group_slots(G, slots=1) f32 -> f16 +extsi/extui/trunci semantic recipes: + contiguous i8 -> deinterleaved=2 i16 through VCVTII.{s,u}82{s,u}16 #part + contiguous i8 -> deinterleaved=4 i32 through VCVTII.{s,u}82{s,u}32 #pp + deinterleaved=2 i16 -> contiguous i8 through VCVTII.*162*8 #part + deinterleaved=4 i32 -> contiguous ui8 through VCVTII.*322u8 #pp + packed group_slots integer width-changing cast is unsupported until a + slot-wise cast recipe is defined. + bitcast semantic recipes: per-part vbitcast for contiguous/deinterleaved layouts when source/result layouts match, physical arity matches, and every physical chunk carries the @@ -651,13 +695,15 @@ buildCastRequests: exists buildGroupReduceRequests: - derive S = logical_lanes / num_groups - S=8 -> contiguous source, group_slots(G,8) result - S=16 -> deinterleaved=2/block_elems=1 or block_elems=8 source, - group_slots(G,8) result - S=32 -> deinterleaved=4/block_elems=1 or block_elems=8 source, - group_slots(G,8) result - S=64 -> contiguous source, group_slots(G,1) result + derive E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems -> contiguous source, group_slots(G,8) result + S=2*VLaneElems -> deinterleaved=2/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S=4*VLaneElems -> deinterleaved=4/block_elems=1 or block_elems=8 source, + group_slots(G,8) result + S>=L && S%L==0 -> contiguous source, group_slots(G,1) result + 8-bit storage must be cast to an accumulator type before this request builder other S -> diagnostic unless an explicit fallback recipe is enabled buildGroupMemoryRequests: @@ -866,10 +912,10 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 S=8 reduce buildGroupReduceRequests s8_reduce_contiguous recipe -3.5 S=16 reduce buildGroupReduceRequests s16_reduce_parity/block8 recipe -3.6 S=32 reduce buildGroupReduceRequests s32_reduce_dintlv4/block8 recipe -3.7 S=64 reduce buildGroupReduceRequests s64_reduce_row_local recipe +3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous recipe +3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 recipe +3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 recipe +3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local recipe 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks 3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks @@ -1065,34 +1111,34 @@ group_load, recipe=group_load_contiguous_chunks: emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_addf, recipe=s8_reduce_contiguous: - consumes contiguous f32 with group size 8 +group_reduce_add{f|i}, recipe=one_vlane_reduce_contiguous: + consumes contiguous accumulator type T with group size VLaneElems(T) produces group_slots(G, slots=8) emits one vcgadd -group_reduce_addf, recipe=s16_reduce_parity: +group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, recipe=s16_reduce_block8: +group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_addf, recipe=s32_reduce_dintlv4: +group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=s32_reduce_block8_stride: +group_reduce_add{f|i}, recipe=four_vlane_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_addf, recipe=full_chunk_reduce_row_local: - consumes contiguous f32 with group size that is a multiple of one physical - chunk +group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: + consumes contiguous accumulator type T with group size that is a multiple of + one physical chunk L(T) produces group_slots(G, slots=1) target lowering emits per-row vcgadd plus vcadd; the current prototype uses the existing row-local VCADD/VADD/VSEL sequence while preserving the same @@ -1150,6 +1196,9 @@ group_reduce_addf: #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic VCADD row-local path is registered and selected locally. + group_reduce_addi is implemented for i32 accumulator values. i8/i16 storage + must be widened explicitly before grouped reduction because narrow integer + reduction instructions widen their result. group_broadcast: explicit slots=8/1 source layouts select @@ -1199,18 +1248,69 @@ group_store: design target unless a strided packed-lane store recipe is made explicit. ``` +Current implementation contract for type-generic grouped reduction: + +```text +ODS/verifiers: + pto.vmi.group_reduce_addi is the integer counterpart to group_reduce_addf. + group_reduce_addi accepts i32 accumulator element types; i8/i16 direct + grouped reduction is rejected with a diagnostic that points users to + extsi/extui. + extsi/extui/trunci carry integer signedness across storage/accumulator + boundaries without overloading add semantics. + +Layout assignment: + compute VLaneElems and L from the accumulator/reduce element type: + VLaneElems = 32B / sizeof(accumulator T) + L = 256B / sizeof(accumulator T) + use the same S formula for f16/f32/i32 once the typed reduce op and target + capability say the type is legal. + route f8 storage through extf to f32 before group_reduce_addf. + route i8/i16 storage through extsi/extui to i32 before group_reduce_addi. + route integer narrowing to i8 through trunci; direct i8 compute remains + illegal unless the target capability registry exposes an explicit recipe. + diagnose direct f8/i8 compute use with a message that points at the offending + op and suggests inserting the explicit cast when the op is meant to consume + storage data. + +Local recipe registry: + replace f32-shaped recipe keys with width-parametric recipe classes: + one_vlane_reduce + two_vlane_reduce_deinterleaved + four_vlane_reduce_deinterleaved + full_chunk_row_local_reduce + key legality on accumulator byte width, source/mask layout, result + group_slots layout, num_groups, and target instruction capability. + +VMI-to-VPTO: + lower group_reduce_addi through the same VCGADD/VADD skeleton used for + floating-point where the target supports the integer accumulator type. + materialize integer casts explicitly before reduction; direct i8 group reduce + and direct i16 group reduce must not silently become a widening reduction in + this pass. + keep VPTO lowering local: it consumes assigned layouts and registered local + recipes, but does not invent a new global layout plan. + +Tests: + cover f16 direct and i16-storage-to-i32 grouped reductions. + add i32 S=8/S=16/S=32/S=64 group-reduce cases. + add f8 storage -> extf -> f32 group_reduce_addf cases. + add i8/i16 storage -> extsi/extui -> i32 group_reduce_addi cases. + add invalid direct f8/i8/i16 grouped-reduce diagnostics. +``` + Examples: ```text -group_reduce_addf, recipe=s16_reduce_parity: +group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_addf, recipe=s16_reduce_block8: +group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_addf, recipe=s32_reduce_dintlv4: +group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 13588bce3b..82a84082c6 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -94,11 +94,18 @@ dense cast: f16 -> f32 -> store f32 -> f16 -> store f8 -> f32 -> compute -> f8 + f8 -> f32 accumulator -> group_reduce_addf + i8/i16 -> signed/unsigned integer cast to i32 accumulator + -> group_reduce_addi + f8/i8 appear as cast source or cast destination at compute boundaries + integer narrowing back to i8 is an explicit cast, not implicit arithmetic f16 -> f32 shared by dense store and S=16 reduce f32 shared by f8 store and S=32 reduce group reduce: - S=8, S=16, S=32, S=64 + 32-bit accumulator: S=8, S=16, S=32, S=64 + 16-bit accumulator: S=16, S=32, S=64, S=128 + 8-bit storage reduces only through an explicit accumulator cast reduce -> group_store reduce -> group_slot_load/elemwise -> group_store reduce -> group_broadcast -> elemwise -> reduce -> store @@ -205,6 +212,29 @@ diagnostic-only cases: Layout is a property of a layout-assigned VMI value, not a property inferred by the final lowering pattern. +Type policy: + +```text +storage boundary: + f8-like/i8/f16/i16/f32/i32 may appear in load/store values when the target + memory instruction supports the physical width. + +cast boundary: + f8-like participates through extf/truncf. + i8 participates through extsi/extui/trunci. Signedness is carried by the + cast op semantics, not by a separate layout. + On the current VPTO target, 32-bit to 8-bit integer narrowing is only a + baseline recipe for unsigned i8 results because the available VCVTII forms + are s32/u32 -> u8. + +compute boundary: + baseline floating compute uses f16/f32. + baseline integer grouped reduction compute uses i32 accumulators. i8/i16 + storage must be widened first because integer reduction instructions widen + narrow inputs. + f8/i8 are not baseline accumulator/compute element types. +``` + ### 2.1 Dense Layouts ```text @@ -348,10 +378,37 @@ group_slot_load: rematerialized into two ops when different users require different result layouts; each clone is then locally deterministic. -group_reduce_addf: +group_reduce_add{f|i}: source/mask layout, result group_slots layout, num_groups, element type, and - reassoc decide S=8 contiguous vcgadd, S=16/S=32 deinterleaved vcgadd trees, - and S=64 row-local vcadd/vsel lowering. + the typed reduce semantics decide the local reduction recipe. The recipe is + not keyed by f32 shape names. It is derived from the element byte width. + Floating-point `group_reduce_addf` carries `reassoc`; integer + `group_reduce_addi` does not. + + VLaneElems = 32B / sizeof(T) + L = 256B / sizeof(T) + S = logical_lane_count / num_groups + + S == VLaneElems -> contiguous vcgadd, result slots=8 + S == 2 * VLaneElems -> deinterleaved=2 vcgadd tree, result slots=8 + S == 4 * VLaneElems -> deinterleaved=4 vcgadd tree, result slots=8 + S >= L && S % L == 0 -> contiguous row-local vcadd/vsel, result slots=1 + + Type support is controlled by the typed reduce op semantics and target + capability, not by separate per-type shape rules. Once a type is legal for a + reduce op, the same formula above selects its layout and local recipe. The + current checked-in implementation may lag this design target; that is staged + implementation status, not a design boundary. + + The formula is applied to the accumulator/reduce element type, not + necessarily the storage element type. 8-bit floating-point storage first + casts to f32 for `group_reduce_addf`; 8-bit and 16-bit integer storage first + casts to a signed/unsigned i32 accumulator for + `group_reduce_addi`. In the baseline VMI contract, f8/i8 are storage and + cast-boundary types: they may be the source or destination of cast, load, and + store, but they are not accumulator/compute types for group reduce. Direct + 8-bit grouped reduction is illegal unless the target exposes an explicit + 8-bit compute recipe. group_broadcast: source group_slots layout, result dense layout, num_groups, and element type @@ -480,30 +537,15 @@ load: ### 5.3 Group Recipes From Cases ```text -group_reduce f32 S=8: - input contiguous - result group_slots(G, slots=8) - -group_reduce f32 S=16: - legal input layout A: deinterleaved=2, block_elems=1 - legal input layout B: deinterleaved=2, block_elems=8 - result group_slots(G, slots=8) - -group_reduce f32 S=32: - legal input layout A: deinterleaved=4, block_elems=1 - legal input layout B: deinterleaved=4, block_elems=8 - result group_slots(G, slots=8) - -group_reduce f32 S=64: - input contiguous - result group_slots(G, slots=1) - -group_reduce f32 S=128/S=256/...: - input contiguous - result group_slots(G, slots=1) - lowering reduces each full physical chunk with vcadd, accumulates all chunks - in the same logical group with lane0 vadd, and writes one physical result - part per group +group_reduce_add{f|i} typed shape classification: + define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. + S=VLaneElems uses contiguous input and group_slots(G, slots=8). + S=2*VLaneElems uses deinterleaved=2 input/mask and group_slots(G, slots=8). + S=4*VLaneElems uses deinterleaved=4 input/mask and group_slots(G, slots=8). + S>=L && S%L==0 uses contiguous input/mask and group_slots(G, slots=1); + lowering reduces each full physical chunk, accumulates all chunks in the + same logical group through lane0, and writes one physical result part per + group. group_slot_load: result group_slots(G, slots=8) for packed slots @@ -585,21 +627,18 @@ truncf f32 -> f8: requests source deinterleaved=4, block_elems=1 requests result contiguous f8 -group_reduce S=8: - requests source contiguous - requests result group_slots(num_groups, slots=8) - -group_reduce S=16: - requests source deinterleaved=2, block_elems=1 or block_elems=8 - requests result group_slots(num_groups, slots=8) - -group_reduce S=32: - requests source deinterleaved=4, block_elems=1 or block_elems=8 - requests result group_slots(num_groups, slots=8) - -group_reduce S=64: - requests source contiguous - requests result group_slots(num_groups, slots=1) +group_reduce_add{f|i}: + computes E = sizeof(accumulator type), VLaneElems = 32B / E, + L = 256B / E, and S = logical_lanes / num_groups + S=VLaneElems requests source contiguous and result group_slots(G, slots=8) + S=2*VLaneElems requests source deinterleaved=2 and result + group_slots(G, slots=8) + S=4*VLaneElems requests source deinterleaved=4 and result + group_slots(G, slots=8) + S>=L && S%L==0 requests source contiguous and result + group_slots(G, slots=1) + 8-bit storage reaches this request only after an explicit cast to the + accumulator type group_broadcast: requests source group_slots(num_groups, slots=K) diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e084ad58c0..e17c14844b 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -5467,3 +5467,555 @@ store itself can locally prove the same contiguous memory effect from the source layout. vmi-to-vpto must not scan the `%w` producer or both store users to decide this. ``` + +### 3.47 Type-Parametric Group Reduce Rule + +The group-reduce layout rule is parameterized by the element width, not by f32 +case names. + +```text +E = sizeof(T) +VLaneElems = 32B / E +L = 256B / E +S = logical_lane_count / num_groups +``` + +The canonical grouped-reduce layouts are: + +```text +S == VLaneElems: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=8) + +S == 2 * VLaneElems: + source/mask layout = deinterleaved=2 + result layout = group_slots(num_groups=G, slots=8) + +S == 4 * VLaneElems: + source/mask layout = deinterleaved=4 + result layout = group_slots(num_groups=G, slots=8) + +S >= L && S % L == 0: + source/mask layout = contiguous + result layout = group_slots(num_groups=G, slots=1) +``` + +Concrete shape table: + +```text +T VLaneElems L packed cases row-local cases +f32 8 64 S=8, S=16, S=32 S=64, S=128, ... +i32 8 64 S=8, S=16, S=32 S=64, S=128, ... +f16 16 128 S=16, S=32, S=64 S=128, S=256, ... +i16 16 128 S=16, S=32, S=64 S=128, S=256, ... +f8 32 256 cast to f32 before grouped reduce +i8 32 256 cast to i16/i32 before grouped reduce +``` + +These non-f32 cases are part of the type-generic layout/lowering design. If a +typed reduce op admits the element type and the target capability registry +accepts it, assignment must use the same `VLaneElems/L/S` formula instead of +adding per-type shape special cases. Any f32-only behavior in the current +implementation is staged implementation status, not the intended design limit. +For the current baseline, `f8/i8` are storage and cast-boundary types: they are +valid as load/store element types and as cast source/destination, but compute +ops such as group reduce consume the post-cast accumulator type. + +### 3.48 16-bit Typed Group Reduce, `S = VLaneElems = 16` + +This case covers both `f16` and `i16`. The element width is the same, so the +layout and VPTO instruction skeleton are identical. The VMI op name carries the +semantic difference: + +```text +f16: pto.vmi.group_reduce_addf ... {reassoc} +i16 storage: pto.vmi.extsi/extui ... -> i32 group_reduce_addi ... +``` + +VMI-shaped input: + +```text +// Floating form. +%xf = pto.vmi.load %base_f16[%off] + : memref<128xf16> -> !pto.vmi.vreg<128xf16> +%mf = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumf = pto.vmi.group_reduce_addf %xf, %mf {num_groups = 8, reassoc} +pto.vmi.group_store %sumf, %out_f16[%group_off], %c1 {num_groups = 8} + +// Integer form. +%xi = pto.vmi.load %base_i16[%off] + : memref<128xi16> -> !pto.vmi.vreg<128xi16> +%mi = pto.vmi.create_group_mask %c16 {num_groups = 8, group_size = 16} + : index -> !pto.vmi.mask<128xpred> +%sumi = pto.vmi.group_reduce_addi %xi, %mi {num_groups = 8} +pto.vmi.group_store %sumi, %out_i16[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%xf, %mf, %xi, %mi: + #pto.vmi.layout + +%sumf: + !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%sumi: + !pto.vmi.vreg<128xi16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x0 = pto.vlds %base[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%sum0 = pto.vcgadd %x0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 16 + 0 .. 15]) +``` + +### 3.49 16-bit Typed Group Reduce, `S = 2 * VLaneElems = 32` + +This case covers both `f16` and `i16`. Each logical row is 64B and must be +split into two 32B VLane fragments before `vcgadd`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xT16> -> !pto.vmi.vreg<256xT16> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B16" + : !pto.ptr, index -> !pto.vreg<128xT16>, !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum0 = pto.vadd %s0, %s1, %slot8_b16 + : !pto.vreg<128xT16>, !pto.vreg<128xT16>, !pto.mask + -> !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 32 + 0 .. 31]) +``` + +### 3.50 16-bit Typed Group Reduce, `S = 4 * VLaneElems = 64` + +This is the four-fragment packed case for both `f16` and `i16`. + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<512xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4 input + : four !pto.vreg<128xT16> + +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b16 : !pto.vreg<128xT16> +%s1 = pto.vcgadd %x_p1, %all_b16 : !pto.vreg<128xT16> +%s2 = pto.vcgadd %x_p2, %all_b16 : !pto.vreg<128xT16> +%s3 = pto.vcgadd %x_p3, %all_b16 : !pto.vreg<128xT16> + +%s01 = pto.vadd %s0, %s1, %slot8_b16 : !pto.vreg<128xT16> +%s23 = pto.vadd %s2, %s3, %slot8_b16 : !pto.vreg<128xT16> +%sum0 = pto.vadd %s01, %s23, %slot8_b16 : !pto.vreg<128xT16> + +pto.vsts %sum0, %out[%group_off], %slot8_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 64 + 0 .. 63]) +``` + +### 3.51 16-bit Typed Group Reduce, `S = L = 128` + +This is the first row-local full-physical-chunk case for both `f16` and `i16`. +The canonical result is row-local `slots = 1`, not packed `slots = 8`. + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<1024xT16> -> !pto.vmi.vreg<1024xT16> +%mask = pto.vmi.create_group_mask %c128 {num_groups = 8, group_size = 128} + : index -> !pto.vmi.mask<1024xpred> +%sum = pto.vmi.group_reduce_add{f|i} %x, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<1024xT16, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%all_b16 = pto.pge_b16 "PAT_ALL" +%slot8_b16 = pto.pge_b16 "PAT_VL8" +%slot1_b16 = pto.pge_b16 "PAT_VL1" + +// Repeated for r = 0..7. +%x_r = pto.vlds %base[%row_off_r] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xT16> +%partial_r = pto.vcgadd %x_r, %all_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> +%sum_r = pto.vcadd %partial_r, %slot8_b16 + : !pto.vreg<128xT16>, !pto.mask -> !pto.vreg<128xT16> + +pto.vsts %sum_r, %out[%group_off_plus_r], %slot1_b16 {dist = "NORM_B16"} + : !pto.vreg<128xT16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out[group_off + r] = reduce_T16(base[off + r * 128 + 0 .. 127]) +``` + +### 3.52 32-bit Typed Group Reduce + +This case covers both `f32` and `i32`. The element width is the same, so +`VLaneElems = 8` and `L = 64` for both. Floating-point uses +`group_reduce_addf` with `reassoc`; integer uses `group_reduce_addi`. + +Example for `S = 2 * VLaneElems = 16`: + +```text +%x: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<128xT32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x_p0, %x_p1 = pto.vldsx2 %base[%off], "DINTLV_B32" + : !pto.ptr, index -> !pto.vreg<64xT32>, !pto.vreg<64xT32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x_p0, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%s1 = pto.vcgadd %x_p1, %all_b32 + : !pto.vreg<64xT32>, !pto.mask -> !pto.vreg<64xT32> +%sum0 = pto.vadd %s0, %s1, %slot8_b32 + : !pto.vreg<64xT32>, !pto.vreg<64xT32>, !pto.mask + -> !pto.vreg<64xT32> + +pto.vsts %sum0, %out[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xT32>, !pto.ptr, !pto.mask +``` + +The same formula gives: + +```text +S=8: + contiguous, slots=8, one vcgadd. + +S=32: + deinterleaved=4, slots=8, four vcgadd plus vadd tree. + +S=64: + contiguous, slots=1, row-local vcgadd plus vcadd. + +S=128: + contiguous, slots=1, row-local multi-chunk accumulation. +``` + +### 3.53 Integer Semantics And Invalid Typed Reductions + +Integer group reduction is not a variant of `group_reduce_addf`; it requires a +typed integer op: + +```text +%sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = G} +``` + +Required semantics: + +```text +inactive lanes contribute integer zero +addition uses the target's normal integer add behavior +wrap/saturating variants must be represented by distinct ops if both are needed +signedness does not affect add, but does affect future max/min integer reduces +``` + +Required invalid cases: + +```text +pto.vmi.group_reduce_addf with integer element type -> verifier error +pto.vmi.group_reduce_addi with floating-point element type -> verifier error +pto.vmi.group_reduce_addi i8 -> invalid direct 8-bit accumulator reduce; + cast to i16/i32 first unless target exposes i8 vcgadd +S not in {VLaneElems, 2*VLaneElems, 4*VLaneElems} and not a full-chunk multiple + -> layout-contract diagnostic +``` + +### 3.54 8-bit Floating Group Reduce + +There is no direct f8 `vcgadd` grouped reduction in the current target model, +but f8 supports cast to an accumulator type. The semantic path is: + +```text +f8 storage -> cast/extf to f32 accumulator -> group_reduce_addf on f32 +``` + +Here `f8` is only the cast source and the memory element type. The reduction +itself is a f32 accumulator operation. + +The group size remains a logical-lane property. For example, reducing eight +rows of 32 f8 elements produces the same logical result as reducing eight rows +of 32 f32 accumulator elements after extension. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_f8[%off] + : memref<256xf8> -> !pto.vmi.vreg<256xf8> +%x32 = pto.vmi.extf %x8 + : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %out_f32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xf8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xf32, #pto.vmi.layout> +``` + +VPTO lowering shape: + +```text +%x8_packed = pto.vlds %base_f8[%off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xf8> + +%all_b8 = pto.pge_b8 "PAT_ALL" +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%x32_p0 = pto.vcvt %x8_packed, %all_b8 {part = "P0"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p1 = pto.vcvt %x8_packed, %all_b8 {part = "P1"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p2 = pto.vcvt %x8_packed, %all_b8 {part = "P2"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> +%x32_p3 = pto.vcvt %x8_packed, %all_b8 {part = "P3"} + : !pto.vreg<256xf8>, !pto.mask -> !pto.vreg<64xf32> + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xf32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xf32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xf32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xf32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xf32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xf32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xf32> + +pto.vsts %sum0, %out_f32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_f32[group_off + r] = + reduce_f32(f32(base_f8[off + r * 32 + 0 .. 31])) +``` + +Direct f8 grouped reduction is invalid: + +```text +pto.vmi.group_reduce_addf %x8, %mask + : !pto.vmi.vreg<256xf8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +### 3.55 8-bit Integer Group Reduce + +The current target model has no i8 `vcgadd`. It does have widening `vcadd` for +full-vector reductions, but grouped reduction needs one partial result per +32B VLane. Since 8-bit integers support cast to wider integer types, the +baseline grouped path casts before reducing: + +```text +i8/i16 storage -> signed/unsigned cast to i32 accumulator + -> group_reduce_addi on the accumulator type +``` + +Here `i8`/`i16` are only cast sources and memory element types. The reduction +itself is an i32 accumulator operation, with signedness handled by the cast. + +The integer cast operation must carry signedness. This document uses +`extsi/extui` as the widening spelling and `trunci` as the narrowing spelling: + +```text +%x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x32 = pto.vmi.extui %x8 : !pto.vmi.vreg -> !pto.vmi.vreg +%x8 = pto.vmi.trunci %x32 : !pto.vmi.vreg -> !pto.vmi.vreg +``` + +The last form is unsigned i8 on the current VPTO target: VISA exposes +VCVTII.s322u8/u322u8 for 32-bit to 8-bit narrowing, not a signed-i8 +destination form. + +VMI-shaped input: + +```text +%x8 = pto.vmi.load %base_i8[%off] + : memref<256xi8> -> !pto.vmi.vreg<256xi8> +%x32 = pto.vmi.extsi %x8 + : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> +%mask = pto.vmi.create_group_mask %c32 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +pto.vmi.group_store %sum, %out_i32[%group_off], %c1 {num_groups = 8} +``` + +Assigned layouts: + +```text +%x8: + !pto.vmi.vreg<256xi8, #pto.vmi.layout> + +%x32, %mask: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> + !pto.vmi.mask<256xb32, #pto.vmi.layout> + +%sum: + !pto.vmi.vreg<256xi32, #pto.vmi.layout> +``` + +VPTO lowering shape after integer cast materialization: + +```text +%x32_p0, %x32_p1, %x32_p2, %x32_p3 = + materialize signed cast i8 -> i32 with deinterleaved=4 layout + : four !pto.vreg<64xi32> + +%all_b32 = pto.pge_b32 "PAT_ALL" +%slot8_b32 = pto.pge_b32 "PAT_VL8" + +%s0 = pto.vcgadd %x32_p0, %all_b32 : !pto.vreg<64xi32> +%s1 = pto.vcgadd %x32_p1, %all_b32 : !pto.vreg<64xi32> +%s2 = pto.vcgadd %x32_p2, %all_b32 : !pto.vreg<64xi32> +%s3 = pto.vcgadd %x32_p3, %all_b32 : !pto.vreg<64xi32> +%s01 = pto.vadd %s0, %s1, %slot8_b32 : !pto.vreg<64xi32> +%s23 = pto.vadd %s2, %s3, %slot8_b32 : !pto.vreg<64xi32> +%sum0 = pto.vadd %s01, %s23, %slot8_b32 : !pto.vreg<64xi32> + +pto.vsts %sum0, %out_i32[%group_off], %slot8_b32 {dist = "NORM_B32"} + : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..7: + out_i32[group_off + r] = + reduce_i32(sign_extend(base_i8[off + r * 32 + 0 .. 31])) +``` + +Direct i8 grouped reduction without the cast is invalid: + +```text +pto.vmi.group_reduce_addi %x8, %mask + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> verifier or layout-contract diagnostic +``` + +An optimized row-local i8 full-chunk recipe may be added later for +`S = 256` by using widening `vcadd`, but that requires a widening +`group_slots` result contract and must not change the baseline cast-to-accumulator +semantics above. + +If the final memory result is i8, narrowing is a separate cast after the +accumulator computation: + +```text +%sum32 = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} +%sum8 = pto.vmi.trunci %sum32 +pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} +``` + +That packed group-slot `trunci` path is not a baseline recipe yet; the +implementation must either define a slot-wise VCVTII recipe or diagnose at +layout assignment. diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 80036f9946..d14b6fe8ee 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -418,6 +418,16 @@ def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { + let summary = "VMI masked integer add reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { let summary = "VMI broadcast group-slot values back to each logical group"; let arguments = (ins VMI_VRegTypeConstraint:$source, @@ -443,6 +453,30 @@ def VMITruncFOp : VMI_Op<"truncf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } +def VMIExtSIOp : VMI_Op<"extsi"> { + let summary = "VMI signed integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIExtUIOp : VMI_Op<"extui"> { + let summary = "VMI unsigned integer elementwise extension"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMITruncIOp : VMI_Op<"trunci"> { + let summary = "VMI saturating integer elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + def VMIBitcastOp : VMI_Op<"bitcast"> { let summary = "VMI bitwise vector reinterpretation"; let arguments = (ins VMI_VRegTypeConstraint:$source); diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h index 10cde1dc96..8472a32c4c 100644 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ b/include/PTO/Transforms/VMILocalRecipeRegistry.h @@ -85,14 +85,15 @@ struct VMIGroupSlotsStoreRecipe { }; enum class VMIGroupReduceAddFRecipeKind { - S8Vcgadd, - S16Deinterleaved2VcgaddVadd, - S32Deinterleaved4VcgaddTree, + OneVLaneVcgadd, + TwoVLaneDeinterleaved2VcgaddVadd, + FourVLaneDeinterleaved4VcgaddTree, ContiguousVcaddRows, }; struct VMIGroupReduceAddFRecipe { - VMIGroupReduceAddFRecipeKind kind = VMIGroupReduceAddFRecipeKind::S8Vcgadd; + VMIGroupReduceAddFRecipeKind kind = + VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd; }; enum class VMIGroupBroadcastRecipeKind { @@ -125,6 +126,27 @@ struct VMIExtFRecipe { VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; }; +enum class VMITruncIRecipeKind { + Deinterleaved2I32ToContiguousI16, + Deinterleaved4I32ToContiguousI8, + GroupSlots1I32ToI16, +}; + +struct VMITruncIRecipe { + VMITruncIRecipeKind kind = + VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16; +}; + +enum class VMIExtIRecipeKind { + ContiguousI16ToDeinterleaved2I32, + ContiguousI8ToDeinterleaved4I32, +}; + +struct VMIExtIRecipe { + VMIExtIRecipeKind kind = + VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32; +}; + enum class VMIBitcastRecipeKind { PerPartVbitcast, }; @@ -178,6 +200,11 @@ class VMILocalRecipeRegistry { VMIGroupReduceAddFOp op, std::string *reason = nullptr) const; + FailureOr + getGroupReduceAddIRecipe(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + FailureOr getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, @@ -189,6 +216,15 @@ class VMILocalRecipeRegistry { FailureOr getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; + FailureOr + getExtSIRecipe(VMIExtSIOp op, std::string *reason = nullptr) const; + + FailureOr + getExtUIRecipe(VMIExtUIOp op, std::string *reason = nullptr) const; + + FailureOr + getTruncIRecipe(VMITruncIOp op, std::string *reason = nullptr) const; + FailureOr getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; }; diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index 15b4f19f1d..a96a73a6d0 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -42,6 +42,8 @@ enum class VMIElementPurpose { enum class VMIReductionKind { AddI, AddF, + GroupAddI, + GroupAddF, MaxF, MinF, }; @@ -229,11 +231,26 @@ class VMITargetCapabilityRegistry { "currently supports only 32-bit integer elements because narrow " "vcadd widens its result"); case VMIReductionKind::AddF: - if (elementType.isF32()) + if (elementType.isF16() || elementType.isF32()) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "currently supports only f16/f32 elements for floating-point " + "reduction"); + case VMIReductionKind::GroupAddI: { + auto intType = dyn_cast(elementType); + if (intType && intType.getWidth() == 32) + return VMICapabilityResult::supported(); + return VMICapabilityResult::missingCapability( + "grouped integer add reduction supports only i32 accumulator " + "elements because narrow integer reductions widen their result; " + "cast i8/i16 storage before grouped reduction"); + } + case VMIReductionKind::GroupAddF: + if (elementType.isF16() || elementType.isF32()) return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( - "currently supports only f32 elements; f16 requires an explicit " - "accumulator precision and rounding contract"); + "grouped floating-point add reduction supports f16/f32 accumulator " + "elements"); case VMIReductionKind::MaxF: case VMIReductionKind::MinF: if (elementType.isF16() || elementType.isF32()) diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index ff7170044e..b504de67f5 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -61,6 +61,16 @@ static bool isVMIIntegerLikeType(Type type) { return isa(type); } +static bool isVMISignedOrSignlessIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && !integerType.isUnsigned(); +} + +static bool isVMIUnsignedIntegerType(Type type) { + auto integerType = dyn_cast(type); + return integerType && integerType.isUnsigned(); +} + static bool isVMIIotaElementType(Type type) { if (auto intType = dyn_cast(type)) return intType.getWidth() == 8 || intType.getWidth() == 16 || @@ -1154,6 +1164,50 @@ LogicalResult VMIGroupReduceAddFOp::verify() { getNumGroupsAttr().getInt()); } +LogicalResult VMIGroupReduceAddIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto maskType = cast(getMask().getType()); + auto resultType = cast(getResult().getType()); + if (!isVMIIntegerLikeType(sourceType.getElementType())) + return emitOpError("requires integer-like VMI source element type"); + auto intType = dyn_cast(sourceType.getElementType()); + if (!intType || intType.getWidth() != 32) + return emitOpError( + "requires i32 accumulator element type; cast i8/i16 storage to i32 " + "before grouped reduction because integer reduction widens narrow " + "inputs"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return emitOpError("requires source and result element types to match"); + if (auto sourceLayout = sourceType.getLayoutAttr()) { + bool supportedSourceLayout = + sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)) || + (sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && + (sourceLayout.getBlockElems() == 1 || + sourceLayout.getBlockElems() == 8)); + if (!supportedSourceLayout) + return emitOpError( + "requires layout-assigned source to use contiguous layout or " + "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isGroupSlots() || + resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + return emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; + } + if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + return failure(); + return verifyNumGroups(getOperation(), sourceType, + getNumGroupsAttr().getInt()); +} + LogicalResult VMIGroupBroadcastOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); @@ -1212,6 +1266,56 @@ LogicalResult VMITruncFOp::verify() { return success(); } +LogicalResult VMIExtSIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMISignedOrSignlessIntegerType(sourceType.getElementType()) || + !isVMISignedOrSignlessIntegerType(resultType.getElementType())) + return emitOpError( + "requires signed or signless integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMIExtUIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIUnsignedIntegerType(sourceType.getElementType()) || + !isVMIUnsignedIntegerType(resultType.getElementType())) + return emitOpError( + "requires unsigned integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) >= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be wider than source element type"); + return success(); +} + +LogicalResult VMITruncIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIIntegerLikeType(sourceType.getElementType()) || + !isVMIIntegerLikeType(resultType.getElementType())) + return emitOpError("requires integer source and result element types"); + if (getVMIElementBitWidth(sourceType.getElementType()) <= + getVMIElementBitWidth(resultType.getElementType())) + return emitOpError( + "requires result element type to be narrower than source element type"); + return success(); +} + LogicalResult VMIBitcastOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 2ff9e50ae2..5f30ba82e0 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -236,6 +236,13 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups); } + std::optional getVLaneElems(Type elementType) { + FailureOr lanesPerPart = getDataLanesPerPart(elementType); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return std::nullopt; + return *lanesPerPart / 8; + } + VMILayoutAttr getPreferredGroupSlotsLayout(VMIVRegType type, int64_t numGroups) { if (VMILayoutAttr existing = type.getLayoutAttr()) @@ -243,11 +250,10 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 8) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 16) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - if (groupSize == 32) + std::optional vlaneElems = getVLaneElems(type.getElementType()); + if (vlaneElems && (groupSize == *vlaneElems || + groupSize == 2 * *vlaneElems || + groupSize == 4 * *vlaneElems)) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); FailureOr lanesPerPart = getDataLanesPerPart(type.getElementType()); @@ -264,9 +270,10 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 16) + std::optional vlaneElems = getVLaneElems(type.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems) return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); - if (groupSize == 32) + if (vlaneElems && groupSize == 4 * *vlaneElems) return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); } return getContiguousLayout(); @@ -279,7 +286,9 @@ struct LayoutSolver { return existing; if (numGroups > 0 && type.getElementCount() % numGroups == 0) { int64_t groupSize = type.getElementCount() / numGroups; - if (groupSize == 64) + FailureOr lanesPerPart = + getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart) && groupSize == *lanesPerPart) return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); @@ -349,7 +358,8 @@ struct LayoutSolver { if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && solved.getSlots() > 0) return solved; - if (value.getDefiningOp()) + if (value.getDefiningOp() || + value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); if (value.getDefiningOp()) return getPreferredGroupSlotLoadLayout(type, numGroups); @@ -387,9 +397,10 @@ struct LayoutSolver { if (!resultType) continue; unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (groupSize == 16 && resultBits == 16) + std::optional vlaneElems = getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && resultBits == 16) return true; - if (groupSize == 32 && resultBits == 8) + if (vlaneElems && groupSize == 4 * *vlaneElems && resultBits == 8) return true; } return false; @@ -810,12 +821,16 @@ struct LayoutSolver { if (solvedSourceLayout && numGroups > 0 && sourceType.getElementCount() % numGroups == 0) { int64_t groupSize = sourceType.getElementCount() / numGroups; - if (groupSize == 16 && solvedSourceLayout.isDeinterleaved() && + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && solvedSourceLayout.getFactor() == 2 && (solvedSourceLayout.getBlockElems() == 1 || solvedSourceLayout.getBlockElems() == 8)) sourceLayout = solvedSourceLayout; - if (groupSize == 32 && solvedSourceLayout.isDeinterleaved() && + if (vlaneElems && groupSize == 4 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && solvedSourceLayout.getFactor() == 4 && (solvedSourceLayout.getBlockElems() == 1 || solvedSourceLayout.getBlockElems() == 8)) @@ -825,10 +840,12 @@ struct LayoutSolver { int64_t groupSize = sourceType.getElementCount() / numGroups; if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), groupSize)) { - if (groupSize == 16) + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems) sourceLayout = VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); - if (groupSize == 32) + if (vlaneElems && groupSize == 4 * *vlaneElems) sourceLayout = VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); } @@ -846,6 +863,45 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( + sourceType, reduce.getNumGroupsAttr().getInt()); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + if (solvedSourceLayout && numGroups > 0 && + sourceType.getElementCount() % numGroups == 0) { + int64_t groupSize = sourceType.getElementCount() / numGroups; + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); + if (vlaneElems && groupSize == 2 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 2 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + if (vlaneElems && groupSize == 4 * *vlaneElems && + solvedSourceLayout.isDeinterleaved() && + solvedSourceLayout.getFactor() == 4 && + (solvedSourceLayout.getBlockElems() == 1 || + solvedSourceLayout.getBlockElems() == 8)) + sourceLayout = solvedSourceLayout; + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + getPreferredGroupSlotsLayout( + resultType, reduce.getNumGroupsAttr().getInt()), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto broadcast = dyn_cast(op)) { auto sourceType = cast(broadcast.getSource().getType()); requestDataUse(broadcast.getSourceMutable(), @@ -873,6 +929,46 @@ struct LayoutSolver { } return WalkResult::advance(); } + if (auto extsi = dyn_cast(op)) { + auto sourceType = cast(extsi.getSource().getType()); + auto resultType = cast(extsi.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extsi.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extsi.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + if (auto extui = dyn_cast(op)) { + auto sourceType = cast(extui.getSource().getType()); + auto resultType = cast(extui.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32) { + requestDataUse(extui.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extui.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 2), + op))) + return WalkResult::interrupt(); + } else if (sourceBits == 8 && resultBits == 32) { + requestDataUse(extui.getSourceMutable(), getContiguousLayout()); + if (failed(setNaturalLayout(extui.getResult(), + VMILayoutAttr::getDeinterleaved(ctx, 4), + op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } if (auto truncf = dyn_cast(op)) { auto sourceType = cast(truncf.getSource().getType()); auto resultType = cast(truncf.getResult().getType()); @@ -897,6 +993,30 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto trunci = dyn_cast(op)) { + auto sourceType = cast(trunci.getSource().getType()); + auto resultType = cast(trunci.getResult().getType()); + unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); + unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); + if (sourceBits == 32 && resultBits == 16 && sourceLayout && + sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + requestDataUse(trunci.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (sourceBits == 32 && resultBits == 16) + requestDataUse(trunci.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 2)); + else if (sourceBits == 32 && resultBits == 8) + requestDataUse(trunci.getSourceMutable(), + VMILayoutAttr::getDeinterleaved(ctx, 4)); + if (failed(setNaturalLayout(trunci.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto bitcast = dyn_cast(op)) { if (failed(unite(bitcast.getSource(), bitcast.getResult(), op))) return WalkResult::interrupt(); @@ -1463,6 +1583,14 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed(requestMaskUse( diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp index 7cd5281353..34b843737c 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp @@ -657,9 +657,12 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( } FailureOr -VMILocalRecipeRegistry::getGroupReduceAddFRecipe( - const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, - std::string *reason) const { +getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, + Operation *op, VMIVRegType sourceType, + VMIMaskType maskType, VMIVRegType resultType, + int64_t numGroups, bool requiresReassoc, + VMIReductionKind reductionKind, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -667,17 +670,13 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return failure(); }; - if (!op->hasAttr("reassoc")) + if (requiresReassoc && !op->hasAttr("reassoc")) return fail("requires reassoc attr for pair-wise floating-point " "reduction"); - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); if (!sourceLayout || !maskLayout || !resultLayout) return fail("requires assigned source, mask, and result layouts"); if (!resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups) @@ -685,23 +684,28 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( if (resultLayout.getSlots() != 8 && resultLayout.getSlots() != 1) { FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, numGroups, reason); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + int64_t vlaneElems = + succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 ? *lanesPerPart / 8 + : -1; if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && - *groupSize != 8 && *groupSize != 16 && *groupSize != 32) - return fail("stable group_reduce_addf slots=8 recipes support group " - "size 8, 16, or 32"); - return fail("stable group_reduce_addf local recipes currently require " + (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && + *groupSize != 4 * vlaneElems)) + return fail("stable group_reduce_add slots=8 recipes support group " + "sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems"); + return fail("stable group_reduce_add local recipes currently require " "result layout slots=8 or slots=1"); } VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType(VMIReductionKind::AddF, + capabilities.supportsReductionElementType(reductionKind, sourceType.getElementType()); if (!elementCapability.isSupported()) return fail(elementCapability.reason); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("stable group_reduce_addf local recipes require f32 " - "source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("stable group_reduce_add local recipes require matching " + "source/result element types"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -709,6 +713,11 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( getGroupSizeFromNumGroups(sourceType, numGroups, reason); if (failed(groupSize)) return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + int64_t vlaneElems = *lanesPerPart / 8; FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); @@ -719,81 +728,105 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( return fail("requires matching non-empty source/mask physical arity"); if (resultLayout.getSlots() == 1) { - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); if (failed(lanesPerPart) || *groupSize < *lanesPerPart || *groupSize % *lanesPerPart != 0) - return fail("stable group_reduce_addf slots=1 recipes support group " + return fail("stable group_reduce_add slots=1 recipes support group " "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("slots=1 group_reduce_addf requires contiguous source/mask " + return fail("slots=1 group_reduce_add requires contiguous source/mask " "layouts"); if (*resultArity != numGroups) - return fail("slots=1 group_reduce_addf requires one physical result " + return fail("slots=1 group_reduce_add requires one physical result " "part per group"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("slots=1 group_reduce_addf requires full source " + return fail(Twine("slots=1 group_reduce_add requires full source " "chunks; ") + sourceFullReason); return VMIGroupReduceAddFRecipe{ VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; } - if (*groupSize == 8) { + if (*groupSize == vlaneElems) { if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) - return fail("s8 group_reduce_addf requires contiguous source/mask " + return fail("one-vlane group_reduce_add requires contiguous source/mask " "layouts"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("s8 group_reduce_addf requires full source chunks; ") + + return fail(Twine("one-vlane group_reduce_add requires full source " + "chunks; ") + sourceFullReason); if (*resultArity != *sourceArity) - return fail("s8 group_reduce_addf requires source/result physical " + return fail("one-vlane group_reduce_add requires source/result physical " "arity to match"); - return VMIGroupReduceAddFRecipe{VMIGroupReduceAddFRecipeKind::S8Vcgadd}; + return VMIGroupReduceAddFRecipe{ + VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd}; } - if (*groupSize == 16) { + if (*groupSize == 2 * vlaneElems) { if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s16 group_reduce_addf requires source layout " + return fail("two-vlane group_reduce_add requires source layout " "deinterleaved=2 with block_elems=1 or block_elems=8"); if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s16 group_reduce_addf requires matching mask layout " + return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2) - return fail("s16 group_reduce_addf requires two source/mask parts per " + return fail("two-vlane group_reduce_add requires two source/mask parts per " "result part"); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S16Deinterleaved2VcgaddVadd}; + VMIGroupReduceAddFRecipeKind::TwoVLaneDeinterleaved2VcgaddVadd}; } - if (*groupSize == 32) { + if (*groupSize == 4 * vlaneElems) { if (!sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s32 group_reduce_addf requires source layout " + return fail("four-vlane group_reduce_add requires source layout " "deinterleaved=4 with block_elems=1 or block_elems=8"); if (!maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s32 group_reduce_addf requires matching mask layout " + return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4) - return fail("s32 group_reduce_addf requires four source/mask parts per " + return fail("four-vlane group_reduce_add requires four source/mask parts per " "result part"); return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::S32Deinterleaved4VcgaddTree}; + VMIGroupReduceAddFRecipeKind::FourVLaneDeinterleaved4VcgaddTree}; } - return fail("stable group_reduce_addf slots=8 recipes support group size " - "8, 16, or 32"); + return fail("stable group_reduce_add slots=8 recipes support group sizes " + "VLaneElems, 2*VLaneElems, or 4*VLaneElems"); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddFRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, + std::string *reason) const { + return getGroupReduceAddRecipeImpl( + capabilities, op.getOperation(), cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/true, + VMIReductionKind::GroupAddF, reason); +} + +FailureOr +VMILocalRecipeRegistry::getGroupReduceAddIRecipe( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, + std::string *reason) const { + return getGroupReduceAddRecipeImpl( + capabilities, op.getOperation(), cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupAddI, reason); } FailureOr @@ -964,6 +997,118 @@ VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, "physical arity"); } +template +static FailureOr getExtIRecipeImpl(OpT op, + std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + !isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires contiguous integer source layout and deinterleaved " + "integer result layout"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceBits == 16 && resultBits == 32 && resultLayout.getFactor() == 2 && + *resultArity == 2 * *sourceArity) + return VMIExtIRecipe{ + VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32}; + if (sourceBits == 8 && resultBits == 32 && resultLayout.getFactor() == 4 && + *resultArity == 4 * *sourceArity) + return VMIExtIRecipe{ + VMIExtIRecipeKind::ContiguousI8ToDeinterleaved4I32}; + + return fail("unsupported integer extension source/result element width, " + "result factor, or physical arity"); +} + +FailureOr +VMILocalRecipeRegistry::getExtSIRecipe(VMIExtSIOp op, + std::string *reason) const { + return getExtIRecipeImpl(op, reason); +} + +FailureOr +VMILocalRecipeRegistry::getExtUIRecipe(VMIExtUIOp op, + std::string *reason) const { + return getExtIRecipeImpl(op, reason); +} + +FailureOr +VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (!sourceLayout || !resultLayout || failed(sourceArity) || + failed(resultArity)) + return fail("requires assigned source/result layouts and computable " + "physical arity"); + if (!isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires integer source and result element types"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + + if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || + sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + sourceBits != 32 || resultBits != 16 || *sourceArity != *resultArity) + return fail("group-slot trunci requires matching " + "group_slots(num_groups=G, slots=1) source/result layouts, " + "32-bit integer source, 16-bit integer result, and matching " + "physical arity"); + return VMITruncIRecipe{VMITruncIRecipeKind::GroupSlots1I32ToI16}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + sourceBits != 32 || *resultArity != 1) + return fail("requires 32-bit integer deinterleaved source and contiguous " + "integer result"); + + if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) + return VMITruncIRecipe{ + VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16}; + if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8 && + cast(resultType.getElementType()).isUnsigned()) + return VMITruncIRecipe{ + VMITruncIRecipeKind::Deinterleaved4I32ToContiguousI8}; + + return fail("unsupported deinterleaved trunci factor, arity, result element " + "width, or result signedness; 32-bit to 8-bit integer narrowing " + "requires unsigned i8 result"); +} + FailureOr VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, std::string *reason) const { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 1c92be4018..c44fc114ec 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/raw_ostream.h" #include +#include namespace mlir { namespace pto { @@ -2434,12 +2435,17 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, return failure(); }; - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("vcgadd group_reduce_addf path requires f32 source/result"); - if (groupSize != 8) - return fail("vcgadd group_reduce_addf path requires group size = 8 for " - "f32 32-byte VLane groups"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("vcgadd group_reduce_add path requires matching " + "source/result element types"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("vcgadd group_reduce_add path requires known VLane width"); + int64_t vlaneElems = *lanesPerPart / 8; + if (groupSize != vlaneElems) + return fail("vcgadd group_reduce_add path requires group size equal to " + "one 32-byte VLane"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2447,27 +2453,28 @@ LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, if (!sourceLayout || !resultLayout || !maskLayout || !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) - return fail("vcgadd group_reduce_addf path requires contiguous source/mask " + return fail("vcgadd group_reduce_add path requires contiguous source/mask " "layouts and matching num_groups result layout"); std::string sourceFullReason; if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("vcgadd group_reduce_addf path requires full source " + return fail(Twine("vcgadd group_reduce_add path requires full source " "chunks; ") + sourceFullReason); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("vcgadd group_reduce_addf path requires computable physical " + return fail("vcgadd group_reduce_add path requires computable physical " "arity"); if (*sourceArity < 1 || *sourceArity != *maskArity || *sourceArity != *resultArity) - return fail("vcgadd group_reduce_addf path requires matching non-empty " + return fail("vcgadd group_reduce_add path requires matching non-empty " "source/mask/result physical arity"); return success(); } -LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, +template +LogicalResult checkS16Block8GroupReduceShape(OpTy op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -2478,14 +2485,18 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("s16 block8 group_reduce_addf requires f32 source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("two-vlane group_reduce_add requires matching source/result " + "element types"); FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - if (failed(groupSize) || *groupSize != 16) - return fail("s16 block8 group_reduce_addf requires group size 16"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || + *groupSize != 2 * (*lanesPerPart / 8)) + return fail("two-vlane group_reduce_add requires group size equal to two " + "32-byte VLanes"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2494,33 +2505,34 @@ LogicalResult checkS16Block8GroupReduceShape(VMIGroupReduceAddFOp op, if (!sourceLayout || !sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 2 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s16 group_reduce_addf requires source layout " + return fail("two-vlane group_reduce_add requires source layout " "deinterleaved=2 with block_elems=1 or block_elems=8"); if (!maskLayout || !maskLayout.isDeinterleaved() || maskLayout.getFactor() != 2 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s16 group_reduce_addf requires matching mask layout " + return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); if (!resultLayout || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("s16 block8 group_reduce_addf requires " + return fail("two-vlane group_reduce_add requires " "group_slots(num_groups, slots=8) result layout"); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("s16 block8 group_reduce_addf requires computable physical " + return fail("two-vlane group_reduce_add requires computable physical " "arity"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || *maskArity != *sourceArity) - return fail("s16 block8 group_reduce_addf requires two source/mask " + return fail("two-vlane group_reduce_add requires two source/mask " "parts per result part"); return success(); } -LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, +template +LogicalResult checkS32Block8GroupReduceShape(OpTy op, std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -2531,14 +2543,18 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); auto resultType = cast(op.getResult().getType()); - if (!sourceType.getElementType().isF32() || - sourceType.getElementType() != resultType.getElementType()) - return fail("s32 block8 group_reduce_addf requires f32 source/result"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("four-vlane group_reduce_add requires matching source/result " + "element types"); FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - if (failed(groupSize) || *groupSize != 32) - return fail("s32 block8 group_reduce_addf requires group size 32"); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || + *groupSize != 4 * (*lanesPerPart / 8)) + return fail("four-vlane group_reduce_add requires group size equal to four " + "32-byte VLanes"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr maskLayout = maskType.getLayoutAttr(); @@ -2547,27 +2563,27 @@ LogicalResult checkS32Block8GroupReduceShape(VMIGroupReduceAddFOp op, if (!sourceLayout || !sourceLayout.isDeinterleaved() || sourceLayout.getFactor() != 4 || (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("s32 group_reduce_addf requires source layout " + return fail("four-vlane group_reduce_add requires source layout " "deinterleaved=4 with block_elems=1 or block_elems=8"); if (!maskLayout || !maskLayout.isDeinterleaved() || maskLayout.getFactor() != 4 || maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("s32 group_reduce_addf requires matching mask layout " + return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); if (!resultLayout || !resultLayout.isGroupSlots() || resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("s32 block8 group_reduce_addf requires " + return fail("four-vlane group_reduce_add requires " "group_slots(num_groups, slots=8) result layout"); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr maskArity = getVMIPhysicalArity(maskType); FailureOr resultArity = getVMIPhysicalArity(resultType); if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("s32 block8 group_reduce_addf requires computable physical " + return fail("four-vlane group_reduce_add requires computable physical " "arity"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || *maskArity != *sourceArity) - return fail("s32 block8 group_reduce_addf requires four source/mask " + return fail("four-vlane group_reduce_add requires four source/mask " "parts per result part"); return success(); @@ -5551,13 +5567,12 @@ struct OneToNVMIReduceAddFOpPattern } }; -struct OneToNVMIGroupReduceAddFOpPattern - : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupReduceAddFOp>::OneToNOpConversionPattern; +template +struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(VMIGroupReduceAddFOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); auto maskVMIType = cast(op.getMask().getType()); @@ -6295,6 +6310,215 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { } }; +template +struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite( + OpT op, typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.empty()) + return rewriter.notifyMatchFailure( + op, "integer extension requires at least one physical source chunk"); + + auto sourceType = dyn_cast(sourceParts.front().getType()); + if (!sourceType) + return rewriter.notifyMatchFailure( + op, "expected physical integer extension source"); + for (Value sourcePart : sourceParts) { + auto currentSourceType = dyn_cast(sourcePart.getType()); + if (!currentSourceType || currentSourceType != sourceType) + return rewriter.notifyMatchFailure( + op, "integer extension source physical parts must have matching " + "type"); + } + + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || + !isa(resultVRegType.getElementType()) || + (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( + resultVRegType.getElementType()) != 32 + : resultVRegType != + resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension result type"); + resultVRegTypes.push_back(resultVRegType); + } + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + ArrayRef parts; + int64_t factor = 0; + if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + factor = 2; + } else if (sourceBits == 8 && + resultTypes.size() == 4 * sourceParts.size()) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + factor = 4; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical integer extension source/result width " + "relation"); + } + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to build integer extension seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + for (auto [chunkIndex, sourcePart] : llvm::enumerate(sourceParts)) { + VRegType resultType = + resultVRegTypes[partIndex * sourceParts.size() + chunkIndex]; + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMITruncIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()) != + 32 || + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci shape"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr even = rewriter.getStringAttr("EVEN"); + FailureOr lane0Mask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", + rewriter); + if (failed(lane0Mask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot trunci lane0 mask"); + for (auto [sourcePart, physicalResultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultType = dyn_cast(physicalResultType); + if (!sourceType || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || + !resultType || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci physical type"); + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *lane0Mask, + /*rnd=*/nullptr, sat, even) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + + if ((sourceParts.size() != 2 && sourceParts.size() != 4) || + resultTypes.size() != 1) + return rewriter.notifyMatchFailure( + op, "only 32-bit integer deinterleaved=2/4 to 16/8-bit contiguous " + "trunci is supported"); + + auto sourceType0 = dyn_cast(sourceParts.front().getType()); + auto resultType = dyn_cast(resultTypes.front()); + if (!sourceType0 || !isa(sourceType0.getElementType()) || + !resultType || !isa(resultType.getElementType())) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result type"); + for (Value sourcePart : sourceParts) { + auto sourceType = dyn_cast(sourcePart.getType()); + if (!sourceType || sourceType != sourceType0) + return rewriter.notifyMatchFailure( + op, "trunci source physical parts must have matching 32-bit " + "integer type"); + } + + if (pto::getPTOStorageElemBitWidth(sourceType0.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "trunci source physical element width must be 32-bit"); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + ArrayRef parts; + if (sourceParts.size() == 2 && resultBits == 16) { + static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; + parts = kEvenOddParts; + } else if (sourceParts.size() == 4 && resultBits == 8) { + static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; + parts = kPacked4Parts; + } else { + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result width relation"); + } + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(sourceMask) || failed(resultMask)) + return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); + + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector partials; + partials.reserve(parts.size()); + for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { + partials.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *sourceMask, + /*rnd=*/nullptr, sat, + rewriter.getStringAttr(part)) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + + rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -6782,10 +7006,14 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddFOpPattern, OneToNVMIGroupBroadcastOpPattern, + OneToNVMIGroupReduceAddOpPattern, + OneToNVMIGroupReduceAddOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, + OneToNVMIExtIOpPattern, + OneToNVMIExtIOpPattern, OneToNVMITruncIOpPattern, OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( typeConverter, patterns.getContext()); @@ -6845,6 +7073,30 @@ LogicalResult checkSupportedTruncFShape(VMITruncFOp op, return success(); } +LogicalResult checkSupportedExtSIShape(VMIExtSIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtSIRecipe(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedExtUIShape(VMIExtUIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getExtUIRecipe(op, reason))) + return failure(); + return success(); +} + +LogicalResult checkSupportedTruncIShape(VMITruncIOp op, + std::string *reason = nullptr) { + VMILocalRecipeRegistry recipes; + if (failed(recipes.getTruncIRecipe(op, reason))) + return failure(); + return success(); +} + LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { VMILocalRecipeRegistry recipes; if (failed(recipes.getBitcastRecipe(op, reason))) @@ -7143,8 +7395,9 @@ checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, return success(); } -LogicalResult checkSupportedGroupReduceAddFShape( - const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, +template +LogicalResult checkSupportedGroupReduceAddShape( + const VMITargetCapabilityRegistry &capabilities, OpTy op, std::string *reason = nullptr) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) @@ -7152,8 +7405,10 @@ LogicalResult checkSupportedGroupReduceAddFShape( return failure(); }; - if (!op->hasAttr("reassoc")) + if constexpr (std::is_same_v) { + if (!op->hasAttr("reassoc")) return fail("requires reassoc attr for pair-wise floating-point reduction"); + } auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); auto maskType = cast(op.getMask().getType()); @@ -7164,8 +7419,13 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires assigned source, mask, and result layouts"); VMILocalRecipeRegistry recipes; - if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) - return success(); + if constexpr (std::is_same_v) { + if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + return success(); + } else { + if (succeeded(recipes.getGroupReduceAddIRecipe(capabilities, op, nullptr))) + return success(); + } FailureOr groupSize = getGroupSizeFromNumGroups( sourceType, op.getNumGroupsAttr().getInt(), reason); @@ -7181,7 +7441,9 @@ LogicalResult checkSupportedGroupReduceAddFShape( return fail("requires contiguous source/mask layouts and matching " "num_groups result layout"); VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType(VMIReductionKind::AddF, + capabilities.supportsReductionElementType( + std::is_same_v ? VMIReductionKind::GroupAddF + : VMIReductionKind::GroupAddI, sourceType.getElementType()); if (!elementCapability.isSupported()) return fail(elementCapability.reason); @@ -7204,10 +7466,15 @@ LogicalResult checkSupportedGroupReduceAddFShape( if (resultLayout.getSlots() <= 0) return success(); - if (!sourceLayout.isContiguous() || *groupSize != 64 || + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return fail("requires known physical chunk lane count"); + if (!sourceLayout.isContiguous() || *groupSize != *lanesPerPart || resultLayout.getSlots() != 1) - return fail("explicit group_slots group_reduce_addf chunk path requires " - "contiguous group size 64 source and slots=1 result layout"); + return fail("explicit group_slots group_reduce_add chunk path requires " + "contiguous full-physical-chunk group size source and slots=1 " + "result layout"); return success(); } @@ -7843,13 +8110,13 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedGroupReduceAddFShape(capabilities, reduce, - &reason))) + if (succeeded( + checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for f32 " - "32B groups or through pto.vcadd with reassoc, contiguous full " + << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " + "VLane groups or through pto.vcadd with reassoc, contiguous full " "source/mask chunks, #pto.vmi.layout result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" @@ -7857,6 +8124,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_addi lowers through pto.vcgadd/vadd only " + "for i32 accumulator values; i8/i16 storage must be cast to i32 " + "before grouped reduction because narrow integer reductions " + "widen their result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( @@ -7930,6 +8212,51 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto extsi = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtSIShape(extsi, &reason))) + return WalkResult::advance(); + + extsi.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extsi supports contiguous signed/signless 8-bit or " + "16-bit integer physical source chunks to 32-bit integer " + "deinterleaved=4/2 results (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto extui = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedExtUIShape(extui, &reason))) + return WalkResult::advance(); + + extui.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.extui supports contiguous unsigned 8-bit or 16-bit " + "integer physical source chunks to unsigned 32-bit integer " + "deinterleaved=4/2 results (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto trunci = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedTruncIShape(trunci, &reason))) + return WalkResult::advance(); + + trunci.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.trunci supports only 32-bit integer deinterleaved=2 " + "source parts to one contiguous 16-bit integer result chunk, " + "32-bit integer deinterleaved=4 source parts to one contiguous " + "8-bit integer result chunk, or 32-bit integer " + "group_slots(num_groups=G, slots=1) to 16-bit integer " + "group_slots(num_groups=G, slots=1) (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto bitcast = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedBitcastShape(bitcast, &reason))) diff --git a/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto new file mode 100644 index 0000000000..948dfe9c54 --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i16_group_reduce_invalid( + %source: !pto.vmi.vreg<128xi16>, + %mask: !pto.vmi.mask<128xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi16> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto new file mode 100644 index 0000000000..578acc00b9 --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i8_group_reduce_invalid( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xpred>) { + %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xi8> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index eccb4e0007..b322e5700e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -16,7 +16,7 @@ module { %off: index) { %c1 = arith.constant 1 : index // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto new file mode 100644 index 0000000000..34bf1c9633 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @typed_group_reduce_assignment( + %f16: !pto.vmi.vreg<256xf16>, + %mf16: !pto.vmi.mask<256xpred>, + %i16: !pto.vmi.vreg<128xi16>, + %mi16: !pto.vmi.mask<128xpred>, + %i32: !pto.vmi.vreg<128xi32>, + %mi32: !pto.vmi.mask<128xpred>) { + %sum_f16 = pto.vmi.group_reduce_addf %f16, %mf16 {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf16> + %wide_i16 = pto.vmi.extsi %i16 + : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum_i16 = pto.vmi.group_reduce_addi %wide_i16, %mi16 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + %sum_i32 = pto.vmi.group_reduce_addi %i32, %mi32 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + return + } +} + +// CHECK-LABEL: func.func @typed_group_reduce_assignment( +// CHECK: %[[F16_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MF16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MF16_B16:.*]] = pto.vmi.ensure_mask_granularity %[[MF16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addf %[[F16_SPLIT]], %[[MF16_B16]] +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[WIDE_I16:.*]] = pto.vmi.extsi %arg2 +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI16_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[WIDE_I16]], %[[MI16_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[I32_SPLIT:.*]] = pto.vmi.ensure_layout +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK: %[[MI32_SPLIT:.*]] = pto.vmi.ensure_mask_layout +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: pto.vmi.group_reduce_addi %[[I32_SPLIT]], %[[MI32_SPLIT]] +// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto index 673f3ee47b..33a7bc0fae 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=8 recipes support group size 8, 16, or 32 + // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto index 6e0b04e8f6..d33315f88d 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto @@ -13,7 +13,7 @@ module { %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf slots=1 recipes support group sizes that are multiples of one physical chunk + // CHECK-SAME: stable group_reduce_add slots=1 recipes support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto index b8576fe3b7..c787f57fea 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -29,7 +29,7 @@ module { %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_addf local recipes currently require result layout slots=8 or slots=1 + // CHECK-SAME: stable group_reduce_add local recipes currently require result layout slots=8 or slots=1 %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto new file mode 100644 index 0000000000..f01c6865a1 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_group_reduce_addf_f16_vlane( + %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, + !pto.vmi.mask<128xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %part : !pto.vreg<128xf16> + } + + func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( + %source: !pto.vmi.vreg<128xi16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<128xi16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %out = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } + + func.func @vmi_group_reduce_addi_i32_two_vlane( + %source: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> { + %out = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vreg<64xi32> + return %part : !pto.vreg<64xi32> + } +} + +// CHECK-LABEL: func.func @vmi_group_reduce_addf_f16_vlane( +// CHECK: %[[OUT:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i16_storage_to_i32_vlane( +// CHECK: %[[EVEN:.*]] = pto.vcvt %arg0, {{.*}} {part = "EVEN"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[ODD:.*]] = pto.vcvt %arg0, {{.*}} {part = "ODD"} : !pto.vreg<128xi16>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[EVEN]], %arg1 +// CHECK: %[[S1:.*]] = pto.vcgadd %[[ODD]], %arg2 +// CHECK: %[[SUM:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: return %[[SUM]] + +// CHECK-LABEL: func.func @vmi_group_reduce_addi_i32_two_vlane( +// CHECK: %[[MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask +// CHECK: %[[SLO:.*]] = pto.vcgadd %arg0, %arg2 +// CHECK: %[[SHI:.*]] = pto.vcgadd %arg1, %arg3 +// CHECK: %[[SUM:.*]] = pto.vadd %[[SLO]], %[[SHI]], %[[MASK]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto new file mode 100644 index 0000000000..c3e7403e91 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_extsi_i8_to_i32_group_reduce( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> { + %wide = pto.vmi.extsi %source + : !pto.vmi.vreg<256xi8> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %sum = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + !pto.vmi.mask<256xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_extsi_i8_to_i32_group_reduce( +// CHECK: %[[P0:.*]] = pto.vcvt %arg0, {{.*}} {part = "P0"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P1:.*]] = pto.vcvt %arg0, {{.*}} {part = "P1"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P2:.*]] = pto.vcvt %arg0, {{.*}} {part = "P2"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[P3:.*]] = pto.vcvt %arg0, {{.*}} {part = "P3"} : !pto.vreg<256xi8>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: %[[S0:.*]] = pto.vcgadd %[[P0]] +// CHECK: %[[S1:.*]] = pto.vcgadd %[[P1]] +// CHECK: %[[S2:.*]] = pto.vcgadd %[[P2]] +// CHECK: %[[S3:.*]] = pto.vcgadd %[[P3]] +// CHECK: %[[A01:.*]] = pto.vadd %[[S0]], %[[S1]] +// CHECK: %[[A23:.*]] = pto.vadd %[[S2]], %[[S3]] +// CHECK: %[[SUM:.*]] = pto.vadd %[[A01]], %[[A23]] +// CHECK: return %[[SUM]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto new file mode 100644 index 0000000000..50051aab6d --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_extui_u8_to_u32( + %input: !pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) { + %wide = pto.vmi.extui %input + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xui32, #pto.vmi.layout>) + -> (!pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xui32>, !pto.vreg<64xui32>, + !pto.vreg<64xui32>, !pto.vreg<64xui32> + } + + func.func @vmi_to_vpto_trunci_i32_to_ui8( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xui8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> !pto.vreg<256xui8> + return %p : !pto.vreg<256xui8> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u32( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xui8> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P0"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P1"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P2"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "P3"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_to_ui8( +// CHECK: %[[P0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P1", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P2:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P2", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[P3:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P3", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: %[[M01:.*]] = pto.vor %[[P0]], %[[P1]] +// CHECK: %[[M012:.*]] = pto.vor %[[M01]], %[[P2]] +// CHECK: pto.vor %[[M012]], %[[P3]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto similarity index 55% rename from test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto rename to test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto index 4e24ee12a8..fc4ebdc92a 100644 --- a/test/lit/vmi/vmi_to_vpto_reduce_addf_f16_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_reduce_addf_f16.pto @@ -6,21 +6,36 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_reduce_addf_f16_invalid( + func.func @vmi_to_vpto_reduce_addf_f16( %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, %init: !pto.vmi.vreg<1xf16, #pto.vmi.layout>, - %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) { + %mask: !pto.vmi.mask<128xb16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> { %out = pto.vmi.reduce_addf %source, %init, %mask {reassoc} : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.vmi.vreg<1xf16, #pto.vmi.layout>, !pto.vmi.mask<128xb16, #pto.vmi.layout> -> !pto.vmi.vreg<1xf16, #pto.vmi.layout> - return + %p = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<1xf16, #pto.vmi.layout>) + -> !pto.vreg<128xf16> + return %p : !pto.vreg<128xf16> } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.reduce_addf lowers through pto.vcadd only with reassoc -// CHECK-SAME: currently supports only f32 elements +// CHECK-LABEL: func.func @vmi_to_vpto_reduce_addf_f16( +// CHECK-SAME: %arg0: !pto.vreg<128xf16> +// CHECK-SAME: %arg1: !pto.vreg<128xf16> +// CHECK-SAME: %arg2: !pto.mask +// CHECK: %[[LANE0:.*]] = pto.pge_b16 "PAT_VL1" : !pto.mask +// CHECK: %[[REDUCED:.*]] = pto.vcadd %arg0, %arg2 +// CHECK-SAME: !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT:.*]] = pto.vadd %[[REDUCED]], %arg1, %[[LANE0]] +// CHECK-SAME: !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: return %[[OUT]] : !pto.vreg<128xf16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto new file mode 100644 index 0000000000..145ef2a7b9 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_trunci_i32_to_i8_invalid( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vreg<256xi8> { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi8, #pto.vmi.layout> + %p = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xi8, #pto.vmi.layout>) + -> !pto.vreg<256xi8> + return %p : !pto.vreg<256xi8> + } +} + +// CHECK: VMI-UNSUPPORTED +// CHECK: pto.vmi.trunci supports only +// CHECK: 32-bit to 8-bit integer narrowing requires unsigned i8 result diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py new file mode 100644 index 0000000000..fbba5d605b --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.float16) + output = np.fromfile("v2.bin", dtype=np.float16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py new file mode 100644 index 0000000000..beed48b5da --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.float16) + base = np.array([-3, -2, -1, 0, 1, 2, 3, 4], dtype=np.float16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, np.float16(-17), dtype=np.float16) + golden = np.sum(src, axis=1, dtype=np.float16).astype(np.float16) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto new file mode 100644 index 0000000000..b8d274c280 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_f16_addf_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf16> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<128xf16>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xf16> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xf16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp new file mode 100644 index 0000000000..8cfb1e58b5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/launch.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_f16_addf_store_kernel(__gm__ half *src, __gm__ half *dst); + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream) { + vmi_group_reduce_f16_addf_store_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp new file mode 100644 index 0000000000..7a92e1a331 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_f16_addf_store_kernel(uint16_t *src, uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t dstBytes = kOutputElems * sizeof(uint16_t); + uint16_t *srcHost = nullptr; + uint16_t *dstHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_f16_addf_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..00097384f0 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/golden.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 16 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int16) + base = np.array([-5, -3, -1, 0, 2, 4, 6, 8], dtype=np.int16) + for row in range(ROWS): + src[row, :] = np.tile(np.roll(base, row), 2) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..da95759e3c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i16_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xi16> + %x32 = pto.vmi.extsi %x16 : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<128xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..255de845bd --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i16_extsi_i32_addi_store_kernel(__gm__ int16_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i16_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int16_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..277a78662f --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(int16_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 128; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int16_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int16_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int16_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i16_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py new file mode 100644 index 0000000000..4153e74342 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/golden.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 8 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int32) + for row in range(ROWS): + src[row, :] = np.arange(COLS, dtype=np.int32) + row * 3 - 5 + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src, axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..783658e453 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xi32> + %sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..5783bfd5a8 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i32_addi_store_kernel(__gm__ int32_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream) { + vmi_group_reduce_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp new file mode 100644 index 0000000000..385f3ae909 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i32_addi_store_kernel(int32_t *src, int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 64; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int32_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int32_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int32_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i32_addi_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py new file mode 100644 index 0000000000..76d46fff4c --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 32 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row, :] = ((np.arange(COLS, dtype=np.int16) * 3 + row * 5) % 41 - 20).astype( + np.int8 + ) + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.sum(src.astype(np.int32), axis=1, dtype=np.int32).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto new file mode 100644 index 0000000000..97154d0dd6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i8_extsi_i32_addi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x8 = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xi8> + %x32 = pto.vmi.extsi %x8 : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32> + %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp new file mode 100644 index 0000000000..1e046a8eb5 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/launch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i8_extsi_i32_addi_store_kernel(__gm__ int8_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream) { + vmi_group_reduce_i8_extsi_i32_addi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int8_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp new file mode 100644 index 0000000000..cef9801b4d --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/main.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(int8_t *src, + int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 256; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int8_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int8_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int8_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i8_extsi_i32_addi_store_kernel(srcDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 976e07e22282858a5104b4d6764ec6887a89ccf7 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 09:11:32 +0800 Subject: [PATCH 22/54] Implement VMI layout support lowering --- docs/designs/vmi-implementation-manual.md | 137 ++-- .../vmi-layout-assignment-implementation.md | 649 +++++++++--------- .../vmi-layout-assignment-lowering-design.md | 522 +++++++------- docs/designs/vmi-layout-lowering-cases.md | 195 +++--- include/PTO/Transforms/Passes.h | 3 +- include/PTO/Transforms/Passes.td | 10 +- include/PTO/Transforms/VMILayoutSupport.h | 287 ++++++++ .../PTO/Transforms/VMILocalRecipeRegistry.h | 234 ------- lib/PTO/Transforms/CMakeLists.txt | 2 +- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 157 +++-- lib/PTO/Transforms/VMILayoutAssignment.cpp | 311 +++------ lib/PTO/Transforms/VMILayoutFoldConsumers.cpp | 6 +- .../VMILayoutSinkMaterialization.cpp | 274 +++++++- ...ecipeRegistry.cpp => VMILayoutSupport.cpp} | 505 +++++++++----- lib/PTO/Transforms/VMIToVPTO.cpp | 399 +++-------- ...assignment_broadcast_dense_group_users.pto | 6 +- .../vmi_layout_assignment_broadcast_remat.pto | 7 +- .../vmi_layout_assignment_constant_remat.pto | 8 +- ...ayout_assignment_create_group_mask_s16.pto | 6 +- ...signment_create_group_mask_s32_dynamic.pto | 6 +- ...ment_dense_group_reduce_multi_consumer.pto | 6 +- ...gnment_dense_store_group_slots_invalid.pto | 2 +- ..._layout_assignment_f32_f8_store_reduce.pto | 6 +- ...nment_group_load_block8_truncf_invalid.pto | 2 +- ...ut_assignment_group_reduce_s12_invalid.pto | 7 +- ...gnment_group_reduce_s32_tail_full_tile.pto | 12 +- ...p_reduce_s32_tail_no_full_tile_invalid.pto | 4 +- ...lot_load_slots1_dynamic_stride_invalid.pto | 2 +- ...t_load_slots1_unaligned_stride_invalid.pto | 2 +- ..._layout_assignment_group_slots_scf_for.pto | 6 +- ...group_store_slots1_unit_stride_invalid.pto | 2 +- .../vmi/vmi_layout_assignment_iota_remat.pto | 7 +- ...ignment_mask_granularity_f32_f16_store.pto | 8 +- .../vmi/vmi_layout_assignment_mask_remat.pto | 63 +- ...signment_masked_load_dense_group_users.pto | 4 +- ..._assignment_masked_load_group_tail_s32.pto | 4 +- ..._layout_assignment_non_load_s32_reduce.pto | 4 +- ...ment_packed_group_slots_truncf_invalid.pto | 2 +- ...yout_assignment_widen_f16_store_reduce.pto | 6 +- ...ayout_gate_bitcast_group_slots_invalid.pto | 2 +- ...i_layout_gate_bitcast_support_invalid.pto} | 4 +- ... vmi_layout_gate_extf_support_invalid.pto} | 4 +- ..._gate_group_broadcast_support_invalid.pto} | 4 +- ...ayout_gate_group_load_support_invalid.pto} | 4 +- ...e_group_reduce_slots1_support_invalid.pto} | 6 +- ...out_gate_group_reduce_support_invalid.pto} | 6 +- ..._gate_group_slot_load_support_invalid.pto} | 4 +- ..._group_slots_unsupported_slots_invalid.pto | 6 +- ...yout_gate_group_store_support_invalid.pto} | 4 +- ...e_helper_materialization_shape_invalid.pto | 4 +- ...mi_layout_gate_helper_support_invalid.pto} | 4 +- ...vmi_layout_gate_store_support_invalid.pto} | 4 +- ...recipe.pto => vmi_layout_gate_support.pto} | 4 +- ...mi_layout_gate_truncf_support_invalid.pto} | 4 +- .../lit/vmi/vmi_layout_rematerialize_data.pto | 17 + ...vmi_layout_sink_materialization_binary.pto | 122 ++++ ...mi_to_vpto_constant_mask_rematerialize.pto | 2 +- .../vmi_to_vpto_create_mask_rematerialize.pto | 2 +- ...o_vpto_group_broadcast_slots8_support.pto} | 4 +- .../vmi/vmi_to_vpto_group_broadcast_vselr.pto | 4 +- ...pto => vmi_to_vpto_group_load_support.pto} | 4 +- test/lit/vmi/vmi_to_vpto_group_ops.pto | 4 +- ...vpto_group_reduce_legacy_slots_invalid.pto | 27 + ... vmi_to_vpto_group_reduce_s64_support.pto} | 4 +- ...i_to_vpto_group_reduce_slots8_support.pto} | 4 +- .../vmi/vmi_to_vpto_group_reduce_vcgadd.pto | 4 +- ...to_vpto_group_reduce_vcgadd_multichunk.pto | 4 +- ...> vmi_to_vpto_group_slot_load_support.pto} | 4 +- ...vpto_group_slot_truncf_slots1_support.pto} | 4 +- test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 11 +- ...vpto_truncf_fp8_128_contiguous_invalid.pto | 4 +- 71 files changed, 2369 insertions(+), 1793 deletions(-) create mode 100644 include/PTO/Transforms/VMILayoutSupport.h delete mode 100644 include/PTO/Transforms/VMILocalRecipeRegistry.h rename lib/PTO/Transforms/{VMILocalRecipeRegistry.cpp => VMILayoutSupport.cpp} (73%) rename test/lit/vmi/{vmi_layout_gate_bitcast_recipe_invalid.pto => vmi_layout_gate_bitcast_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_layout_gate_extf_recipe_invalid.pto => vmi_layout_gate_extf_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_layout_gate_group_broadcast_recipe_invalid.pto => vmi_layout_gate_group_broadcast_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_load_recipe_invalid.pto => vmi_layout_gate_group_load_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto => vmi_layout_gate_group_reduce_slots1_support_invalid.pto} (85%) rename test/lit/vmi/{vmi_layout_gate_group_reduce_recipe_invalid.pto => vmi_layout_gate_group_reduce_support_invalid.pto} (85%) rename test/lit/vmi/{vmi_layout_gate_group_slot_load_recipe_invalid.pto => vmi_layout_gate_group_slot_load_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_group_store_recipe_invalid.pto => vmi_layout_gate_group_store_support_invalid.pto} (93%) rename test/lit/vmi/{vmi_layout_gate_helper_recipe_invalid.pto => vmi_layout_gate_helper_support_invalid.pto} (92%) rename test/lit/vmi/{vmi_layout_gate_store_recipe_invalid.pto => vmi_layout_gate_store_support_invalid.pto} (95%) rename test/lit/vmi/{vmi_layout_gate_local_recipe.pto => vmi_layout_gate_support.pto} (92%) rename test/lit/vmi/{vmi_layout_gate_truncf_recipe_invalid.pto => vmi_layout_gate_truncf_support_invalid.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_broadcast_slots8_local_recipe.pto => vmi_to_vpto_group_broadcast_slots8_support.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_load_local_recipe.pto => vmi_to_vpto_group_load_support.pto} (94%) create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto rename test/lit/vmi/{vmi_to_vpto_group_reduce_s64_local_recipe.pto => vmi_to_vpto_group_reduce_s64_support.pto} (94%) rename test/lit/vmi/{vmi_to_vpto_group_reduce_slots8_local_recipe.pto => vmi_to_vpto_group_reduce_slots8_support.pto} (91%) rename test/lit/vmi/{vmi_to_vpto_group_slot_load_local_recipe.pto => vmi_to_vpto_group_slot_load_support.pto} (91%) rename test/lit/vmi/{vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto => vmi_to_vpto_group_slot_truncf_slots1_support.pto} (96%) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 497e951e73..6bb7a7e0fe 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -143,8 +143,10 @@ values and ops. It is not part of the default PTOAS pipeline; existing PTO/VPTO unless the flag is set. The `ptoas --enable-vmi` user-facing entry also rejects public functions whose signature contains `!pto.vmi.*`. -Internal/private VMI-typed functions may still be specialized by `vmi-layout-assignment` and physicalized by -`vmi-to-vpto`, but a public VMI ABI requires an explicit materialization plan and must not be inferred from the +Internal/private VMI-typed functions are materialized at explicit boundary +helpers by baseline `vmi-layout-assignment` and physicalized by `vmi-to-vpto`. +A later optimization pass may specialize private signatures. A public VMI ABI +requires an explicit materialization plan and must not be inferred from the layout solver. CLI coverage: @@ -198,7 +200,7 @@ vmi-to-vpto: 写成 `pto.vmi.ensure_*`,physicalization 后不允许残留 `pto.vmi.*`、`!pto.vmi.*` 或 `unrealized_conversion_cast`。不能把 layout 决策藏在 pass-private side table 里让后续 pass 猜。 -源码级实现应该进一步拆成六个独立层次: +源码级实现应该进一步拆成七个独立层次: ```text IR layer: @@ -221,6 +223,20 @@ Layout solving layer: 负责从 producer/consumer/control-flow/call 关系解出每个 logical value 的 layout, 然后把结果写回 type 或 ensure_* helper。 +Layout support query layer: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + + 只放跨阶段共享的纯查询:cast layout fact、group_reduce layout fact、 + ensure_* materialization support、layout-aware store support 等。它可以被 + assignment、validation、layout optimization 和 vmi-to-vpto 调用,但不能保存 + per-value 状态,不能返回 VPTO 指令计划,不能决定 clone/rematerialize,也不能 + 通过 producer/user/control-flow context 恢复 lowering 决策。 + + 加新 query 的标准是:至少两个阶段需要同一个语义事实,并且重复实现会导致 + assignment、validation、lowering 对同一个 layout shape 得出不同结论。只有 + 一个 lowering pattern 自己使用的分支应该留在该 pattern 内。 + Layout optimization layer: lib/PTO/Transforms/VMILayoutFoldConsumers.cpp lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -257,7 +273,8 @@ Union-find + DenseMap: 用于 layout assignment 的 per-SSA-value 等价类求解。 IRRewriter/RewriterBase: - 用于 layout assignment 之后的 type rewrite、helper insertion、cheap producer rematerialization。 + 用于 layout assignment 之后的 type rewrite、helper insertion;cheap producer + rematerialization 属于后续 layout optimization pass。 OneToNTypeConverter + OneToNOpConversionPattern: 只用于 vmi-to-vpto,把一个 logical VMI value 展成多个 VPTO value。 @@ -1002,7 +1019,7 @@ SymbolTable: 解析 direct internal func.call;带 VMI type 的 external/indirect call 先拒绝。 IRRewriter: - 改写 function/block/result type,插入 ensure_*,必要时 rematerialize cheap producer。 + 改写 function/block/result type,插入 ensure_*。 verifyLayoutAssignedVMIIR: pass 末尾 hard gate,确认所有决策已经 materialize 到 IR。 @@ -1248,7 +1265,7 @@ The solver runs in phases: 3. add producer natural-layout constraints 4. add consumer layout/granularity requests 5. solve each equivalence class -6. insert ensure_* or rematerialize producers for non-class-compatible uses +6. insert ensure_* for non-class-compatible uses 7. rewrite value types and function signatures 8. run pto-validate-vmi-layout-ir ``` @@ -1300,9 +1317,10 @@ store/tile_write: consumer requests contiguous externally visible order ``` -If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless a -defined rematerialization path can split the value before the conflict. The first version should only rematerialize -trivially replayable producers: +If one equivalence class has incompatible natural layouts, the pass must diagnose `VMI-LAYOUT-CONTRACT` unless an +explicit use-site `ensure_*` can represent the requested materialization. Baseline layout assignment does not +clone/rematerialize producers. The separate `vmi-layout-rematerialize` optimization may replace an `ensure_*` +with a cloned trivially replayable producer after the materialization request is visible in IR: ```text constant @@ -1595,8 +1613,8 @@ Layout assignment completion checks: 2. No surface !pto.vmi.mask remains. 3. Every VMI function argument, result, block argument, branch operand, call operand, and return operand has the layout-assigned type selected by the solved equivalence class. -4. Every consumer-specific mismatch is represented either by a rematerialized cheap producer or by an explicit - pto.vmi.ensure_* op immediately before that consumer. +4. Every consumer-specific mismatch is represented by an explicit pto.vmi.ensure_* op immediately before that + consumer. Optional optimization passes may later replace selected helpers with rematerialized cheap producers. 5. External declarations with VMI types are rejected; they are not rewritten into an implicit ABI. ``` @@ -2430,7 +2448,6 @@ allowed layouts: bitset {contiguous, deinterleaved2, deinterleaved4} required mask granularity: pred/b8/b16/b32 or unknown natural layout preference hard constraints -soft costs ``` No information required by later passes may live only in this data structure. After the pass, type/attr/op @@ -2495,10 +2512,12 @@ bitcast: bitcast contract is defined. load/tile_read: - result layout chosen by consumers unless memory plan has a cheaper registered sink/source + baseline result layout is deterministic from explicit layout attrs or the + producer natural layout; consumer-specific alternatives are represented by + ensure_layout and optimized later store/tile_write: - can consume any layout only if target registry has preserving store path + baseline requests contiguous source layout current implementation records a contiguous use-site request for vmi.store and inserts pto.vmi.ensure_layout when the stored value class solved to a non-contiguous layout. This makes externally visible memory order explicit in @@ -2508,7 +2527,8 @@ store/tile_write: the same physical chunk count and therefore forms complete intlv groups. shuffle/channel_split/channel_merge: - default result layout contiguous unless target registry provides direct layout-preserving path + default result layout contiguous unless the current op explicitly carries a + supported layout-preserving contract current implementation supports pto.vmi.shuffle when every result physical chunk forwards one source physical chunk with identical lane positions for all non-padding result lanes. Result padding lanes are ignored by the @@ -2538,12 +2558,12 @@ Implement deterministic solving: ```text 1. Collect region/SCC constraints, including scf/cf/function/call boundaries. 2. Propagate impossible layouts and required mask granularities. -3. Pick a layout per node using minimum cost. -4. Tie-break: explicit layout already present on the VMI type, then natural layout, then contiguous. +3. Pick one layout per node using deterministic priority, not a cost model: + explicit layout already present on the VMI type, then unique natural layout, + then hard non-contiguous request, then contiguous. 5. Rewrite result/block/function types to layout-assigned VMI types. 6. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses that need conversion. -7. Clone rematerializable producers per use when cheaper than conversion. -8. Run verifier gate. +7. Run verifier gate. ``` Current implementation status: @@ -2591,7 +2611,8 @@ Do not implement a local greedy pattern pass that ignores block arguments or fun CFG 处理分两层。第一层是必须做的 layout equivalence:同一个控制流值在 result、yield、region/block argument 之间必须形成同一个 layout/mask 约束组。第二层才是 layout conflict resolution:当同一个 producer 的不同 consumers 希望不同 layout 时,插入 -`ensure_layout`、`ensure_mask_layout` 或 rematerialize producer。 +`ensure_layout` 或 `ensure_mask_layout`。后续 `vmi-layout-rematerialize` 可以把部分 helper +替换成重放的纯构造 producer。 当前可落地的最小实现先做第一层。它不尝试在 branch 边界自动插入 conversion,因此下面这些 关系一旦因为 natural layout 或 mask granularity 冲突无法合并,必须报 `VMI-LAYOUT-CONTRACT`, @@ -3051,17 +3072,13 @@ pto.vmi.group_reduce_addf: requires {reassoc} N = logical lane count; G = num_groups; S = N / G L = physical lanes per 256B chunk for the element type. - The result carries #pto.vmi.layout, a sparse group-slot - layout. It is not a dense vector layout: only group_slot(g) lanes have - semantic values. - group_slot(g) is canonical and derived from N, G, and L: - if S < L: - low_elems = L / S - chunk_stride = 1 - if S >= L: - low_elems = 1 - chunk_stride = S / L - group_slot(g) = (g / low_elems) * chunk_stride * L + (g % low_elems) + The result carries #pto.vmi.layout, a sparse + group-slot layout. It is not a dense vector layout: only slot lanes have + semantic values. Supported K values are: + K = 8 for VCGADD-style packed results, where group g is stored in + physical chunk floor(g / 8), lane g % 8. + K = 1 for row-local VCADD results, where group g is stored in physical + chunk g, lane 0. for each group g: result[group_slot(g)] = reduce_add(source[g * S .. (g + 1) * S), mask in same range) @@ -3069,10 +3086,10 @@ pto.vmi.group_reduce_addf: direct lowering materializes them as zero where the hardware path does not already define them. The result remains a VMI vector with the same element type and logical lane - count as the source, but its layout is #pto.vmi.layout. + count as the source, but its layout is an explicit group-slot layout. layout assignment: source use is requested as contiguous - result natural layout is #pto.vmi.layout + result natural layout is #pto.vmi.layout mask use is requested as contiguous with granularity derived from source element width current direct lowering: @@ -3085,8 +3102,8 @@ pto.vmi.group_reduce_addf: Otherwise: derived group size S must be a multiple of physical lanes per part lower each source chunk with pto.vcadd, combine chunks in the same group - with pto.vadd under PAT_VL1, then place group g at group_slot(g) in the - #pto.vmi.layout result. All other result chunks/lane values + with pto.vadd under PAT_VL1, then place group g in the slot lane defined by + K. All other result chunks/lane values are zero. unsupported cases: missing reassoc attr @@ -3097,17 +3114,17 @@ pto.vmi.group_reduce_addf: pto.vmi.group_broadcast: semantic: N = logical lane count; G = num_groups; S = N / G - source must carry #pto.vmi.layout. For each group g, the - source value is read from group_slot(g), using the same canonical group_slot - definition as pto.vmi.group_reduce_addf. The result broadcasts it back to + source must carry #pto.vmi.layout. For each group + g, the source value is read from the slot lane defined by K. The result broadcasts it back to each logical group: result[g * S + i] = source[group_slot(g)] layout assignment: - source use is requested as #pto.vmi.layout + source use is requested as #pto.vmi.layout result is consumer-driven. If no consumer requests another layout, it defaults to contiguous. current direct lowering: - source must carry #pto.vmi.layout with full physical chunks + source must carry #pto.vmi.layout with full + physical chunks result may be contiguous with full physical chunks result may also be deinterleaved when S is large enough that every physical result chunk stays inside one logical group, for example N=512, G=2, S=256, @@ -4011,23 +4028,27 @@ Slice 5 完成条件: writeMask fallback paths must report `VMI-UNSUPPORTED`. ``` -## 8. Target Capability Registry +## 8. Target Capabilities And Layout Fact Helpers -Add one explicit registry object, passed into layout assignment and VMI-to-VPTO: +Keep target capabilities separate from layout assignment policy. The shared +helpers expose target support and small layout/materialization facts; they do +not select a global lowering plan and are not a shared lowering-plan registry +between assignment and VMI-to-VPTO. ```text supportsElementType(type, purpose) -getNaturalLayout(op) -supportsLayoutConversion(srcLayout, dstLayout, elementType) -getLayoutMaterializationPlan(srcLayout, dstLayout, elementType) +getPreferredCastLayoutFact(sourceType, resultType) +getPreferredGroupReduceLayoutFact(sourceType, numGroups) +canMaterializeDataLayout(sourceType, resultType) +canMaterializeMaskLayout(sourceType, resultType) supportsMaskGranularityConversion(srcG, dstG) -supportsMemoryAccessPlan(plan) +supportsMemoryAccessProof(proof) supportsPrefixPopcount(maskType) supportsReductionScanContract(op) getScratchResource(plan) ``` -The registry returns structured results: +Capability and materialization helpers return structured results: ```text supported @@ -4170,7 +4191,7 @@ If any answer is no, the slice is not ready to be treated as complete. ## 13. Adding One VMI Op End To End -新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的六个落点, +新增一个 `pto.vmi.*` op 时,不要只补 ODS 和 lowering pattern。它必须穿过固定的七个落点, 否则很容易出现 verifier 能过、layout pass 不知道怎么约束、或控制流 physicalization 后残留 VMI type。 ```text @@ -4180,20 +4201,24 @@ If any answer is no, the slice is not ready to be treated as complete. 2. semantic verifier: lib/PTO/IR/VMI.cpp -3. layout facts: +3. layout assignment facts: lib/PTO/Transforms/VMILayoutAssignment.cpp -4. vmi-to-vpto preflight: +4. shared layout support, when the fact crosses stages: + include/PTO/Transforms/VMILayoutSupport.h + lib/PTO/Transforms/VMILayoutSupport.cpp + +5. vmi-to-vpto preflight: lib/PTO/Transforms/VMIToVPTO.cpp::verifySupportedVMIToVPTOOps -5. OneToN lowering pattern: +6. OneToN lowering pattern: lib/PTO/Transforms/VMIToVPTO.cpp::populateVMIOneToNConversionPatterns -6. focused lit tests: +7. focused lit tests: test/lit/vmi/ ``` -这六个落点的职责不同: +这七个落点的职责不同: ```text ODS: @@ -4211,6 +4236,12 @@ LayoutAssignment: - mask consumer required granularity 不能在 collect 阶段改 IR。 +VMILayoutSupport: + 只放跨 assignment、validation、optimization、lowering 中至少两个阶段共享的纯查询。 + 典型内容是 cast layout fact、group_reduce layout fact、ensure_* materialization support。 + 不能返回 VPTO instruction sequence、不能决定 clone/rematerialize、不能读取 producer/user context。 + 只有一个 lowering pattern 自己使用的判断不要抽到这里。 + VMIToVPTO preflight: 在 rewrite 前拒绝当前 lowering 不支持但语义合法的 case。 典型例子是 partial physical chunk、non-prefix mask constant、dynamic create_mask、unsupported shuffle。 diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index a6583a3d8b..e1fa19cc7e 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -41,8 +41,8 @@ pto-validate-vmi-ir: vmi-layout-assignment: solve hard value layout constraints - choose explicit layouts and local recipe carriers visible in IR - insert ensure/rematerialization helpers + choose explicit layouts visible in IR + insert ensure_layout / ensure_mask_layout / ensure_mask_granularity helpers make internal function boundary layouts explicit rewrite VMI types with layout attrs @@ -54,11 +54,12 @@ vmi-layout-fold-consumers: fold use-site materialization into consumers that can directly consume the source layout while preserving the same logical effect example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become - a store of deinterleaved=2 when the store has a local vstsx2 INTLV recipe + a store of deinterleaved=2 when the store has a layout-aware vstsx2 INTLV + lowering current implementation: pto.vmi.store, pto.vmi.tile_write, and the value operand of pto.vmi.masked_store when the existing mask arity matches, fed by ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 - uses the store's vstsx2 INTLV recipe; factor=4 is still store-local, but it + uses the store's vstsx2 INTLV lowering; factor=4 is still store-local, but it materializes through physical interleave before vsts. vmi-layout-rematerialize: @@ -73,15 +74,17 @@ vmi-layout-rematerialize: vmi-layout-sink-materialization: move ensure_layout across pure layout-transparent elementwise chains when the - rewritten IR reduces materialization cost and keeps every op locally legal + rewritten IR reduces materialization overhead and keeps every op locally legal current implementation: sink two identical operand ensure_layout helpers - across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, or one - source ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, - producing one result ensure_layout. It also sinks matching - ensure_mask_layout or ensure_mask_granularity helpers across - mask_and/mask_or/mask_xor/mask_not, producing one result mask helper. It - does not sink through select, fma, cast, load, store, reduce, - group_broadcast, or control-flow ops + across binary add/sub/mul/div/min/max/and/or/xor/shl/shru VMI ops, three + identical operand ensure_layout helpers across fma, or one source + ensure_layout across unary neg/abs/sqrt/exp/ln/relu/not VMI ops, producing + one result ensure_layout. It also sinks compare data helpers to one result + ensure_mask_layout, and sinks select only when both selected values and the + mask carry matching explicit helpers. Matching ensure_mask_layout or + ensure_mask_granularity helpers are sunk across mask_and/mask_or/mask_xor/ + mask_not, producing one result mask helper. It does not sink through cast, + load, store, reduce, group_broadcast, or control-flow ops. vmi-legalize-arith-select: restore scalar-condition arith.select with VMI result type back to scf.if @@ -92,11 +95,12 @@ pto-validate-vmi-layout-ir: verify every VMI data/mask value has layout verify every VMI value has an assigned layout and every non-local lowering choice has been serialized explicitly - verify helper ops have registered materialization recipes. Current + verify helper ops have supported materialization paths. Current implementation checks `ensure_layout`, `ensure_mask_layout`, and - `ensure_mask_granularity` at the layout gate, so unsupported helper recipes - fail before `vmi-to-vpto`. It also checks the first semantic local-recipe - families, non-contiguous `pto.vmi.store`/`pto.vmi.tile_write`, block8 + `ensure_mask_granularity` at the layout gate, so unsupported helper + materializations fail before `vmi-to-vpto`. It also checks the first + semantic local lowering families, non-contiguous + `pto.vmi.store`/`pto.vmi.tile_write`, block8 `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, @@ -168,7 +172,7 @@ include/PTO/Transforms/Passes.td lib/PTO/Transforms/PTOValidateVMIIR.cpp lib/PTO/Transforms/VMILayoutAssignment.cpp lib/PTO/Transforms/VMIToVPTO.cpp -lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +small layout fact/materialization helpers under lib/PTO/Transforms test/lit/vmi/vmi_layout_assignment_*.pto test/lit/vmi/vmi_to_vpto_*.pto @@ -224,7 +228,7 @@ contiguous: deinterleaved: F > 1 B > 0 - direct full-chunk recipes require N % (F * B) == 0 + direct full-chunk lowerings require N % (F * B) == 0 group_slots: G > 0 @@ -236,12 +240,13 @@ group_slots: Parser compatibility during migration: ```text -#pto.vmi.layout +#pto.vmi.layout ``` -is accepted as a legacy spelling for the pre-design implicit group layout. New -`vmi-layout-assignment` output must not rely on that implicit form. It must -print one of: +is the lowering contract for group-slot values. The parser still accepts +`#pto.vmi.layout` as a legacy spelling for the pre-design +implicit group layout, but `vmi-to-vpto` support queries require explicit slots. +New `vmi-layout-assignment` output must print one of: ```text #pto.vmi.layout @@ -270,10 +275,10 @@ Layout-assigned: Surface VMI types are legal before assignment. Layout-assigned VMI types are required after assignment. -### 3.3 Explicit Recipe Carriers +### 3.3 Explicit Lowering Carriers Lowering decisions are carried by the current op and its types, not by a -separate recipe string. The allowed carriers are: +separate lowering-plan string. The allowed carriers are: ```text op attrs and operands @@ -297,9 +302,9 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths -ensure_layout always carries source/result layouts instead of recipe -ensure_mask_layout always carries source/result layouts instead of recipe -ensure_mask_granularity always carries source/result granularities instead of recipe +ensure_layout always carries source/result layouts +ensure_mask_layout always carries source/result layouts +ensure_mask_granularity always carries source/result granularities ``` Layout/attr-only decisions today: @@ -318,11 +323,12 @@ Implementation rule: validate-assigned-vmi validates assigned layouts, mask granularity, boundaries, and helper placement. vmi-to-vpto emits VMI-LAYOUT-CONTRACT for missing local proof. -If a layout/attr-only op later gains a second legal recipe that cannot be -distinguished from current-op information, that recipe must be represented by a +If a layout/attr-only op later gains a second legal lowering that cannot be +distinguished from current-op information, that lowering must be represented by a new attr, helper op, or rematerialized op before vmi-to-vpto can emit it. -Unsupported shapes that have no registered recipe still diagnose through their -specific capability check rather than failing with a generic missing-recipe +Unsupported shapes that have no explicit materialization/lowering path still +diagnose through their specific capability check rather than failing with a generic +missing-lowering error. ``` @@ -386,7 +392,7 @@ cast boundary: op semantic, not a VMI type spelling. Current VPTO lowering supports 32-bit integer narrowing to unsigned i8 storage, matching the available VCVTII s32/u32 -> u8 forms; signed i8 - narrowing needs a separate target recipe. + narrowing needs a separate target lowering. compute / accumulator: floating compute baseline: f16/f32, with reassoc required for reductions @@ -394,7 +400,7 @@ compute / accumulator: integer compute baseline: i32 for grouped reduction; i8/i16 storage must first cast to i32 because integer reduction instructions widen narrow inputs. f8/i8 are not baseline accumulator/compute types. Supporting direct 8-bit - compute requires a target capability entry and a separate recipe family. + compute requires a target capability entry and a separate lowering family. ``` Important semantic split: @@ -412,129 +418,169 @@ group_slot_load: loads one scalar per group and produces group_slots ``` -## 5. Local Recipe Registry +## 5. Layout Fact Helpers And Ensure-Based Optimization Hooks -Create one target-aware local recipe registry shared by assignment and lowering. -It is not serialized as a separate recipe-selection attribute. It answers local legality -questions from op kind, explicit attrs/operands, layouts, and target capability. +Do not implement a target-aware lowering-plan registry shared by assignment and +lowering. The shared contract is the IR itself: assigned VMI layouts, explicit +`ensure_layout` / `ensure_mask_layout` / `ensure_mask_granularity` helpers, +semantic op attrs/operands, and target capability diagnostics. -```c++ -class VMILocalRecipeRegistry { -public: - SmallVector getProducerRecipes(Operation *op); - SmallVector getConsumerRecipes(OpOperand &use); - SmallVector getTransferRecipes(Operation *op); - FailureOr - getMaterializationRecipe(Type valueType, VMILayoutKey from, - VMILayoutKey to); - bool isCheaplyRematerializable(Operation *op); - bool hasTargetCapability(RecipeID recipe) const; -}; +Small pure helpers are still useful when they remove duplicated layout math. +They must return semantic layout facts, not VPTO instruction plans, costs, +clone decisions, or multi-user plans. + +Keep the support layer small. A query belongs in `VMILayoutSupport` only when +at least two stages need the same fact and a mismatch would create an +assignment-vs-lowering bug. Current valid shared facts are: + +```text +cast layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: f32->f8 must see deinterleaved=4 source and contiguous result in + every stage. + +group_reduce layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: S=2*VLaneElems means deinterleaved=2 source/mask and + group_slots(G, slots=8) result in every stage. + +layout materialization support: + shared by layout validation, vmi-to-vpto, and helper-based optimizations. + Example: ensure_layout from deinterleaved=2 f32 to contiguous f32 is the same + materialization whether it survives to lowering or is folded into a store. + +contiguous store support: + shared by fold-consumers and vmi-to-vpto because both must preserve the same + row-major memory effect when consuming a non-contiguous value. ``` -Recipe record: +Do not add a support query for a single private branch such as "this exact op +uses this exact VPTO mnemonic". Keep that branch in the lowering pattern until +another stage needs the same semantic fact. This prevents `VMILayoutSupport` +from becoming a second copy of the lowering pass. ```c++ -struct VMILayoutRecipe { - RecipeID id; - SmallVector operandLayouts; - SmallVector resultLayouts; - int64_t cost; - bool requiresFullTileReadable; - bool mayReadInactivePhysicalLanes; - DiagnosticBuilder (*explainFailure)(...); +struct VMICastLayoutFact { + VMICastLayoutKind kind; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t factor; }; + +struct VMIGroupReduceLayoutFact { + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize; + int64_t vlaneElems; +}; + +FailureOr +getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType); + +FailureOr +getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups); + +LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason); ``` -The registry must be target-aware but deterministic. It should not read global -mutable state. Pass options configure fallback availability: +Baseline assignment uses these helpers only to produce assigned layouts and +use-site helpers. It does not clone producers, rematerialize cheap ops, choose +memory-fused layouts by cost, or specialize private function signatures for +performance. + +Optimization passes are deliberately helper-driven: ```text -enableScratchFallback -enableGatherFallback -enablePublicVMIABI -diagnosticVerbosity +fold-consumers: + input shape: ensure_layout feeding a layout-aware consumer. + support query: can this consumer preserve the same logical memory effect from + the source layout? + output shape: the consumer directly uses the source value. + +rematerialize: + input shape: cheap producer feeding ensure_layout / ensure_mask_layout. + support query: can the cloned producer directly create the requested type? + output shape: a cloned producer at the use. + +sink-materialization: + input shape: pure elementwise op whose operands are matching ensure_* helpers. + support query: can the result helper be materialized if it remains? + output shape: the op runs in the source layout and one helper remains on the + result. ``` -Assignment and optimization passes may query the registry to decide which IR -shape to produce. `vmi-to-vpto` may query the same registry to verify the -current op is locally lowerable. If the same op, attrs, operands, and -operand/result layouts could map to two different physical recipes with -different observable preconditions, the IR is under-specified; add an explicit -attr, operand, helper op, or distinct VMI semantic op before implementing that -recipe. +These passes may improve multi-consumer cases without asking assignment to solve +a global cost problem. Assignment guarantees a legal baseline with helpers; +optimization removes or moves those helpers locally when the rewritten IR still +contains enough information for `vmi-to-vpto`. -Current implementation status: `VMILocalRecipeRegistry` exists and currently -owns nine local recipe families: +Implementation-relevant layout facts: ```text -contiguous store/tile_write consumer recipes: - contiguous vsts - deinterleaved=2 vstsx2 INTLV - deinterleaved=4 materialize-then-vsts +dense store/tile_write: + requests contiguous source. If the value is assigned deinterleaved, + assignment inserts ensure_layout at the store use. A later optimization may + fold ensure_layout + store into a layout-aware store lowering. -helper materialization recipes: - data/mask layout identity - data/mask contiguous <-> deinterleaved=2/4 when source/result physical - arity matches and the physical part shape can be materialized - mask granularity identity or b8/b16/b32 predicate cast +data/mask helper materialization: + identity conversions are always legal. + contiguous <-> deinterleaved=2/4 is legal only when source/result physical + arity and physical chunk shapes make the same logical value materializable. + unsupported conversions remain explicit diagnostics. -group_slot_load semantic recipes: - slots=8 unit-stride vsldb - slots=1 aligned lane-0 vsldb per group - -block8 group_load semantic recipes: - S=16 deinterleaved=2, block_elems=8 vsldb per row fragment - S=32 deinterleaved=4, block_elems=8 vsldb per row fragment +group_slot_load: + assigned result layout is group_slots(G, slots=8) for packed slots or + group_slots(G, slots=1) for row-local slots. -group_slots group_store semantic recipes: - slots=8 unit-stride vsts - slots=1 aligned lane-0 vsts per group +block8 group_load: + assigned result layout is deinterleaved=2/4 with block_elems=8 only when the + op carries the required constant stride and memory-safety proof. -group_slots group_reduce_add{f|i} semantic recipes: - define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. - T is the accumulator/reduce element type after any required storage cast. - f8 storage reduces through f32; i8 storage reduces through an explicit - signed/unsigned integer cast to an accumulator type such as i32. In the - baseline contract, f8/i8 are cast-boundary storage types rather than - accumulator/compute types. - S=VLaneElems contiguous vcgadd - S=2*VLaneElems deinterleaved=2 vcgadd+vadd - S=4*VLaneElems deinterleaved=4 vcgadd+vadd tree - S>=L && S%L==0 contiguous slots=1 vcadd/vadd/vsel row-local reduction, - with one physical result part per group. For 32-bit element types this covers - S=64, S=128, S=256, ...; for 16-bit element types this covers S=128, S=256, - ... +group_store: + consumes group_slots(G,K). Explicit output stride attrs/operands decide + whether slots=8 packed or slots=1 row-local stores are legal. + +group_reduce_add{f|i}: + define E = sizeof(accumulator T), VLaneElems = 32B / E, L = 256B / E, + S = N / G. T is the accumulator/reduce element type after any required + storage cast. + S=VLaneElems uses contiguous source/mask and group_slots(G, slots=8). + S=2*VLaneElems uses deinterleaved=2 source/mask and group_slots(G, slots=8). + S=4*VLaneElems uses deinterleaved=4 source/mask and group_slots(G, slots=8). + S>=L && S%L==0 uses contiguous source/mask and group_slots(G, slots=1). -explicit-slots group_broadcast semantic recipes: - slots=8/slots=1 vselr materialization to contiguous or supported - deinterleaved result layouts +group_broadcast: + consumes group_slots(G,K) and produces one assigned dense layout. If another + consumer wants a different dense layout, assignment inserts ensure_layout. + Optimization may clone/rematerialize group_broadcast per use. -extf/truncf semantic recipes: +extf/truncf: contiguous f16/bf16 -> deinterleaved=2 f32 contiguous f8-like -> deinterleaved=4 f32 deinterleaved=2 f32 -> contiguous f16 deinterleaved=4 f32 -> contiguous f8-like - group_slots(G, slots=1) f32 -> f16 + group_slots(G, slots=1) f32 -> f16 remains a slot-preserving transform. -extsi/extui/trunci semantic recipes: - contiguous i8 -> deinterleaved=2 i16 through VCVTII.{s,u}82{s,u}16 #part - contiguous i8 -> deinterleaved=4 i32 through VCVTII.{s,u}82{s,u}32 #pp - deinterleaved=2 i16 -> contiguous i8 through VCVTII.*162*8 #part - deinterleaved=4 i32 -> contiguous ui8 through VCVTII.*322u8 #pp +extsi/extui/trunci: + contiguous i8/i16 -> deinterleaved i32 according to widening factor. + deinterleaved i32 -> contiguous i8/i16 according to narrowing factor. packed group_slots integer width-changing cast is unsupported until a - slot-wise cast recipe is defined. + slot-wise transform is represented explicitly. -bitcast semantic recipes: - per-part vbitcast for contiguous/deinterleaved layouts when source/result - layouts match, physical arity matches, and every physical chunk carries the - same logical bit footprint; this does not require each deinterleaved part to - contain the same number of chunks. group_slots bitcast is unsupported until a - slot-wise bitcast contract is defined. +bitcast: + per-part vbitcast is valid for contiguous/deinterleaved layouts when + source/result layouts match, physical arity matches, and every physical chunk + carries the same logical bit footprint. group_slots bitcast is unsupported + until a slot-wise bitcast contract is defined. ``` -`vmi-layout-fold-consumers`, `pto-validate-vmi-layout-ir`, and `vmi-to-vpto` -query this registry for the decisions implemented above. +`vmi-layout-fold-consumers`, rematerialization, sink/hoist, and private +function specialization passes consume explicit helper IR. They may replace +helpers with cheaper equivalent IR, but they must not introduce hidden lowering +plans that `vmi-to-vpto` has to rediscover from producer/user context. ## 6. Layout Assignment Data Model @@ -544,23 +590,17 @@ query this registry for the decisions implemented above. struct ValueLayoutState { Value value; Type logicalType; - SmallVector candidates; std::optional chosen; + std::optional naturalLayout; SmallVector useRequests; }; struct UseRequest { OpOperand *operand; VMILayoutKey requestedLayout; - RecipeID requestingRecipe; + Operation *requestingOp; bool hard; }; - -struct OpRecipeState { - Operation *op; - SmallVector candidates; - std::optional chosen; -}; ``` ### 6.2 Collection Phase @@ -571,7 +611,7 @@ Walk the module and collect: 1. every VMI value 2. every VMI block argument 3. every VMI function argument/result -4. every VMI op with candidate local recipes +4. every VMI op with natural producer layouts or use-site layout requests 5. every branch/yield/call/return edge carrying VMI ``` @@ -602,13 +642,11 @@ truncf f32->f16: result contiguous group_reduce S=16: - source candidate deinterleaved=2, block_elems=1 - source candidate deinterleaved=2, block_elems=8 + source request deinterleaved=2, block_elems=1 result group_slots(G, slots=8) group_reduce S=32: - source candidate deinterleaved=4, block_elems=1 - source candidate deinterleaved=4, block_elems=8 + source request deinterleaved=4, block_elems=1 result group_slots(G, slots=8) group_reduce S=64: @@ -617,8 +655,8 @@ group_reduce S=64: group_broadcast: source request group_slots(G,K) - result candidate comes from each dense consumer request - op is rematerializable per use + result receives one assigned dense layout + incompatible dense uses are represented by ensure_layout ordinary dense add/mul/select: operands/results same dense layout @@ -634,33 +672,21 @@ group_store: source request group_slots(G,K) ``` -Consumer-driven adoption is limited to producers that are layout-transparent or -can produce the requested memory layout directly: +Baseline assignment does not perform consumer-driven adoption for performance. +It records natural producer layouts and hard use-site requests. If a request +does not match the assigned layout, the pass inserts an explicit helper at that +use. ```text -direct layout producer: - load, tile_read - -layout-transparent producer: - broadcast, constant, iota - add/sub/mul/fma/div/min/max/neg/abs/sqrt/exp/ln/relu - integer bitwise/shift/not - select, bitcast -``` +natural layout producer: + extf/truncf, group_reduce, group_slot_load, group_load when the op itself + carries a layout-producing contract -For a non-load layout-transparent producer, only non-contiguous consumer -requests may be adopted by the producer equivalence class. Contiguous requests -from ordinary stores are handled by use-site `ensure_layout` or -rematerialization instead. This prevents a dense store from overwriting a -natural `deinterleaved` cast layout while still allowing: - -```text -load -> broadcast -> addf -> S=32 group_reduce +layout equality producer: + dense add/mul/select and CFG-carried values tie operands/results but do not + pick a cheaper layout by cost ``` -to assign the whole producer chain as -`deinterleaved = 4, block_elems = 8` before `vmi-to-vpto`. - Memory legality constraints: ```text @@ -676,12 +702,12 @@ compact S=12 logical S=16: ### 6.3.1 Request Builders Implement request generation as small per-op builders. The builders produce -candidate recipes and use-site requests; they do not rewrite IR. +natural layouts, use-site requests, equality constraints, and diagnostics; they +do not choose optimization plans. ```text buildStoreRequests: - ordinary store -> dense contiguous request unless a layout-aware store recipe is - selected + ordinary store -> dense contiguous request group_store -> group_slots(G,K) request plus stride/alignment capability checks @@ -690,25 +716,25 @@ buildCastRequests: extf f8->f32 -> source contiguous, result deinterleaved=4 truncf f32->f16 -> source deinterleaved=2/block_elems=1, result contiguous truncf f32->f8 -> source deinterleaved=4/block_elems=1, result contiguous - group_slots slots=1 f32->f16 -> slot-preserving recipe - group_slots slots=8 width-changing cast -> diagnostic unless a packed recipe - exists + group_slots slots=1 f32->f16 -> explicit slot-preserving transform + group_slots slots=8 width-changing cast -> diagnostic unless a packed + transform is explicitly represented buildGroupReduceRequests: derive E = sizeof(accumulator type), VLaneElems = 32B / E, L = 256B / E, and S = logical_lanes / num_groups S=VLaneElems -> contiguous source, group_slots(G,8) result - S=2*VLaneElems -> deinterleaved=2/block_elems=1 or block_elems=8 source, + S=2*VLaneElems -> deinterleaved=2/block_elems=1 source, group_slots(G,8) result - S=4*VLaneElems -> deinterleaved=4/block_elems=1 or block_elems=8 source, + S=4*VLaneElems -> deinterleaved=4/block_elems=1 source, group_slots(G,8) result S>=L && S%L==0 -> contiguous source, group_slots(G,1) result 8-bit storage must be cast to an accumulator type before this request builder - other S -> diagnostic unless an explicit fallback recipe is enabled + other S -> diagnostic unless an explicit fallback op/helper is enabled buildGroupMemoryRequests: - group_load S=16/S=32 with aligned constant stride -> block_elems=8 recipe - group_load row-local full chunks -> contiguous recipe + group_load S=16/S=32 with aligned constant stride -> natural block_elems=8 + group_load row-local full chunks -> natural contiguous group_slot_load unit stride -> group_slots(G,8) group_slot_load aligned row-local stride -> group_slots(G,1) unsupported dynamic/unaligned grouped memory -> diagnostic @@ -723,8 +749,8 @@ buildElementwiseRequests: buildMaskRequests: mask layout follows each consuming data layout predicate granularity follows each consuming element type - create_mask/create_group_mask may be cloned for incompatible mask layout or - granularity requests + create_mask/create_group_mask produce one assigned mask layout and use + ensure_mask_layout / ensure_mask_granularity for incompatible uses masked_store requests source layout, mask layout, and store predicate granularity explicitly @@ -733,20 +759,22 @@ buildControlFlowRequests: create equality requests on the carried VMI layout variable buildFunctionBoundaryRequests: - private/internal function argument/result layouts are specialized or - materialized with callee-entry/return-site helpers + private/internal function argument/result layouts are materialized with + callee-entry/return-site helpers in baseline assignment; signature + specialization is an optimization pass public/external VMI arguments/results diagnose unless enablePublicVMIABI has - a real ABI recipe + a real ABI contract ``` Request builders must record the requesting op. Diagnostics and inserted helpers are use-site operations, so the user can see which consumer forced a layout. -### 6.3.2 Producer Classes +### 6.3.2 Optimization Producer Classes -The solver uses producer classes to decide whether a conflict can be solved by -cloning, equivalence propagation, or materialization. +Baseline assignment does not use producer classes to solve conflicts. It +inserts helpers. Later optimization passes may classify producers to replace +helpers with cheaper equivalent IR. ```text cheap rematerializable producers: @@ -757,7 +785,7 @@ cheap rematerializable producers: create_group_mask group_broadcast group_slot_load when the same address/no-alias/proof conditions as load hold - and the memory recipe is legal at the clone site + and the memory access remains legal at the clone site layout-transparent producers: add/sub/mul/fma/min/max/neg/abs @@ -766,13 +794,13 @@ layout-transparent producers: integer bitwise and shift ops fixed-layout producers: - extf/truncf physical conversion recipes - group_load block-fragment recipes + extf/truncf physical conversion layouts + group_load block-fragment layouts group_reduce result group_slots - masked_load when the physical memory-safety proof fixes a full-read recipe + masked_load when the physical memory-safety proof fixes a full-read lowering ``` -Conflict policy: +Optimization conflict policy: ```text cheap producer: @@ -784,31 +812,28 @@ layout-transparent producer: only at incompatible uses fixed-layout producer: - use registered materialization only; otherwise diagnose + use explicit helper materialization only; otherwise diagnose ``` -This is the rule that keeps case 3.32 legal: a plain `load` can be assigned to -`deinterleaved=4, block_elems=1` for both `truncf f32->f8` and S=32 -`group_reduce`. It also keeps case 3.19.2 diagnostic: a strided `group_load` -that selected `block_elems=8` is fixed unless a block8-to-parity -materialization or rematerialized memory recipe is registered. +These classes are not assignment constraints. They are rewrite preconditions +for passes that consume `ensure_layout` and decide whether the helper can be +folded, sunk, hoisted, or replaced by rematerialization. ### 6.4 Solving And Rewriting Algorithm: ```text -1. Pick candidate recipe sets for every op. -2. Propagate hard constraints through SCCs. -3. Resolve transfer-equivalent dense values. -4. Choose multi-recipe ops by cost: - - S=16 parity vs block8 - - load memory-fused vs load+materialize - - group_slot_load slots=8 vs slots=1 -5. For conflicting uses: - - rematerialize cheap producer where legal - - otherwise insert ensure_layout at use - - otherwise diagnose +1. Collect natural layouts, use-site requests, equality constraints, and + memory-safety proofs. +2. Propagate equality constraints through SCCs. +3. Choose one deterministic assigned layout per value/equivalence class: + explicit user layout, then unique producer natural layout, then hard + non-contiguous layout, then contiguous. +4. For conflicting uses, insert ensure_layout / ensure_mask_layout / + ensure_mask_granularity at the use. +5. Emit diagnostics for unsupported semantic constraints or missing explicit + materialization/memory-safety proof. 6. Rewrite VMI result/block/function types with chosen layouts. 7. Insert helper ops with source/result layout attrs. ``` @@ -818,9 +843,12 @@ Rewrite invariants: ```text No VMI data/mask value after assignment has a null layout. Any non-local choice is represented by op attrs, operand/result layouts, a -helper op, a clone, or an explicit diagnostic. -Every ensure_* helper has a registered materialization recipe. -Every function/call signature carrying VMI is specialized or diagnosed. +helper op, or an explicit diagnostic. Cloned/rematerialized producers may +appear only after later layout optimization passes. +Every ensure_* helper has an explicit supported materialization path or a +diagnostic. +Every function/call boundary carrying VMI is materialized, kept in an explicit +ABI contract, or diagnosed. ``` ### 6.5 Rewrite Artifacts @@ -831,24 +859,21 @@ Assignment rewrites the IR so that later lowering has no hidden choices. type rewrite: every VMI data/mask result and block argument receives a layout attr -clone rewrite: - cheap producers are cloned before their divergent use sites - each clone receives its own layout and attrs - ensure rewrite: - non-cheap values use pto.vmi.ensure_layout or ensure_mask_layout at the use + mismatched uses get pto.vmi.ensure_layout or ensure_mask_layout at the use site, with source and target layouts visible in the types granularity rewrite: one semantic mask used by f32 and f16 consumers gets - ensure_mask_granularity or cloned mask producers + ensure_mask_granularity at the use site control-flow rewrite: scf.if/scf.for yields and block arguments are rewritten to one agreed layout; materialization is inserted before yield when branches differ function rewrite: - private VMI functions are specialized or get callee-entry ensure_layout + baseline private VMI functions get callee-entry/return-site ensure_layout; + signature specialization is an optimization pass public/external VMI functions are diagnosed ``` @@ -865,7 +890,8 @@ Canonical assigned IR shape for a conflicting load: pto.vmi.store %x_dense, ... ``` -Canonical assigned IR shape for a cloned cheap producer: +Optional future optimized IR shape for a cloned load with an explicit +safe-read/execution proof: ```text %x_s16 = pto.vmi.load ... @@ -878,12 +904,12 @@ Canonical assigned IR shape for a cloned cheap producer: Canonical assigned IR shape for `group_broadcast` multi-use: ```text -%b0 = pto.vmi.group_broadcast %slots +%b = pto.vmi.group_broadcast %slots : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -%b1 = pto.vmi.group_broadcast %slots - : !pto.vmi.vreg<256xf32, #pto.vmi.layout> +%b_c = pto.vmi.ensure_layout %b + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> ``` @@ -912,10 +938,10 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous recipe -3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 recipe -3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 recipe -3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local recipe +3.4 32-bit S=8 reduce buildGroupReduceRequests one_vlane contiguous lowering +3.5 32-bit S=16 reduce buildGroupReduceRequests two_vlane parity/block8 layout +3.6 32-bit S=32 reduce buildGroupReduceRequests four_vlane dintlv4/block8 layout +3.7 32-bit S=64 reduce buildGroupReduceRequests full_chunk row_local lowering 3.11.1 S=64 active-row tail buildMaskRequests active-row store/reduce masks 3.19.1 S=16 block_elems choice buildGroupReduceRequests explicit block_elems layout 3.38 multi-tile S=32 reduce buildGroupReduceRequests multiple group_slots chunks @@ -931,12 +957,12 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load recipe +3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load layout 3.15.2 S=16 row stride > 16 buildGroupMemoryRequests strided block_elems=8 plan 3.16.1 group_slot_load slots=8 buildGroupMemoryRequests unit-stride packed slots plan 3.16.2 group_slot_load slots=1 buildGroupMemoryRequests row-local aligned slots plan 3.27 strided group_load buildGroupMemoryRequests positive block_elems=8 plan -3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load recipe +3.28 slots=1 non-unit load buildGroupMemoryRequests row-local group_slot_load layout 3.37 slots=1 strided store buildStoreRequests group_store stride/alignment proof 3.39 strided load fanout conflict resolver preserving layout or materialization @@ -944,7 +970,7 @@ vmi-to-vpto contract: consume only explicit memory stride/alignment attrs, current op operands, and layouts. It must not infer safe read/write placement from neighboring compute ops. Unsupported dynamic, unaligned, or compact-row gather shapes - stay diagnostics until a gather recipe is explicit in the current op. + stay diagnostics until a gather fallback is explicit in the current op. ``` ```text @@ -952,13 +978,13 @@ case family builder / owner assignment artifact 3.8 reduce->truncf->broadcast conflict resolver slot cast plus dense materialization 3.10 non-load S=32 producer buildElementwiseRequests transparent deinterleaved chain 3.17 broadcast deint consumer conflict resolver use-site group_broadcast layout -3.18 dense + reduce users conflict resolver clone/rematerialize/ensure_layout -3.23 broadcast multi-user conflict resolver cloned group_broadcast -3.33 S=16 + S=32 users conflict resolver cloned load or materialization +3.18 dense + reduce users conflict resolver ensure_layout; optional remat/fold +3.23 broadcast multi-user conflict resolver per-op group_broadcast layout +3.33 S=16 + S=32 users conflict resolver use-site materialization; optional cloned load 3.34 S=64 slots=1 cast buildCastRequests group_slot_cast layout 3.35 slots fanout buildElementwiseRequests same group_slots layout on users -3.36 scalar slots=8/slots=1 conflict resolver cloned group_slot_load/broadcast -3.40 scalar dense + grouped conflict resolver cloned broadcast +3.36 scalar slots=8/slots=1 conflict resolver explicit slots=8/slots=1 producers +3.40 scalar dense + grouped conflict resolver ensure_layout; optional broadcast remat 3.41 incompatible fixed value conflict resolver diagnostic or ensure_layout vmi-to-vpto contract: @@ -985,17 +1011,17 @@ vmi-to-vpto contract: ```text diagnostic family builder / owner required failure -3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store recipe +3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store path 3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast 3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather -3.13 slots=8 width cast buildCastRequests no packed slot cast recipe -3.14 unsupported group size buildGroupReduceRequests no registered reduce recipe +3.13 slots=8 width cast buildCastRequests no packed slot cast transform +3.14 unsupported group size buildGroupReduceRequests no supported reduce layout/lowering 3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan -3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load recipe +3.16.1 slots=8 non-unit load buildGroupMemoryRequests no packed strided slot load path 3.16.2 slots=1 bad stride buildGroupMemoryRequests no dynamic/unaligned row-local plan 3.19.2 invalid block_elems use conflict resolver no preserving materialization 3.25.2 public/external ABI buildFunctionBoundary no stable public VMI ABI -3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback recipe +3.27 unaligned group_load buildGroupMemoryRequests no gather/block fallback path 3.30 masked_load unsafe tail buildMaskRequests no padding/gather fallback vmi-to-vpto contract: @@ -1068,75 +1094,75 @@ adaptor physical values Each pattern rejects: ```text -missing current-op proof for an otherwise unsafe memory recipe +missing current-op proof for an otherwise unsafe memory lowering missing target capability unexpected group_slots dense consumer ``` -Target local recipe matrix: +Target local lowering matrix: ```text -load, recipe=dense_load_norm: +load, lowering=dense_load_norm: result layout contiguous emits pto.vlds / pto.vsts NORM paths covers dense store users and full-chunk row-local reduce input -load, recipe=load_dintlv2: +load, lowering=load_dintlv2: result layout deinterleaved=2, block_elems=1 emits vldsx2 DINTLV_B32 or normal load + vdintlv materialization covers f32->f16, S=16 parity reduce, f16->f32 widened values -load, recipe=load_dintlv4: +load, lowering=load_dintlv4: result layout deinterleaved=4, block_elems=1 emits two vldsx2 DINTLV_B32 plus vdintlv covers f32->f8, S=32 dintlv4 reduce -group_load, recipe=s16_group_load_block8_unit_stride: +group_load, lowering=s16_group_load_block8_unit_stride: result layout deinterleaved=2, block_elems=8 emits vldsx2/BDINTLV for 8 rows of 16xf32 covers compact logical S=16 when source_group_stride == 16 -group_load, recipe=s16_group_load_block8_stride: +group_load, lowering=s16_group_load_block8_stride: result layout deinterleaved=2, block_elems=8 emits two vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, recipe=s32_group_load_block8_stride: +group_load, lowering=s32_group_load_block8_stride: result layout deinterleaved=4, block_elems=8 emits four vsldb strided 32B block loads requires source_group_stride % 8 == 0 -group_load, recipe=group_load_contiguous_chunks: +group_load, lowering=group_load_contiguous_chunks: result layout contiguous emits one vlds per physical group chunk using row_stride address arithmetic covers the currently implemented full-chunk row-local group_load path -group_reduce_add{f|i}, recipe=one_vlane_reduce_contiguous: +group_reduce_add{f|i}, lowering=one_vlane_reduce_contiguous: consumes contiguous accumulator type T with group size VLaneElems(T) produces group_slots(G, slots=8) emits one vcgadd -group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: consumes deinterleaved=2, block_elems=1 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: consumes deinterleaved=2, block_elems=8 produces group_slots(G, slots=8) emits two vcgadd operations and one vadd -group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: consumes deinterleaved=4, block_elems=1 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_add{f|i}, recipe=four_vlane_reduce_block8_stride: +group_reduce_add{f|i}, lowering=four_vlane_reduce_block8_stride: consumes deinterleaved=4, block_elems=8 produces group_slots(G, slots=8) emits four vcgadd operations and a vadd tree -group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: +group_reduce_add{f|i}, lowering=full_chunk_reduce_row_local: consumes contiguous accumulator type T with group size that is a multiple of one physical chunk L(T) produces group_slots(G, slots=1) @@ -1144,31 +1170,31 @@ group_reduce_add{f|i}, recipe=full_chunk_reduce_row_local: the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract -group_slot_load, recipe=group_slot_load_slots8_unit_stride: +group_slot_load, lowering=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 emits one packed vsldb load -group_slot_load, recipe=group_slot_load_slots1_row_local: +group_slot_load, lowering=group_slot_load_slots1_row_local: result group_slots(G, slots=1) supports aligned non-unit source_group_stride requires constant positive source_group_stride divisible by 256 / elementBits emits one lane-0 vsldb per group -group_broadcast, recipe=group_broadcast_slots8_vselr: +group_broadcast, lowering=group_broadcast_slots8_vselr: source group_slots(G, slots=8) result dense layout selected per use emits vselr using assigned result layout -group_broadcast, recipe=group_broadcast_slots1_vselr: +group_broadcast, lowering=group_broadcast_slots1_vselr: source group_slots(G, slots=1) result dense layout selected per use emits vdup/vselr row-local materialization -truncf, recipe=group_slot_cast_slots1_f32_to_f16: +truncf, lowering=group_slot_cast_slots1_f32_to_f16: source/result group_slots(G, slots=1) emits one lane-0 vcvt per group slot block - rejects packed slots=8 unless another plan is registered + rejects packed slots=8 unless slot-preserving cast support exists ``` The target matrix is the implementation contract. The staged status below @@ -1195,14 +1221,15 @@ group_reduce_addf: Full-chunk row-local assignment, including S=64 and S=256 f32 cases, uses #pto.vmi.layout and has focused layout-assignment/vmi-to-vpto lit coverage; the explicit slots=1 generic - VCADD row-local path is registered and selected locally. + VCADD row-local lowering is selected locally from the current op attrs and + assigned layouts. group_reduce_addi is implemented for i32 accumulator values. i8/i16 storage must be widened explicitly before grouped reduction because narrow integer reduction instructions widen their result. group_broadcast: explicit slots=8/1 source layouts select - packed or row-local VSELR recipes locally. Deinterleaved block-fragment + packed or row-local VSELR lowerings locally. Deinterleaved block-fragment results use the result layout block_elems as the local vselr selection group, so `deinterleaved = 4, block_elems = 8` broadcasts one group slot across each @@ -1219,9 +1246,9 @@ group_load: contiguous full-chunk path is selected from a contiguous result layout. S=16/S=32 block-aligned strided loads are selected from #pto.vmi.layout, and lower to one - vsldb per 32B row fragment and physical chunk. The explicit block8 recipe - is registered and checked by pto-validate-vmi-layout-ir before vmi-to-vpto. - The dedicated S=16 unit-stride vldsx2/BDINTLV recipe remains a local + vsldb per 32B row fragment and physical chunk. The explicit block8 support + is checked by pto-validate-vmi-layout-ir before vmi-to-vpto. + The dedicated S=16 unit-stride vldsx2/BDINTLV lowering remains a local peephole target. S=16/S=32 group_load with a non-constant, non-positive, or non-8-f32-aligned row_stride is rejected by vmi-layout-assignment because the stable gather @@ -1240,12 +1267,12 @@ group_store: multiple of the 32B store alignment in destination elements: 8 for f32, 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only the first row-local store is 32B-aligned; later `group_off + r` stores are - 4B apart. A future pack-to-slots=8 or unaligned-store recipe is required before + 4B apart. A future pack-to-slots=8 or unaligned-store lowering is required before contiguous `%c1` slots=1 group_store can be accepted. Packed group_slots(G, slots=8) group_store is implemented only when num_groups is a multiple of 8 and row_stride is constant 1; it emits one PAT_VL8 store per packed slot block. Non-unit packed group stores remain a - design target unless a strided packed-lane store recipe is made explicit. + design target unless a strided packed-lane store lowering is made explicit. ``` Current implementation contract for type-generic grouped reduction: @@ -1268,17 +1295,18 @@ Layout assignment: route f8 storage through extf to f32 before group_reduce_addf. route i8/i16 storage through extsi/extui to i32 before group_reduce_addi. route integer narrowing to i8 through trunci; direct i8 compute remains - illegal unless the target capability registry exposes an explicit recipe. + illegal unless target capability and explicit op semantics define that + lowering. diagnose direct f8/i8 compute use with a message that points at the offending op and suggests inserting the explicit cast when the op is meant to consume storage data. -Local recipe registry: - replace f32-shaped recipe keys with width-parametric recipe classes: - one_vlane_reduce - two_vlane_reduce_deinterleaved - four_vlane_reduce_deinterleaved - full_chunk_row_local_reduce +Layout fact helpers: + replace f32-shaped checks with width-parametric group-reduce classifiers: + one_vlane_reduce layout fact + two_vlane_reduce_deinterleaved layout fact + four_vlane_reduce_deinterleaved layout fact + full_chunk_row_local_reduce layout fact key legality on accumulator byte width, source/mask layout, result group_slots layout, num_groups, and target instruction capability. @@ -1288,8 +1316,8 @@ VMI-to-VPTO: materialize integer casts explicitly before reduction; direct i8 group reduce and direct i16 group reduce must not silently become a widening reduction in this pass. - keep VPTO lowering local: it consumes assigned layouts and registered local - recipes, but does not invent a new global layout plan. + keep VPTO lowering local: it consumes assigned layouts and current-op + attrs/operands, but does not invent a new global layout plan. Tests: cover f16 direct and i16-storage-to-i32 grouped reductions. @@ -1302,15 +1330,15 @@ Tests: Examples: ```text -group_reduce_add{f|i}, recipe=two_vlane_reduce_deinterleaved: +group_reduce_add{f|i}, lowering=two_vlane_reduce_deinterleaved: consume deinterleaved=2, block_elems=1 emit two VCGADDs and one VADD -group_reduce_add{f|i}, recipe=two_vlane_reduce_block8: +group_reduce_add{f|i}, lowering=two_vlane_reduce_block8: consume deinterleaved=2, block_elems=8 emit two VCGADDs and one VADD -group_reduce_add{f|i}, recipe=four_vlane_reduce_dintlv4: +group_reduce_add{f|i}, lowering=four_vlane_reduce_dintlv4: consume deinterleaved=4 emit four VCGADDs and reduction tree @@ -1346,7 +1374,7 @@ After assignment: Every VMI value has layout. Every VMI mask has layout and granularity plan. Every lowering choice is locally deterministic or explicit in attrs/layouts. -Every ensure_* helper has a materialization recipe. +Every ensure_* helper has a materialization path. Every control-flow edge has matching VMI layouts. ``` @@ -1364,8 +1392,8 @@ allowed: diagnostic not allowed: - walking from a consumer to a producer to decide a recipe - walking from a consumer to a mask producer to decide whether a recipe is legal + walking from a consumer to a producer to decide a lowering + walking from a consumer to a mask producer to decide whether a lowering is legal inspecting users to choose a result layout or materialization recovering full_tile_readable from surrounding MTE/caller context ``` @@ -1378,7 +1406,7 @@ Current audit result: lowering the deinterleaved create_group_mask itself, vmi-to-vpto first materializes contiguous grouped predicate chunks and then applies predicate pdintlv in the same tree shape as the data vdintlv. It still does not walk - from group_reduce_addf to the mask defining op to choose or reject the plan. + from group_reduce_addf to the mask defining op to choose or reject lowering. The dynamic active_elems_per_group form is also op-local: vmi-to-vpto lowers contiguous chunks with vci/vshrs/vshls/vsub/vcmps, then uses the same predicate pdintlv tree for S=32 deinterleaved masks. @@ -1413,7 +1441,7 @@ op name logical type actual layout requested layout -selected/missing plan +selected/missing support path recommended rewrite or option ``` @@ -1424,8 +1452,8 @@ VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is fixed to #pto.vmi.layout by the selected - strided group_load recipe. Register a rematerialization or preserving - materialization recipe, or avoid consuming this block-loaded value with truncf. + strided group_load layout. Register a rematerialization or preserving + materialization path, or avoid consuming this block-loaded value with truncf. ``` ## 11. Test And Simulator Acceptance @@ -1493,14 +1521,13 @@ the case catalog. Current broad runtime sweep: ```text -WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-layout-gate CASE_PREFIX='vmi/' JOBS=4 \ +WORK_SPACE=$PWD/.tmp/vmi-runtime-batch-final CASE_PREFIX='vmi/' JOBS=4 \ test/vpto/scripts/run_host_vpto_validation_parallel.sh -PASS=43 FAIL=0 -summary: .tmp/vmi-runtime-batch-layout-gate/parallel-summary.tsv -log scan: rg -n "RV_|alignment|\[ERROR\]|\[error\]|ERROR" \ - .tmp/vmi-runtime-batch-layout-gate.log -result: no matches +TOTAL_CASES=47 +PASS=47 FAIL=0 +summary: .tmp/vmi-runtime-batch-final/parallel-summary.tsv +result: all summary entries are PASS ``` The `find: Permission denied` messages printed while discovering CANN simulator @@ -1576,10 +1603,10 @@ diagnostic endpoints: repository evidence: all concrete lit/runtime paths listed below exist - all 43 runtime case directories contain kernel.pto, launch.cpp, main.cpp, + all 47 runtime case directories contain kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py - latest broad VMI runtime sweep passed: PASS=43 FAIL=0 - latest full VMI lit sweep passed: 340/340 + latest broad VMI runtime sweep passed: PASS=47 FAIL=0 + latest full VMI lit sweep passed: 350/350 ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -1958,7 +1985,7 @@ Current checked-in lit coverage for the first `vmi-layout-sink-materialization` optimization is: ```text -test/lit/vmi/vmi_layout_sink_materialization_binary.pto +test/lit/vmi/vmi_layout_sink_materialization_binary.pto // unary, binary, fma, cmp, and select data ops test/lit/vmi/vmi_layout_sink_materialization_mask.pto ``` @@ -1969,27 +1996,27 @@ test/lit/vmi/vmi_legalize_arith_select.pto test/lit/vmi/vmi_ptoas_cli_control_flow.pto ``` -Current checked-in lit coverage for the first semantic local-recipe layout gate +Current checked-in lit coverage for the first semantic local-lowering layout gate is: ```text -test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto -test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_store_support_invalid.pto test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto -test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto -test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto +test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto +test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto ``` Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local -recipes is: +lowering is: ```text test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto @@ -2050,7 +2077,7 @@ Diagnostic-only cases: 3.16.1 group_slot_load slots=8 non-unit stride 3.16.2 group_slot_load slots=1 dynamic or unaligned stride 3.27 S=32 source_group_stride not divisible by 8 f32 elements -3.19.2 block_elems=8 value consumed by truncf without materialization recipe +3.19.2 block_elems=8 value consumed by truncf without materialization path 3.25.2 public/external VMI boundary 3.30 unsafe masked_load tail without stable masked/gather fallback ``` @@ -2069,7 +2096,7 @@ entries: ```text lit: - test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto + test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -2141,15 +2168,15 @@ group_store ```text 3.8 cast commute through group_broadcast 3.18 dense/group-reduce multi-consumer -3.19 block_elems recipe selection +3.19 block_elems layout selection 3.23 group_broadcast multi-consumer 3.32 f32 feeding f8 store and S=32 reduce 3.33 S=16/S=32 reduce multi-consumer rematerialization 3.34 slots=1 group-slot f32->f16 cast 3.35 group_slots fanout to group_store and group_broadcast -3.36 group_slot_load rematerialized for slots=8/slots=1 +3.36 group_slot_load expressed as explicit slots=8/slots=1 producers 3.38 multi-tile group_slots arity -3.40 scalar broadcast rematerialized for dense/grouped users +3.40 scalar broadcast materialized for dense/grouped users 3.41 non-rematerializable value with ensure_layout ``` @@ -2192,12 +2219,12 @@ Current evidence for the case-catalog objective: checked-in runtime case directory 3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, golden.py, and compare.py -4. the latest broad VMI runtime sweep passed: PASS=43 FAIL=0 -5. the latest full VMI lit sweep passed: 342/342 +4. the latest broad VMI runtime sweep passed: PASS=47 FAIL=0 +5. the latest full VMI lit sweep passed: 350/350 6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics -8. no separate recipe string attr is emitted or consumed +8. no separate lowering-plan string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 82a84082c6..42e62e8b3a 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -35,7 +35,7 @@ vmi-to-vpto 不允许通过上下文猜 lowering。 必须在 vmi-layout-assignment 或后续 VMI layout optimization 阶段变成显式 IR: 1. vmi.vreg/vmi.mask 的 layout -2. current-op attrs/operands that make the local recipe deterministic +2. current-op attrs/operands that make the local lowering deterministic 3. use-site ensure_layout / ensure_mask_layout / ensure_mask_granularity 4. rematerialized or cloned producer 5. target capability diagnostic @@ -50,7 +50,7 @@ separates correctness from optimization: hard legalization: produces legal layout-assigned VMI IR for all supported semantics inserts conservative ensure_* helpers at incompatible uses - may choose a simple canonical layout even when a fused consumer recipe exists + may choose a simple canonical layout even when a fused consumer lowering exists must diagnose unsupported semantics before vmi-to-vpto has to guess layout optimization: @@ -80,7 +80,7 @@ A later optimization may replace that use with: pto.vmi.store %x : deinterleaved=2 ``` -only if the store op itself has a local deterministic recipe for preserving the +only if the store op itself has a local deterministic lowering for preserving the same row-major memory effect, such as a layout-aware `vstsx2 INTLV` lowering. Both forms are semantically complete. The second form is an optimization, not a hard requirement for correctness. @@ -115,10 +115,10 @@ group reduce: layout conflict: one value with dense and group-reduce consumers one value with S=16 and S=32 group-reduce consumers - one scalar broadcast rematerialized for dense and grouped users + one scalar broadcast materialized for dense and grouped users, with optional rematerialization one non-rematerializable value materialized with use-site ensure_layout - one scalar group-slot source rematerialized as slots=8 and slots=1 - S=16 block_elems=1/8 recipe selection + one scalar group-slot source expressed as explicit slots=8 and slots=1 producers + S=16 block_elems=1/8 layout selection dense consumer of group_slots diagnostic packed group-slot width-changing cast diagnostic S=64 slots=1 group-slot width-changing cast @@ -173,7 +173,9 @@ consumer-driven pressure: elementwise/select, masked_load/masked_store conflict resolution: - cheap rematerialization, explicit ensure_layout, explicit diagnostics + explicit ensure_layout, explicit ensure_mask_layout, explicit diagnostics + optimization passes may later replace the helpers with rematerialization or + layout-aware consumers control-flow propagation: scf.if, scf.for iter_args/results, internal/private function boundaries, @@ -185,7 +187,7 @@ memory legality: ``` No extra layout kind should be added unless a new case proves that the existing -layouts and recipes cannot express the logical behavior. The remaining open +layouts and explicit helper contracts cannot express the logical behavior. The remaining open items are not missing layout semantics: ```text @@ -224,7 +226,7 @@ cast boundary: i8 participates through extsi/extui/trunci. Signedness is carried by the cast op semantics, not by a separate layout. On the current VPTO target, 32-bit to 8-bit integer narrowing is only a - baseline recipe for unsigned i8 results because the available VCVTII forms + baseline lowering for unsigned i8 results because the available VCVTII forms are s32/u32 -> u8. compute boundary: @@ -289,7 +291,7 @@ slot_lane(g) = g % K All non-slot lanes are undefined and may only be read by group-aware operations. Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. -`K` is selected by the producer/consumer local recipe: +`K` is selected by the assigned producer/result contract: ```text S=8/16/32 packed VCG result -> slots=8 @@ -337,7 +339,7 @@ or explicit helper: pto.vmi.ensure_mask_granularity ``` -`vmi-to-vpto` is allowed to choose a deterministic recipe from local +`vmi-to-vpto` is allowed to choose a deterministic lowering from local information on the current op: ```text @@ -354,104 +356,80 @@ ops to recover a lowering decision or a memory-safety proof. If a decision cannot be made from that local information, layout assignment must rewrite the IR until the decision is explicit in attrs, operand/result -layouts, helper ops, cloned producers, or diagnostics. `vmi-to-vpto` must not -consume a separate string recipe attr. +layouts, helper ops, or diagnostics. Later optimization passes may replace +helpers with cloned/rematerialized producers, but `vmi-to-vpto` must not +consume a separate string lowering-plan attr. -### 4.1 Local Recipe Contract +### 4.1 Local Lowering Contract -The lowering recipe is derived from op + assigned operand/result layouts + -explicit attrs/operands. If two legal recipes cannot be distinguished from +The lowering path is derived from op + assigned operand/result layouts + +explicit attrs/operands. If two legal lowerings cannot be distinguished from that local information, the IR is missing a semantic carrier and must be -extended before the recipe is implemented. +extended before that lowering is implemented. -Locally deterministic decisions in the current implementation: +The shared abstraction is a layout fact classifier, not a central lowering-plan +registry. A classifier may answer questions such as: ```text -group_load: - result layout, num_groups, row_stride, source type, and target capability - decide contiguous chunks versus S=16/S=32 block8 vsldb lowering. Unit-stride - vldsx2/BDINTLV can be a local peephole for the same block8 layout. - -group_slot_load: - result group_slots layout and source_group_stride decide packed slots=8 - versus row-local slots=1 vsldb lowering. A single source op may still be - rematerialized into two ops when different users require different result - layouts; each clone is then locally deterministic. - -group_reduce_add{f|i}: - source/mask layout, result group_slots layout, num_groups, element type, and - the typed reduce semantics decide the local reduction recipe. The recipe is - not keyed by f32 shape names. It is derived from the element byte width. - Floating-point `group_reduce_addf` carries `reassoc`; integer - `group_reduce_addi` does not. - - VLaneElems = 32B / sizeof(T) - L = 256B / sizeof(T) - S = logical_lane_count / num_groups - - S == VLaneElems -> contiguous vcgadd, result slots=8 - S == 2 * VLaneElems -> deinterleaved=2 vcgadd tree, result slots=8 - S == 4 * VLaneElems -> deinterleaved=4 vcgadd tree, result slots=8 - S >= L && S % L == 0 -> contiguous row-local vcadd/vsel, result slots=1 - - Type support is controlled by the typed reduce op semantics and target - capability, not by separate per-type shape rules. Once a type is legal for a - reduce op, the same formula above selects its layout and local recipe. The - current checked-in implementation may lag this design target; that is staged - implementation status, not a design boundary. - - The formula is applied to the accumulator/reduce element type, not - necessarily the storage element type. 8-bit floating-point storage first - casts to f32 for `group_reduce_addf`; 8-bit and 16-bit integer storage first - casts to a signed/unsigned i32 accumulator for - `group_reduce_addi`. In the baseline VMI contract, f8/i8 are storage and - cast-boundary types: they may be the source or destination of cast, load, and - store, but they are not accumulator/compute types for group reduce. Direct - 8-bit grouped reduction is illegal unless the target exposes an explicit - 8-bit compute recipe. - -group_broadcast: - source group_slots layout, result dense layout, num_groups, and element type - decide vdup/vselr materialization. - -truncf: - source/result group_slots layouts and element widths decide the slots=1 - f32->f16 slot-preserving vcvt path. +cast layout fact: + f16/i16 -> f32/i32 requires contiguous source and deinterleaved=2 result + f8/i8 -> f32/i32 requires contiguous source and deinterleaved=4 result + f32/i32 -> f16/i16 requires deinterleaved=2 source and contiguous result + f32/i32 -> f8/i8 requires deinterleaved=4 source and contiguous result + +group_reduce layout fact: + define E = sizeof(accumulator T), VLaneElems = 32B / E, + L = 256B / E, S = N / G. + S == VLaneElems requires contiguous source/mask and + group_slots(G, slots=8) result. + S == 2 * VLaneElems requires deinterleaved=2 source/mask and + group_slots(G, slots=8) result. + S == 4 * VLaneElems requires deinterleaved=4 source/mask and + group_slots(G, slots=8) result. + S >= L && S % L == 0 requires contiguous source/mask and + group_slots(G, slots=1) result. + +memory safety fact: + full_read_elems, shaped safe-tail memref, or explicit fallback option + proves whether rounded-up physical reads are legal. ``` -Other layout-only or attr-only decisions in the current implementation: +These helpers return semantic layout requirements and capability diagnostics. +They do not return VPTO instruction names, cost decisions, clone decisions, or +multi-user plans. -```text -load: - result layout plus explicit memory attrs decide the lowering. full_read_elems - is the memory-safety proof; vmi-to-vpto may not recover that proof from MTE or - caller context. - -group_store: - source group_slots layout and explicit output stride decide packed slots=8 - versus row-local slots=1 store legality. If another legal store recipe - needs more information, assignment must make that information explicit in the - op or helper IR before vmi-to-vpto uses it. +The useful shared fact is the part that would otherwise be recomputed by two or +more stages and must stay identical for correctness: -masked_load: - explicit passthrough, mask layout, full physical read, shaped safe-tail memref, - or an explicit diagnostic decide legality. A future stable gather fallback - must be made explicit by assignment before vmi-to-vpto lowers it. - -masked_store/select/elementwise: - operand/result layouts and explicit mask granularity decide the lowering. - They remain transfer ops unless a future case introduces competing recipes. - -extf/truncf: - dense width-changing paths are layout-determined today. Any future - commute-through-group-broadcast or alternative VCVT recipe must have an - explicit IR carrier first. +```text +cast width ratio: + assignment uses it to request source/result layouts and insert ensure_layout. + validation uses it to reject unsupported assigned cast shapes. + lowering uses it to check the local op shape before emitting VPTO. + +group_reduce lane partition: + assignment uses N/G and accumulator element width to request source/mask and + result layouts. + validation uses the same math to reject legacy or incomplete group_slots. + lowering uses the already assigned layouts to select the local VPTO sequence. + +layout materialization shape: + assignment may insert ensure_layout without proving every physical sequence. + validation and lowering use one support query to decide whether that explicit + helper is materializable on the target. + optimization uses the same query only when it wants to fold/sink/remove an + explicit helper. ``` -Forbidden non-local recipe recovery: +The helper is not useful when it only renames one local pattern. A single +`if (is this op with this attr)` that is not shared by assignment, validation, +lowering, or an optimization should stay local to that pass. The support layer +exists to prevent divergent layout math, not to move every branch into a table. + +Forbidden non-local lowering recovery: ```text -No pattern may synthesize a recipe or memory proof by: +No pattern may recover a lowering decision or memory proof by: - walking from group_reduce to the load/group_load producer - walking from store/broadcast/truncf to the group_reduce producer - scanning sibling users of a group_slots value @@ -463,48 +441,59 @@ If the current op lacks enough local information, `vmi-to-vpto` emits `VMI-LAYOUT-CONTRACT` at the current op and prints the op name, logical type, assigned layouts, and the missing decision class. -## 5. Local Recipe Registry - -The compiler owns a target-aware local recipe registry. Layout assignment and -layout optimization query this registry to decide which explicit IR shape to -produce. `vmi-to-vpto` queries the same registry only to verify and lower the -current op from local information. - -The registry is not serialized as a separate recipe-selection attribute. If -two legal physical recipes cannot be distinguished by the current op's name, -attrs, operands, operand/result layouts, helper ops, and target options, the -VMI IR is missing a carrier. Add an explicit attr, operand, helper op, or -semantic op before implementing that recipe. +## 5. Layout Requests, Helpers, And Optimization -### 5.1 Recipe Kinds +The compiler must not carry a target-aware lowering-plan registry as the shared +contract between assignment, optimization, validation, and lowering. The +shared contract is: ```text -ProducerRecipe: - op can produce result layout L - example: load -> deinterleaved=4 using DINTLV_B32 + vdintlv - -ConsumerRecipe: - op can consume operand layout L - example: group_reduce S=32 consumes deinterleaved=4 - -TransferRecipe: - op ties operand/result layouts - example: addf requires same dense layout for operands/result - -MaterializationRecipe: - layout A -> layout B without changing logical value - example: deinterleaved=4 -> contiguous by vintlv tree +1. assigned layouts on VMI types +2. explicit use-site helpers: ensure_layout, ensure_mask_layout, + ensure_mask_granularity +3. explicit op attrs/operands that are part of the semantic op +4. small layout fact classifiers shared only where they remove duplicated + layout math +5. target capability diagnostics +``` -RematerializationRecipe: - cheap producer can be cloned for a use-site layout - example: broadcast/create_mask/group_broadcast +This split makes optimization simpler only when optimization is phrased as +rewriting explicit helper IR: -DiagnosticRecipe: - known unsupported semantic/capability boundary - example: compact S=12 requires gather materialization +```text +baseline: + %x_d2 = pto.vmi.extf %x_f16 + %a = pto.vmi.addf %x_d2, %k_d2 + %a_c = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + pto.vmi.store %a_c, %out0 + %x_c = pto.vmi.ensure_layout %x_d2 : deinterleaved=2 -> contiguous + pto.vmi.store %x_c, %out1 + +fold-consumers: + checks only each local ensure_layout + store use. + If VMILayoutSupport says the store can preserve row-major memory from the + source layout, rewrite that use to store the source directly. + It does not inspect sibling users of %x_d2 and does not recompute the layout + assignment. + +rematerialize: + checks only cheap producer + ensure_layout. + If the producer can directly create the requested layout, clone/rematerialize + that producer for the use. + Memory producers such as group_slot_load are excluded until a separate proof + says cloning is semantically and economically valid. + +sink-materialization: + checks only explicit ensure_* operands of a layout-transparent op. + If every operand helper is compatible, rebuild the op in the source layout and + leave one ensure_* on the result. ``` -### 5.2 Dense Recipes From Cases +If an optimization needs a global cost decision, it should produce a new +explicit IR shape and then rely on canonicalize/CSE. It must not communicate a +private decision to `vmi-to-vpto`. + +### 5.1 Baseline Dense Layout Requests ```text f16 -> f32: @@ -526,47 +515,43 @@ f32 -> f8: elementwise dense: all dense operands/results share the same layout -broadcast scalar: - rematerializable to any dense layout requested by the consumer - -load: - may be rematerialized per use when two consumers request incompatible dense - layouts, such as S=16 deinterleaved=2 and S=32 deinterleaved=4 +dense store: + requests contiguous source + if the stored value is assigned deinterleaved, baseline assignment inserts + ensure_layout at the store use ``` -### 5.3 Group Recipes From Cases +### 5.2 Baseline Group Layout Requests ```text -group_reduce_add{f|i} typed shape classification: - define E = sizeof(T), VLaneElems = 32B / E, L = 256B / E, S = N / G. - S=VLaneElems uses contiguous input and group_slots(G, slots=8). - S=2*VLaneElems uses deinterleaved=2 input/mask and group_slots(G, slots=8). - S=4*VLaneElems uses deinterleaved=4 input/mask and group_slots(G, slots=8). - S>=L && S%L==0 uses contiguous input/mask and group_slots(G, slots=1); - lowering reduces each full physical chunk, accumulates all chunks in the - same logical group through lane0, and writes one physical result part per - group. +group_reduce_add{f|i}: + uses the group_reduce layout fact in section 4.1. + The source and mask operands request the computed dense layout. + The result is assigned group_slots(G, slots=8) or group_slots(G, slots=1). + Floating-point `group_reduce_addf` carries `reassoc`; integer + `group_reduce_addi` does not. group_slot_load: result group_slots(G, slots=8) for packed slots result group_slots(G, slots=1) for row-local slots group_broadcast: - source group_slots(G,K) - result is dense layout requested by each consumer - rematerialize per use instead of forcing one result layout + source requests group_slots(G,K) + result requests one dense layout + incompatible dense consumers are represented by ensure_layout after the + broadcast result; a later optimization may clone/rematerialize the broadcast group_store: - source group_slots(G,K) + source requests group_slots(G,K) + explicit output stride attrs/operands decide store legality group_slot_cast f32 -> f16: - slots=1 row-local source/result is legal with - group_slot_cast_slots1_f32_to_f16 - slots=8 packed source is illegal unless a packed slot-preserving recipe is - registered + slots=1 row-local source/result is legal + slots=8 packed source is illegal unless a future explicit helper or semantic + op defines the packed slot-preserving transform ``` -### 5.4 Tail And Memory Safety Recipes +### 5.3 Tail And Memory Safety Mask semantics and memory legality are separate: @@ -607,7 +592,7 @@ one mask used by f32 and f16 consumers: vmi-to-vpto consumes the assigned per-use mask materialization ``` -### 5.5 Case-Driven Request Matrix +### 5.4 Case-Driven Request Matrix The first implementation should build requests from the following finite table. This table is deliberately case-derived; adding a new request kind requires a @@ -616,8 +601,9 @@ new catalog case or a proof that it is equivalent to one listed here. ```text dense store: requests dense contiguous source - if source is deinterleaved, assignment must insert ensure_layout or select a - store recipe such as vstsx2 that consumes the assigned layout explicitly + if source is deinterleaved, baseline assignment inserts ensure_layout at the + store use. A later optimization may fold that helper into a layout-aware + store lowering such as vstsx2. truncf f32 -> f16: requests source deinterleaved=2, block_elems=1 @@ -642,8 +628,9 @@ group_reduce_add{f|i}: group_broadcast: requests source group_slots(num_groups, slots=K) - produces one dense result layout per consumer request - is cloned per incompatible dense consumer + produces one assigned dense result layout + incompatible dense consumers are represented by ensure_layout uses; a later + optimization may clone/rematerialize the group_broadcast per consumer group_store: requests source group_slots(num_groups, slots=K) @@ -666,7 +653,7 @@ group_slot_load: group_load: requests result deinterleaved=2/4, block_elems=8 for S=16/S=32 block - fragment recipes, or contiguous for row-local full-chunk recipes + fragments, or contiguous for row-local full chunks masked_load: requests result layout from its consumers @@ -674,18 +661,20 @@ masked_load: requires explicit passthrough; padding is not synthesized masked_store: - requests dense source layout selected by the store recipe + requests dense source layout required by the store op requests mask layout matching the source layout and store element granularity does not choose memory safety for an earlier load create_mask/create_group_mask: - produces whichever mask layout each consumer requests - may be cloned per incompatible mask layout or granularity + produces one assigned mask layout and granularity + incompatible mask consumers are represented by ensure_mask_layout or + ensure_mask_granularity; optimization may clone/rematerialize the mask op scf.if/scf.for/call/return: requests equality across carried VMI values, yielded values, call operands, callee arguments, and function results - private/internal functions may specialize or materialize at boundaries + baseline private/internal functions materialize at boundaries; optimization + may specialize signatures public/external VMI boundaries are diagnostics until an ABI is defined ``` @@ -694,48 +683,50 @@ Important negative requests: ```text ordinary dense add/mul/store/truncf cannot request group_slots packed group_slots(slots=8) cannot request width-changing cast unless a packed -slot-preserving cast recipe is registered +slot-preserving cast transform is explicitly represented slots=1 group_store cannot request unit-stride row-major output until a pack or -unaligned-store recipe exists +unaligned-store transform is explicitly represented ``` -### 5.6 Conflict Resolution Matrix +### 5.5 Optimization Hooks + +Baseline assignment resolves incompatible use-site requests by keeping one +assigned layout on the value and inserting explicit helpers at the use sites +that need another layout. It does not clone producers, rematerialize cheap +ops, choose memory-fused layouts by cost, or specialize private function +signatures for performance. -When one value receives incompatible requests, assignment resolves it using the -first legal row below. `vmi-to-vpto` never repeats this decision. +Those choices belong to later VMI layout optimization passes. They consume +the explicit helper IR and may rewrite it when the rewrite preserves the same +logical value and externally visible memory effect: ```text -cheap producer with multiple requested layouts: - clone the producer and assign each clone independently - examples: load, broadcast, create_mask, create_group_mask, group_broadcast - memory-read producers require the same explicit no-alias and safe-read proof - at each clone site +ensure_layout + store: + fold into a layout-aware store if the store can directly consume the source + layout and still write row-major memory -non-cheap value with registered materialization: - keep one chosen layout on the value and insert ensure_layout at the use site - examples: deinterleaved=4 -> contiguous before dense store +producer + ensure_layout: + clone/rematerialize the producer for that use only when the producer is cheap + or has an explicit safe-read proof -layout-transparent chain: - assign the whole equivalence class to the non-contiguous consumer request when - that avoids materialization - examples: broadcast -> addf -> S=32 group_reduce +elementwise chain + ensure_layout: + sink or hoist materialization through pure layout-transparent ops -control-flow join: - all incoming values must be materialized to one layout before yield/branch - examples: scf.if yielding group_slots, scf.for loop-carried group_slots +group_broadcast + incompatible dense consumers: + type each group_broadcast op for its consumer layout; do not force one result + layout across independent group_broadcast users -private function boundary: - specialize or materialize at call/callee-entry before vmi-to-vpto +create_mask/create_group_mask + incompatible mask consumers: + clone/rematerialize the mask producer per layout or predicate granularity -no clone/materialization/specialization recipe: - emit a diagnostic naming the requesting op and both layouts +private function boundary: + specialize function signatures only in an optimization pass; baseline + assignment materializes at boundary uses ``` -The cost model may choose between legal rows only when the observable contract -is identical. For example, S=16 `block_elems=1` and `block_elems=8` are both -valid reduce inputs, but `block_elems=8` is selected only when a producer recipe -such as strided `group_load` naturally creates 32B row fragments or when cost -proves it cheaper without breaking another consumer such as `truncf`. +If no helper materialization or optimization rewrite is legal, the diagnostic +must name the value's assigned layout, the use-site requested layout, and the +op that requested it. ## 6. Layout Assignment Algorithm @@ -758,7 +749,7 @@ Create a use-site request for: ```text 1. every operand use that requires a specific layout 2. every control-flow yield/branch/call/return edge -3. every memory operation that requires a memory legality recipe +3. every memory operation that requires an explicit memory legality proof ``` ### 6.2 Constraints @@ -767,14 +758,14 @@ Hard constraints: ```text group_slots cannot feed ordinary dense consumers -direct group-slot width-changing cast requires a slot-preserving recipe +direct group-slot width-changing cast requires an explicit slot-preserving transform public/external VMI function boundary requires a stable ABI or diagnostic S=32 fast tail load requires full_tile_readable or gather fallback ``` -`slots = 1` row-local cast may satisfy the slot-preserving recipe requirement. +`slots = 1` row-local cast may satisfy the slot-preserving transform requirement. Packed `slots = 8` f32->f16 remains a diagnostic unless a separate packed cast -or unpack/materialization recipe is registered. +or unpack/materialization transform is represented explicitly. Equivalence constraints: @@ -788,22 +779,23 @@ scf.if/scf.for: as the region result/iter_arg ``` -Candidate constraints: +Canonical baseline constraints: ```text S=16 group_reduce: - choose block_elems=1 or block_elems=8 by cost and explicit assignment constraints + request deinterleaved=2; baseline uses block_elems=1 unless the producer + result already carries block_elems=8 as an explicit layout one dense value feeding S=16 and S=32 group_reduce: - rematerialize a cheap producer per consumer layout, or insert an explicit - materialization recipe; the final lowering pass must not pick one layout after - seeing both users + keep the value's assigned layout and insert ensure_layout at both use sites + that need deinterleaved=2 or deinterleaved=4 load/group_load: - choose memory recipe and result layout together + use the op's assigned result layout and explicit memory-safety attrs only group_broadcast: - rematerialize per dense consumer layout + keep one assigned dense result layout and communicate other dense use layouts + through ensure_layout ``` ### 6.3 Solving @@ -812,26 +804,24 @@ Recommended solving order: ```text 1. Build function/control-flow SCCs. -2. Collect candidate recipes for every op. -3. Propagate hard required layouts from consumers. -4. Propagate producer natural layouts where they are unique. -5. Resolve multi-recipe ops by cost. -6. Insert use-site materialization where a value has multiple incompatible uses. -7. Rematerialize cheap producers instead of materializing when cheaper. -8. Specialize internal function signatures. -9. Emit diagnostics for unsatisfied hard constraints. -10. Rewrite VMI types and insert explicit helper/rematerialized ops. +2. Collect natural producer layouts and hard use-site layout requests. +3. Propagate equality constraints through dense elementwise ops and CFG edges. +4. Choose one deterministic assigned layout for each value or equivalence + class. +5. Insert ensure_layout / ensure_mask_layout / ensure_mask_granularity at uses + whose requested layout differs from the assigned layout. +6. Emit diagnostics for unsupported semantic constraints or missing explicit + memory-safety proofs. +7. Rewrite VMI types and insert explicit helper ops. ``` -Tie-breaking must be deterministic. Suggested priority: +Tie-breaking must be deterministic and deliberately simple. Suggested priority: ```text -1. Avoid unsupported recipes. -2. Prefer rematerializing cheap producers over register materialization. -3. Prefer layouts accepted by all consumers without conversion. -4. Prefer memory-fused layout recipes over load + register rearrange. -5. Prefer fewer VPTO instructions. -6. Prefer contiguous only when cost ties and no consumer requests a special layout. +1. Preserve an explicit user-provided layout attr. +2. Preserve a unique producer natural layout when present. +3. Preserve an equality-class non-contiguous layout when required by a hard op. +4. Otherwise choose contiguous. ``` ## 7. Control Flow And Functions @@ -889,7 +879,7 @@ For each op, the pattern: 1. reads operand/result layouts 2. reads current op attrs and operand values 3. asks TypeConverter for ordered physical values -4. emits the locally implied VPTO recipe +4. emits the locally implied VPTO lowering 5. fails if target capability or required local proof is absent ``` @@ -916,7 +906,7 @@ current VMI op body/attrs: helper materialization chain: allowed only to strip ensure_mask_layout / ensure_mask_granularity for - static predicate analysis that does not choose a different layout or recipe + static predicate analysis that does not choose a different layout or lowering diagnostic embellishment: allowed only to improve an already-failed capability message, such as naming @@ -930,7 +920,7 @@ grouped masks: assignment emits explicit contiguous and deinterleaved mask values, and `vmi-to-vpto` lowers the deinterleaved mask op itself through contiguous grouped-mask materialization followed by predicate deinterleave. It does not walk from `group_reduce_addf` to the mask producer to choose or reject -the recipe. Dynamic `active_elems_per_group` follows the same rule: the +the lowering. Dynamic `active_elems_per_group` follows the same rule: the `create_group_mask` op lowers its own SSA scalar with vci/vshrs/vshls/vsub/vcmps for contiguous chunks before any predicate deinterleave. @@ -952,8 +942,8 @@ group_slots(G,K): slot_block0, slot_block1, ... ``` -Two physical bundle entries may alias the same VPTO SSA value when the local -recipe proves they have the same contents, such as group_broadcast feeding both +Two physical bundle entries may alias the same VPTO SSA value when the current +op semantics prove they have the same contents, such as group_broadcast feeding both parts of a `deinterleaved=2` broadcast result. Arity still follows the layout; aliasing is not a different layout. @@ -989,15 +979,87 @@ public VMI function boundary: make function internal, inline before assignment, or define ABI layout ``` -## 11. Design Completion Criteria +## 11. Implementation Migration Checks + +The design is useful only if the implementation removes duplicated decision +points instead of renaming them. The migration target is: + +```text +assignment: + computes assigned layouts, records use-site requests, inserts ensure_* helpers, + and diagnoses unsupported semantics + does not clone/rematerialize producers + does not choose memory-fused layouts by cost + does not inspect sibling users to optimize a value + +layout optimization: + consumes explicit ensure_* helpers + may fold ensure_layout into layout-aware consumers + may clone/rematerialize cheap producers + may sink/hoist materialization through pure elementwise chains + may specialize private function signatures + +vmi-to-vpto: + consumes current op attrs/operands, assigned operand/result layouts, and + explicit helper ops + performs local physical shape and target-capability checks + does not recover layout plans from producers, sibling users, CFG regions, or + callees/callers +``` + +Concrete implementation debt to remove: + +```text +1. Move assignment-side data/mask rematerialization into + vmi-layout-rematerialize. Baseline assignment should insert ensure_* for + mismatched uses. +2. Keep `VMILayoutSupport` as target capability and layout-shape queries, not + as a shared plan table. Group-reduce layout math now lives in + `getPreferredGroupReduceLayoutFact`. Dense cast layout shape now lives in + `getPreferredCastLayoutFact`. Helper materialization gates use + `canMaterializeDataLayout`, `canMaterializeMaskLayout`, and + `canMaterializeMaskGranularity`. +3. Assignment, validation, and lowering may call layout fact helpers, but must + not each independently derive VLaneElems/groupSize/factor/slots rules. +4. Keep store-fold, rematerialization, and sink/hoist as local rewrites over + explicit ensure_* IR. They must not walk sibling users to rediscover why the + helper exists. +5. Update pass descriptions, diagnostics, and tests so "assignment only" output + is legal with helpers, and optimized output is a separate, equivalent IR + form. +``` + +Regression tests should prove the boundary: + +```text +assignment only: + multi-consumer values keep one assigned layout and use ensure_* at mismatched + uses + +fold-consumers: + ensure_layout + store becomes a layout-aware store only when the consumer can + preserve the same row-major memory effect + +rematerialize: + cheap producer + ensure_layout becomes a cloned/rematerialized producer; with + the pass disabled, the ensure_layout form remains legal + +vmi-to-vpto: + rejects any residual need for producer/user context with VMI-LAYOUT-CONTRACT +``` + +## 12. Design Completion Criteria The design is complete only when: ```text -1. every case in vmi-layout-lowering-cases.md maps to a local recipe -2. every local recipe can be emitted without looking at producer/user context +1. every case in vmi-layout-lowering-cases.md maps to assignment requests, + explicit helpers, or a precise diagnostic +2. every VMI-to-VPTO lowering can be emitted without looking at producer/user + context 3. every unsupported case has a precise capability diagnostic -4. every control-flow/function boundary either specializes layout or diagnoses +4. every control-flow/function boundary materializes, specializes in an + optimization pass, or diagnoses 5. every mask has explicit data layout and predicate granularity 6. every positive case has end-to-end lit coverage 7. every simulator-supported positive case has simulator validation diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index e17c14844b..d2d7b3835d 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -83,7 +83,7 @@ G % K == 0 K must fit in the physical vreg element count ``` -`K` is selected by the producer/consumer local recipe. It is not always 8. For +`K` is selected by the producer/consumer layout support rule. It is not always 8. For `VCGADD`-packed results, `K = 8` matches the eight 32B block results written to the low lanes of one destination vreg. For row-local reductions where each logical group already occupies one full 256B vreg, `K = 1` keeps each group's @@ -99,10 +99,11 @@ physical slot block slot_block(g), lane slot_lane(g) All other lanes are undefined for ordinary VMI consumers. They may only be read by group-aware ops that define how to interpret group slots. -## 2. Recipe Selection Rules +## 2. Layout Support Selection Rules -VMI cast ops must not hard-code one physical `vcvt` recipe as their semantic -layout rule. +VMI cast ops must not hard-code one physical `vcvt` lowering as their semantic +layout rule. Layout assignment records the required value layout; target +support queries only answer whether that layout can be materialized or lowered. ```text dense cast: @@ -112,7 +113,7 @@ dense cast: group-slot cast: source/result are both group_slots(G,K). lowering preserves slot_block(g) and slot_lane(g). Width-changing casts are - legal only when a slot-preserving VPTO recipe is registered, or when the cast + legal only when slot-preserving VPTO lowering support exists, or when the cast can be commuted through a later group-aware consumer such as group_broadcast. ``` @@ -171,7 +172,7 @@ the immediately following complete endpoints. 3.16 group_slot_load layout contract complete 3.17 group_broadcast feeding deinterleaved consumer complete 3.18 one value with dense and group-reduce consumers complete/materialization -3.19 S=16 reduce block_elems recipe selection complete/diagnostic +3.19 S=16 reduce block_elems support selection complete/diagnostic 3.20 group_slots control-flow join complete 3.21 S=32 tail with full-tile-readable source complete 3.22 scf.for loop-carried layout complete @@ -752,7 +753,7 @@ row-major store of this layout must be rejected with: VMI-LAYOUT-CONTRACT: pto.vmi.store requires materializing #pto.vmi.layout to contiguous, but no - VPTO block-interleave materialization/store plan is registered. + VPTO block-interleave materialization/store support exists. ``` #### 3.5.3 Reduce Result, Elementwise, Store @@ -1182,8 +1183,8 @@ slot_lane(r) = 0 Trying to canonicalize this result to `slots = 8` would require packing lane 0 from eight different physical vregs into lanes 0..7 of one vreg. This document -does not use that plan. `slots = 1` is the canonical layout for S=64 row-local -group reductions. +does not use that packing transform. `slots = 1` is the canonical layout for +S=64 row-local group reductions. #### 3.7.1 Reduce And Store Group Sums @@ -1464,8 +1465,8 @@ group_off + 0, group_off + 1, group_off + 2, ... Only the first address is necessarily 32B-aligned. The remaining f32 addresses are 4B apart and are not valid for this `vsts` lowering. The compiler must not -accept this as a clean lowering until either a pack-to-slots=8 plan or an -unaligned-store plan is selected. +accept this as a clean lowering until either pack-to-slots=8 materialization +support or unaligned-store support exists. VMI input: @@ -1523,7 +1524,7 @@ layout transition explicit: `group_broadcast` first produces a dense contiguous f32 value, then `pto.vmi.ensure_layout` materializes the deinterleaved=2 f32 view required by dense `f32 -> f16` truncation. A future direct `group_broadcast -> deinterleaved=2` lowering may remove that materialization, -but the `group_broadcast` result layout must make that recipe explicit rather +but the `group_broadcast` result layout must make that support path explicit rather than hiding it inside `truncf` lowering. VPTO lowering result for one full 8-row tile: @@ -1776,13 +1777,13 @@ contract: ```text VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf with group size 32 and num_groups tail 6 requires - materializing #pto.vmi.layout. The registered fast plan + materializing #pto.vmi.layout. The fast lowering support uses vldsx2 DINTLV_B32 over a full 8-row tile. This source is not marked - full-tile-readable, and the stable gather tail plan is not implemented. + full-tile-readable, and the stable gather tail fallback is not implemented. ``` -If a future option enables the stable gather tail plan, the same VMI input may -lower by gathering only the active lanes. Until that plan is registered, the +If a future option enables the stable gather tail fallback, the same VMI input +may lower by gathering only the active lanes. Until that support exists, the converter must not silently issue the full-tile `vldsx2` loads. ### 3.12 Control-Flow Join Before `group_reduce` @@ -1848,7 +1849,8 @@ VPTO lowering result for the join: } ``` -The consumer after the join is the same S=32 reduction plan as section 3.6: +The consumer after the join uses the same S=32 reduction lowering support as +section 3.6: ```text %all_b32 = pto.pge_b32 "PAT_ALL" @@ -1876,7 +1878,7 @@ for r = 0..7: ``` If the two branches cannot be assigned the same layout and no materialization -plan exists before `scf.yield`, the required diagnostic is: +support exists before `scf.yield`, the required diagnostic is: ```text VMI-LAYOUT-CONTRACT: @@ -1917,7 +1919,7 @@ Required diagnostic: VMI-LAYOUT-CONTRACT: pto.vmi.truncf cannot lower from #pto.vmi.layout f32 to f16 because no - slot-preserving width-changing VPTO plan is registered. f32->f16 vcvt writes + slot-preserving width-changing VPTO support exists. f32->f16 vcvt writes even/odd sub-lanes, not lanes 0..7. Use group_broadcast before truncf, or keep the group_store element type as f32. ``` @@ -1936,8 +1938,8 @@ VMI input: pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} ``` -Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based plans use -32B groups, i.e. 8 f32 elements per row fragment: +Here `S = 96 / 8 = 12` f32 elements per group. The current VCG-based lowering +support uses 32B groups, i.e. 8 f32 elements per row fragment: ```text S = 8 -> one VCGADD block per group @@ -1950,8 +1952,8 @@ Required diagnostic: ```text VMI-LAYOUT-CONTRACT: - pto.vmi.group_reduce_addf with f32 group size 12 has no registered VPTO - layout plan. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. + pto.vmi.group_reduce_addf with f32 group size 12 has no supported VPTO + layout/lowering path. Supported VCG-based f32 group sizes are 8, 16, 32, and 64. A scalar/gather fallback or a rewrite to logical group size 16 with an explicit per-group mask is required. ``` @@ -2291,8 +2293,8 @@ for g = 0..7: out[group_off + g] = rhs_base[rhs_off + g] ``` -If `source_group_stride != 1`, this packed `slots = 8` plan requires a -strided/gather group-slot load materializer. Until that plan is registered, +If `source_group_stride != 1`, this packed `slots = 8` layout requires a +strided/gather group-slot load materializer. Until that support exists, `group_slot_load` with `slots = 8` and non-unit stride must diagnose instead of silently using full-group `group_load`. @@ -2461,8 +2463,9 @@ look through the defining `group_broadcast` and choose a hidden broadcast shape. This case forces layout assignment to handle a solvable use-site conflict. One consumer requires an S=32 group-reduce layout; another consumer requires dense row-major store. This is not semantically illegal. It must be solved by -use-site materialization or producer rematerialization when a registered plan -exists. +explicit use-site materialization. A later optimization pass may fold the +materialization into a store or rematerialize a cheap producer when the required +support exists. VMI input: @@ -2487,10 +2490,10 @@ Assigned layouts: requires #pto.vmi.layout ``` -If `%x` is cheap to rematerialize, layout assignment may clone the producer for -the dense store. Otherwise, if the registry has a `deinterleaved = 4 -> -contiguous` materialization plan, layout assignment may keep `%x` in -`deinterleaved = 4` and insert `ensure_layout` before the dense store. +Baseline layout assignment keeps `%x` in the group-reduce layout and inserts +`ensure_layout` before the dense store use. A later rematerialization pass may +clone the load for the dense store if that is profitable. A later fold-consumer +pass may also fold `ensure_layout + store` into a layout-aware store lowering. VPTO lowering result: @@ -2551,18 +2554,17 @@ for i = 0..255: copy_out[off + i] = base[off + i] ``` -If the `deinterleaved = 4 -> contiguous` plan is not registered, the required -diagnostic is: +If `deinterleaved = 4 -> contiguous` materialization support does not exist, the +required diagnostic is: ```text VMI-LAYOUT-CONTRACT: value %x is required as #pto.vmi.layout by pto.vmi.group_reduce_addf and as #pto.vmi.layout by - pto.vmi.store, but no registered materialization plan exists at the store - use site. + pto.vmi.store, but no materialization support exists at the store use site. ``` -### 3.19 S=16 Reduce `block_elems` Recipe Selection +### 3.19 S=16 Reduce `block_elems` Support Selection S=16 f32 group reduction has two legal dense input layouts: @@ -2576,10 +2578,11 @@ It is also a valid S=16 reduction layout: each physical part contains eight values per row, so `VCGADD` can reduce each part and `VADD` can combine the two partial sums. -`block_elems = 8` is still useful when the producer is a block load plan such -as `BDINTLV` or `vsldb` over 32B row fragments. Layout assignment must select -between these plans by producer/consumer cost. It must not hard-code S=16 -reduce to `block_elems = 8`. +`block_elems = 8` is still useful when the producer is a block load shape such +as `BDINTLV` or `vsldb` over 32B row fragments. Baseline layout assignment must +express any mismatch with an explicit `ensure_layout`; producer rematerialization +or consumer folding can choose the cheaper equivalent form later. Assignment +must not hard-code S=16 reduce to `block_elems = 8`. #### 3.19.1 Continuous S=16 Reduce And Truncf, `block_elems = 1` @@ -2662,7 +2665,7 @@ for i = 0..127: #### 3.19.2 Block-Load Producer Fixed To `block_elems = 8` This is the real conflict case. The value is fixed to `block_elems = 8` -because the producer is a registered block-load plan. A later `truncf` +because the producer uses block-load support. A later `truncf` requires element-parity `block_elems = 1`. VMI input: @@ -2691,7 +2694,8 @@ Assigned layouts before the conflicting `truncf` use: ``` The reduction path is legal and uses the same `vsldb` block-load shape as -section 3.15.2. The `truncf` path is legal only if one of these plans exists: +section 3.15.2. The `truncf` path is legal only if one of these transforms +exists: ```text 1. rematerialize the original memory producer as block_elems=1 @@ -2699,15 +2703,15 @@ section 3.15.2. The `truncf` path is legal only if one of these plans exists: 3. use an explicitly enabled scratch/reload fallback ``` -If no such plan is registered, the required diagnostic is: +If no such transform exists, the required diagnostic is: ```text VMI-LAYOUT-CONTRACT: pto.vmi.truncf requires #pto.vmi.layout, but the source value is - fixed to #pto.vmi.layout by the selected - strided group_load plan. Register a rematerialization or preserving - materialization plan, or avoid consuming this block-loaded value with truncf. + fixed to #pto.vmi.layout by the strided + group_load. Add rematerialization or preserving materialization support, or + avoid consuming this block-loaded value with truncf. ``` ### 3.20 `group_slots` Control-Flow Join @@ -2987,8 +2991,9 @@ for r = 0..7: ### 3.23 `group_broadcast` With Multiple Dense Consumers One `group_slots` value may feed multiple `group_broadcast` uses with different -dense result layout requirements. Layout assignment should rematerialize the -broadcast per use instead of forcing one result layout onto all consumers. +dense result layout requirements. Each `group_broadcast` op has its own result +layout, so layout assignment should type each op at its use site instead of +forcing one result layout onto all consumers. VMI input: @@ -3046,7 +3051,7 @@ layout. It is that each use has an explicit layout boundary: %b_for_cast_split = pto.vmi.ensure_layout %b_for_cast ``` -If a future direct `group_broadcast -> deinterleaved` recipe is added, layout +If a future direct `group_broadcast -> deinterleaved` support path is added, layout assignment may assign `%b_for_mul` or `%b_for_cast` directly to that layout, but the choice must still be visible in the assigned IR. @@ -3498,21 +3503,21 @@ Required diagnostic when the stride is not block-aligned: ```text VMI-LAYOUT-CONTRACT: pto.vmi.group_load group_size 32 with source_group_stride not divisible by - 8 f32 elements cannot use the registered vsldb strided-block plan. Enable a - stable gather plan or choose a block-aligned source_group_stride. + 8 f32 elements cannot use the vsldb strided-block lowering support. Enable a + stable gather fallback or choose a block-aligned source_group_stride. ``` Required assignment rule: ```text -This producer selects the S=32 block-fragment plan: +This producer requires the S=32 block-fragment layout: #pto.vmi.layout It must not be unified with the contiguous-load S=32 plan from section 3.6: #pto.vmi.layout Both layouts are legal inputs to group_reduce_addf S=32, but they require -different producer materialization plans. +different producer materialization/lowering support. ``` ### 3.28 `group_slot_load` `slots = 1` With Aligned Non-Unit Stride @@ -3703,7 +3708,7 @@ Required assignment rule: the per-use typed mask materialization inserted by vmi-layout-assignment. For a rematerializable `create_mask`, assignment may clone it as b32/b16 masks. For a non-rematerializable mask producer, assignment must insert -`ensure_mask_granularity` or diagnose if no materialization plan is registered. +`ensure_mask_granularity` or diagnose if no materialization support exists. ``` ### 3.30 `masked_load` Tail Without Padding @@ -4002,10 +4007,10 @@ S=32 reduce over 8 groups: #pto.vmi.layout ``` -The program is semantically legal. Layout assignment must solve it by cloning -or rematerializing the cheap load for one use, or by inserting an explicit -registered materialization plan. `vmi-to-vpto` must not inspect both users and -choose one locally. +The program is semantically legal. Baseline layout assignment solves it by +inserting an explicit use-site `ensure_layout`. A later optimization pass may +clone or rematerialize the cheap load for one use. `vmi-to-vpto` must not +inspect both users and choose one locally. VMI input: @@ -4114,11 +4119,11 @@ for r = 0..7: Required assignment rule: ```text -If a cheap producer such as load can produce both requested layouts, clone or -rematerialize it at the use sites and assign each clone independently. If the -producer is not rematerializable and no deinterleaved=2 <-> deinterleaved=4 -materialization plan is registered, emit a layout-contract diagnostic naming -both consumers and both required layouts. +Baseline assignment inserts `ensure_layout` at the mismatched use. A later +rematerialization pass may clone a cheap producer such as load and assign each +clone independently. If no deinterleaved=2 <-> deinterleaved=4 materialization +support exists, emit a layout-contract diagnostic naming both consumers and +both required layouts. ``` ### 3.34 S=64 Group-Slot Result `f32 -> f16` Cast @@ -4380,7 +4385,7 @@ VPTO lowering result: pto.vsts %out16_block, %out16[%group_off16], %slot8 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -// Row-local S=64 RHS: rematerialize the same scalar stream into one lane-0 +// Row-local S=64 RHS: a separate group_slot_load op produces one lane-0 // value per physical row-local result. %rhs64_r = pto.vsldb %rhs_base[%rhs_off_plus_r], %c0_i16, %c0_i16, %one_b32 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> @@ -4406,9 +4411,12 @@ for r = 0..7: Required assignment rule: ```text -`group_slot_load` is cheaply rematerializable. If two use sites request -different `group_slots` layouts, clone/rematerialize the load per use. Do not -invent a common layout or make `vmi-to-vpto` inspect both users. +`group_slot_load` is a memory op, so the baseline rematerialization pass must +not clone it as a generic cheap producer. If two use sites need different +`group_slots` layouts, the legal first-stage shape is to write two explicit +`group_slot_load` ops, as above, or to introduce a future load-cloning +optimization with an explicit memory-safety proof. Do not invent a common +layout or make `vmi-to-vpto` inspect both users. ``` ### 3.37 S=64 `group_store` With Non-Unit Output Stride @@ -4467,15 +4475,15 @@ Required assignment rule: ```text If `group_store` has non-unit row_stride and the source can legally use `slots = 1`, assignment may select `slots = 1` to keep the store legal. If the -source is fixed to `slots = 8`, the current target plan must diagnose unless a -strided packed store materializer is registered. +source is fixed to `slots = 8`, current target support must diagnose unless a +strided packed store materializer exists. ``` ### 3.38 Multi-Tile S=32 `group_reduce` The S=32 plan is not only a one-tile special case. For more than eight groups, layout assignment keeps the same layout and `vmi-to-vpto` emits the same -8-row tile recipe for each physical tile. +8-row tile lowering sequence for each physical tile. VMI input: @@ -4663,10 +4671,11 @@ though both have four physical parts. ### 3.40 Scalar Broadcast Feeding Dense And Grouped Users This case fixes the rule for ordinary scalar broadcasts. A scalar broadcast is -not born with a physical layout. Layout assignment may either rematerialize it -per use, or assign the transfer-equivalent producer chain to the non-contiguous -layout requested by the grouped consumer and insert an explicit materialization -at the dense store use. The latter is the concrete plan below. +not born with a physical layout. Baseline layout assignment assigns the +transfer-equivalent producer chain to the non-contiguous layout requested by the +grouped consumer and inserts an explicit materialization at the dense store use. +The later `vmi-layout-rematerialize` pass may replace that helper with a cloned +broadcast when profitable. VMI input: @@ -4785,20 +4794,21 @@ for r = 0..7: Required assignment rule: ```text -`broadcast` is layout-transparent and cheaply rematerializable, but assignment -does not have to force a separate contiguous broadcast just because a dense -store exists. It may choose a common deinterleaved compute layout for -transfer-equivalent elementwise ops and insert `ensure_layout` at the dense -store. The required invariant is that this choice is explicit in the assigned -IR; `vmi-to-vpto` must not infer it by inspecting both users. +`broadcast` is layout-transparent and cheaply rematerializable by the optional +`vmi-layout-rematerialize` pass, but baseline assignment does not have to force +a separate contiguous broadcast just because a dense store exists. It may +choose a common deinterleaved compute layout for transfer-equivalent elementwise +ops and insert `ensure_layout` at the dense store. The required invariant is +that this choice is explicit in the assigned IR; `vmi-to-vpto` must not infer it +by inspecting both users. ``` ### 3.41 Non-Rematerializable Value With Incompatible Users This is the non-cheap counterpart to section 3.18. A `masked_load` has explicit mask and passthrough semantics, so layout assignment should not clone it as a -normal cheap load unless the registry explicitly marks that clone legal. The -conflict is solved by inserting `ensure_layout` at one use site. +normal cheap load unless a dedicated rematerialization rule proves that clone +legal. The conflict is solved by inserting `ensure_layout` at one use site. VMI input: @@ -4898,10 +4908,11 @@ for r = 0..7: Required assignment rule: ```text -For non-rematerializable producers, assignment must insert a registered -use-site materialization plan, such as contiguous -> deinterleaved=4. If no -plan exists, it must diagnose at assignment time. `vmi-to-vpto` must not clone -the masked_load or choose a materialization after seeing both users. +For non-rematerializable producers, assignment must insert an explicit use-site +materialization helper, such as contiguous -> deinterleaved=4. If that helper +has no supported materialization, the layout gate must diagnose before +vmi-to-vpto. `vmi-to-vpto` must not clone the masked_load or choose a +materialization after seeing both users. ``` ### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator @@ -5267,9 +5278,9 @@ one contiguous value for `masked_load`, and one deinterleaved value for `create_group_mask` by materializing the contiguous grouped predicate chunks and then applying `pdintlv_b32` in the same tree shape as the data `vdintlv`. It does not walk from `group_reduce_addf` to the mask producer to -choose or reject the recipe. +choose or reject the support path. -Assignment may select a deinterleaved S=32 load plan only when the rounded +Assignment may select a deinterleaved S=32 load layout only when the rounded physical reads are memory-safe; otherwise it must diagnose or use a future stable gather fallback. @@ -5437,7 +5448,7 @@ Optimization pass result: ```text // vmi-layout-fold-consumers may remove both ensure_layout ops if the target -// supports a store recipe that consumes deinterleaved=2 and writes contiguous +// supports store lowering that consumes deinterleaved=2 and writes contiguous // row-major memory. pto.vmi.store %t1, %out1[%off] pto.vmi.store %w, %out2[%off] @@ -6002,7 +6013,7 @@ pto.vmi.group_reduce_addi %x8, %mask -> verifier or layout-contract diagnostic ``` -An optimized row-local i8 full-chunk recipe may be added later for +An optimized row-local i8 full-chunk lowering path may be added later for `S = 256` by using widening `vcadd`, but that requires a widening `group_slots` result contract and must not change the baseline cast-to-accumulator semantics above. @@ -6016,6 +6027,6 @@ accumulator computation: pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} ``` -That packed group-slot `trunci` path is not a baseline recipe yet; the -implementation must either define a slot-wise VCVTII recipe or diagnose at +That packed group-slot `trunci` path is not baseline lowering support yet; the +implementation must either define slot-wise VCVTII lowering support or diagnose at layout assignment. diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 9207993a14..78cb8bc78e 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -110,7 +110,8 @@ std::unique_ptr createPTOValidateVPTOEmissionIRPass(); LogicalResult validateVMIProducerBoundaryIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); LogicalResult validateVMILayoutAssignedIR(ModuleOp module, - llvm::raw_ostream *diagOS = nullptr); + llvm::raw_ostream *diagOS = nullptr, + bool verifyHelperSupport = true); std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); std::unique_ptr createVMILayoutAssignmentPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 74c9bf607e..fdbe82b5bf 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -811,10 +811,12 @@ def PTOValidateVMILayoutIR a concrete VMI layout, every VMI mask must have concrete b8/b16/b32 granularity and layout, physical VPTO register values must not appear yet, and VMI typed values must stay inside VMI semantic/helper or structural ops. - vmi-to-vpto chooses deterministic local recipes from the current op's attrs, - operand/result types, layouts, and operand values; non-local choices must - be represented as explicit attrs, helper ops, cloned producers, or - diagnostics before this stage. + vmi-to-vpto chooses deterministic lowerings from the current op's attrs, + operand/result types, layouts, and operand values. Non-local choices must + be represented as explicit attrs, helper ops, or diagnostics before this + stage. Later VMI layout optimization passes may replace helpers with + cloned/rematerialized producers, but the layout gate must not depend on + hidden producer/user context. }]; let constructor = "mlir::pto::createPTOValidateVMILayoutIRPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h new file mode 100644 index 0000000000..9a274a2a9b --- /dev/null +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -0,0 +1,287 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMILayoutSupport.h - VMI layout support queries ------*- C++ -*-===// +//===----------------------------------------------------------------------===// + +#ifndef PTO_TRANSFORMS_VMILAYOUTSUPPORT_H +#define PTO_TRANSFORMS_VMILAYOUTSUPPORT_H + +#include "PTO/IR/PTO.h" +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir::pto { + +class VMITargetCapabilityRegistry; + +enum class VMIContiguousStoreSupportKind { + ContiguousVsts, + Deinterleaved2Vstsx2, + DeinterleavedMaterializeThenVsts, +}; + +struct VMIContiguousStoreSupport { + VMIContiguousStoreSupportKind kind = + VMIContiguousStoreSupportKind::ContiguousVsts; +}; + +enum class VMILayoutMaterializationSupportKind { + Identity, + ContiguousToDeinterleaved, + DeinterleavedToContiguous, +}; + +struct VMILayoutMaterializationSupport { + VMILayoutMaterializationSupportKind kind = + VMILayoutMaterializationSupportKind::Identity; +}; + +enum class VMIMaskGranularityMaterializationSupportKind { + Identity, + PredicateCast, +}; + +struct VMIMaskGranularityMaterializationSupport { + VMIMaskGranularityMaterializationSupportKind kind = + VMIMaskGranularityMaterializationSupportKind::Identity; +}; + +enum class VMICastLayoutKind { + Widen2x, + Widen4x, + Narrow2x, + Narrow4x, +}; + +struct VMICastLayoutFact { + VMICastLayoutKind kind = VMICastLayoutKind::Widen2x; + VMILayoutAttr sourceLayout; + VMILayoutAttr resultLayout; + int64_t sourceBits = 0; + int64_t resultBits = 0; + int64_t factor = 1; +}; + +enum class VMIGroupSlotLoadSupportKind { + Slots8UnitStrideVsldb, + Slots1AlignedLane0Vsldb, +}; + +struct VMIGroupSlotLoadSupport { + VMIGroupSlotLoadSupportKind kind = + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb; +}; + +enum class VMIGroupLoadSupportKind { + S16Block8Vsldb, + S32Block8Vsldb, +}; + +struct VMIGroupLoadSupport { + VMIGroupLoadSupportKind kind = VMIGroupLoadSupportKind::S16Block8Vsldb; +}; + +enum class VMIGroupSlotsStoreSupportKind { + Slots8UnitStrideVsts, + Slots1AlignedLane0Vsts, +}; + +struct VMIGroupSlotsStoreSupport { + VMIGroupSlotsStoreSupportKind kind = + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts; +}; + +enum class VMIGroupReduceLayoutKind { + OneVLane, + TwoVLane, + FourVLane, + RowLocal, +}; + +struct VMIGroupReduceLayoutFact { + VMIGroupReduceLayoutKind kind = VMIGroupReduceLayoutKind::OneVLane; + VMILayoutAttr sourceLayout; + VMILayoutAttr maskLayout; + VMILayoutAttr resultLayout; + int64_t groupSize = 0; + int64_t lanesPerPart = 0; + int64_t vlaneElems = 0; +}; + +enum class VMIGroupReduceAddFSupportKind { + OneVLaneVcgadd, + TwoVLaneDeinterleaved2VcgaddVadd, + FourVLaneDeinterleaved4VcgaddTree, + ContiguousVcaddRows, +}; + +struct VMIGroupReduceAddFSupport { + VMIGroupReduceAddFSupportKind kind = + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd; +}; + +enum class VMIGroupBroadcastSupportKind { + GroupSlotsVselr, +}; + +struct VMIGroupBroadcastSupport { + VMIGroupBroadcastSupportKind kind = + VMIGroupBroadcastSupportKind::GroupSlotsVselr; +}; + +enum class VMITruncFSupportKind { + Deinterleaved2F32ToContiguousF16, + Deinterleaved4F32ToContiguousF8, + GroupSlots1F32ToF16, +}; + +struct VMITruncFSupport { + VMITruncFSupportKind kind = + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16; +}; + +enum class VMIExtFSupportKind { + ContiguousF16ToDeinterleaved2F32, + ContiguousF8ToDeinterleaved4F32, +}; + +struct VMIExtFSupport { + VMIExtFSupportKind kind = + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32; +}; + +enum class VMITruncISupportKind { + Deinterleaved2I32ToContiguousI16, + Deinterleaved4I32ToContiguousI8, + GroupSlots1I32ToI16, +}; + +struct VMITruncISupport { + VMITruncISupportKind kind = + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16; +}; + +enum class VMIExtISupportKind { + ContiguousI16ToDeinterleaved2I32, + ContiguousI8ToDeinterleaved4I32, +}; + +struct VMIExtISupport { + VMIExtISupportKind kind = + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32; +}; + +enum class VMIBitcastSupportKind { + PerPartVbitcast, +}; + +struct VMIBitcastSupport { + VMIBitcastSupportKind kind = VMIBitcastSupportKind::PerPartVbitcast; +}; + +class VMILayoutSupport { +public: + FailureOr + getContiguousStoreSupport(VMIVRegType valueType, + std::string *reason = nullptr) const; + + LogicalResult canFoldContiguousStoreMaterialization( + VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getDataLayoutMaterializationSupport(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskLayoutMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeMaskLayout(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getMaskGranularityMaterializationSupport(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason = nullptr) const; + + LogicalResult canMaterializeMaskGranularity( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason = nullptr) const; + + FailureOr + getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupLoadOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupSlotsStoreSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupStoreOp op, + std::string *reason = nullptr) const; + + FailureOr + getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceAddISupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, + std::string *reason = nullptr) const; + + FailureOr + getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtFSupport(VMIExtFOp op, std::string *reason = nullptr) const; + + FailureOr + getExtSISupport(VMIExtSIOp op, std::string *reason = nullptr) const; + + FailureOr + getExtUISupport(VMIExtUIOp op, std::string *reason = nullptr) const; + + FailureOr + getTruncISupport(VMITruncIOp op, std::string *reason = nullptr) const; + + FailureOr + getBitcastSupport(VMIBitcastOp op, std::string *reason = nullptr) const; +}; + +} // namespace mlir::pto + +#endif // PTO_TRANSFORMS_VMILAYOUTSUPPORT_H diff --git a/include/PTO/Transforms/VMILocalRecipeRegistry.h b/include/PTO/Transforms/VMILocalRecipeRegistry.h deleted file mode 100644 index 8472a32c4c..0000000000 --- a/include/PTO/Transforms/VMILocalRecipeRegistry.h +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under -// the terms and conditions of CANN Open Software License Agreement Version 2.0 -// (the "License"). Please refer to the License for details. You may not use -// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON -// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS -// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository -// for the full text of the License. - -//===- VMILocalRecipeRegistry.h - VMI local recipe queries ------*- C++ -*-===// -//===----------------------------------------------------------------------===// - -#ifndef PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H -#define PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H - -#include "PTO/IR/PTO.h" -#include "mlir/Support/LLVM.h" - -#include - -namespace mlir::pto { - -class VMITargetCapabilityRegistry; - -enum class VMIContiguousStoreRecipeKind { - ContiguousVsts, - Deinterleaved2Vstsx2, - DeinterleavedMaterializeThenVsts, -}; - -struct VMIContiguousStoreRecipe { - VMIContiguousStoreRecipeKind kind = - VMIContiguousStoreRecipeKind::ContiguousVsts; -}; - -enum class VMILayoutMaterializationRecipeKind { - Identity, - ContiguousToDeinterleaved, - DeinterleavedToContiguous, -}; - -struct VMILayoutMaterializationRecipe { - VMILayoutMaterializationRecipeKind kind = - VMILayoutMaterializationRecipeKind::Identity; -}; - -enum class VMIMaskGranularityMaterializationRecipeKind { - Identity, - PredicateCast, -}; - -struct VMIMaskGranularityMaterializationRecipe { - VMIMaskGranularityMaterializationRecipeKind kind = - VMIMaskGranularityMaterializationRecipeKind::Identity; -}; - -enum class VMIGroupSlotLoadRecipeKind { - Slots8UnitStrideVsldb, - Slots1AlignedLane0Vsldb, -}; - -struct VMIGroupSlotLoadRecipe { - VMIGroupSlotLoadRecipeKind kind = - VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb; -}; - -enum class VMIGroupLoadRecipeKind { - S16Block8Vsldb, - S32Block8Vsldb, -}; - -struct VMIGroupLoadRecipe { - VMIGroupLoadRecipeKind kind = VMIGroupLoadRecipeKind::S16Block8Vsldb; -}; - -enum class VMIGroupSlotsStoreRecipeKind { - Slots8UnitStrideVsts, - Slots1AlignedLane0Vsts, -}; - -struct VMIGroupSlotsStoreRecipe { - VMIGroupSlotsStoreRecipeKind kind = - VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts; -}; - -enum class VMIGroupReduceAddFRecipeKind { - OneVLaneVcgadd, - TwoVLaneDeinterleaved2VcgaddVadd, - FourVLaneDeinterleaved4VcgaddTree, - ContiguousVcaddRows, -}; - -struct VMIGroupReduceAddFRecipe { - VMIGroupReduceAddFRecipeKind kind = - VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd; -}; - -enum class VMIGroupBroadcastRecipeKind { - GroupSlotsVselr, -}; - -struct VMIGroupBroadcastRecipe { - VMIGroupBroadcastRecipeKind kind = - VMIGroupBroadcastRecipeKind::GroupSlotsVselr; -}; - -enum class VMITruncFRecipeKind { - Deinterleaved2F32ToContiguousF16, - Deinterleaved4F32ToContiguousF8, - GroupSlots1F32ToF16, -}; - -struct VMITruncFRecipe { - VMITruncFRecipeKind kind = - VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16; -}; - -enum class VMIExtFRecipeKind { - ContiguousF16ToDeinterleaved2F32, - ContiguousF8ToDeinterleaved4F32, -}; - -struct VMIExtFRecipe { - VMIExtFRecipeKind kind = - VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32; -}; - -enum class VMITruncIRecipeKind { - Deinterleaved2I32ToContiguousI16, - Deinterleaved4I32ToContiguousI8, - GroupSlots1I32ToI16, -}; - -struct VMITruncIRecipe { - VMITruncIRecipeKind kind = - VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16; -}; - -enum class VMIExtIRecipeKind { - ContiguousI16ToDeinterleaved2I32, - ContiguousI8ToDeinterleaved4I32, -}; - -struct VMIExtIRecipe { - VMIExtIRecipeKind kind = - VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32; -}; - -enum class VMIBitcastRecipeKind { - PerPartVbitcast, -}; - -struct VMIBitcastRecipe { - VMIBitcastRecipeKind kind = VMIBitcastRecipeKind::PerPartVbitcast; -}; - -class VMILocalRecipeRegistry { -public: - FailureOr - getContiguousStoreRecipe(VMIVRegType valueType, - std::string *reason = nullptr) const; - - LogicalResult canFoldContiguousStoreMaterialization( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; - - FailureOr - getDataLayoutMaterializationRecipe(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; - - FailureOr - getMaskLayoutMaterializationRecipe(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; - - FailureOr - getMaskGranularityMaterializationRecipe(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; - - FailureOr - getGroupSlotLoadRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupSlotLoadOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupLoadRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupLoadOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupSlotsStoreRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupStoreOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupReduceAddFRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddFOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupReduceAddIRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddIOp op, - std::string *reason = nullptr) const; - - FailureOr - getGroupBroadcastRecipe(const VMITargetCapabilityRegistry &capabilities, - VMIGroupBroadcastOp op, - std::string *reason = nullptr) const; - - FailureOr - getTruncFRecipe(VMITruncFOp op, std::string *reason = nullptr) const; - - FailureOr - getExtFRecipe(VMIExtFOp op, std::string *reason = nullptr) const; - - FailureOr - getExtSIRecipe(VMIExtSIOp op, std::string *reason = nullptr) const; - - FailureOr - getExtUIRecipe(VMIExtUIOp op, std::string *reason = nullptr) const; - - FailureOr - getTruncIRecipe(VMITruncIOp op, std::string *reason = nullptr) const; - - FailureOr - getBitcastRecipe(VMIBitcastOp op, std::string *reason = nullptr) const; -}; - -} // namespace mlir::pto - -#endif // PTO_TRANSFORMS_VMILOCALRECIPEREGISTRY_H diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 12dbb7c8e9..3f808c3072 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -39,7 +39,7 @@ add_mlir_dialect_library(PTOTransforms VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp VMILayoutFoldConsumers.cpp - VMILocalRecipeRegistry.cpp + VMILayoutSupport.cpp VMILayoutRematerialize.cpp VMILayoutSinkMaterialization.cpp VMIToVPTO.cpp diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 7234084c47..6fdf6acf07 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -12,7 +12,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -22,6 +22,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/raw_ostream.h" @@ -170,6 +171,33 @@ LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, return failure(); } +LogicalResult emitLayoutSupportContract(Operation *op, + llvm::raw_ostream *diagOS, + Twine message, StringRef reason) { + std::string text; + llvm::raw_string_ostream os(text); + os << message << ": " << reason; + + bool printedAny = false; + auto printValueType = [&](StringRef kind, int64_t index, Type type) { + if (!isVMIType(type)) + return; + if (!printedAny) { + os << "; VMI types:"; + printedAny = true; + } + os << " " << kind << "#" << index << "=" << type; + }; + + for (auto [index, operand] : llvm::enumerate(op->getOperands())) + printValueType("operand", static_cast(index), operand.getType()); + for (auto [index, result] : llvm::enumerate(op->getResults())) + printValueType("result", static_cast(index), result.getType()); + + os.flush(); + return emitLayoutContract(op, diagOS, text); +} + LogicalResult emitHelperMaterializationContract(Operation *helper, Type sourceType, Type resultType, @@ -179,7 +207,7 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, auto emitFallback = [&]() { return emitLayoutContract( helper, diagOS, - Twine(helperName) + " has no registered materialization recipe: " + + Twine(helperName) + " has no registered materialization support: " + reason); }; @@ -192,7 +220,7 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, llvm::raw_string_ostream os(message); os << requester->getName() << " operand #" << use.getOperandNumber() << " has type " << sourceType << " but requires " << resultType << "; " - << helperName << " has no registered materialization recipe: " << reason; + << helperName << " has no registered materialization support: " << reason; os.flush(); InFlightDiagnostic diag = @@ -395,10 +423,10 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, return success(); } -LogicalResult verifyLayoutHelperRecipe(Operation *op, +LogicalResult verifyLayoutHelperSupport(Operation *op, llvm::raw_ostream *diagOS); -LogicalResult verifyLayoutSemanticRecipe(Operation *op, +LogicalResult verifyLayoutSemanticSupport(Operation *op, llvm::raw_ostream *diagOS); LogicalResult verifyOperationBoundary(Operation *op, @@ -422,7 +450,8 @@ LogicalResult verifyOperationBoundary(Operation *op, } LogicalResult verifyLayoutAssignedOperation(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS, + bool verifyHelperSupports = true) { if (failed(verifyLayoutAssignedOperationTypes(op, diagOS))) return failure(); @@ -431,14 +460,15 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) - return verifyLayoutHelperRecipe(op, diagOS); + return verifyHelperSupports ? verifyLayoutHelperSupport(op, diagOS) + : success(); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); } if (isVMISemanticOp(op)) - return verifyLayoutSemanticRecipe(op, diagOS); + return verifyLayoutSemanticSupport(op, diagOS); if (isStructuralOp(op)) return success(); @@ -446,17 +476,16 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, "VMI typed value is used by a non-VMI semantic op"); } -LogicalResult verifyLayoutHelperRecipe(Operation *op, +LogicalResult verifyLayoutHelperSupport(Operation *op, llvm::raw_ostream *diagOS) { - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; if (auto ensure = dyn_cast(op)) { auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getDataLayoutMaterializationRecipe(sourceType, - resultType, - &reason))) + if (failed(supports.canMaterializeDataLayout(sourceType, resultType, + &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); return success(); @@ -466,9 +495,8 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getMaskLayoutMaterializationRecipe(sourceType, - resultType, - &reason))) + if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, + &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, diagOS); @@ -479,12 +507,12 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(recipes.getMaskGranularityMaterializationRecipe( - sourceType, resultType, &reason))) + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &reason))) return emitLayoutContract( op, diagOS, Twine("pto.vmi.ensure_mask_granularity has no registered " - "materialization recipe: ") + + "materialization support: ") + reason); return success(); } @@ -492,9 +520,9 @@ LogicalResult verifyLayoutHelperRecipe(Operation *op, return success(); } -LogicalResult verifyLayoutSemanticRecipe(Operation *op, +LogicalResult verifyLayoutSemanticSupport(Operation *op, llvm::raw_ostream *diagOS) { - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; VMITargetCapabilityRegistry capabilities; if (auto store = dyn_cast(op)) { @@ -504,12 +532,11 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) - return emitLayoutContract( + if (failed(supports.getContiguousStoreSupport(valueType, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.store has no registered contiguous-memory local " - "recipe: ") + - reason); + "pto.vmi.store has no registered contiguous-memory layout support", + reason); return success(); } @@ -520,12 +547,12 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getContiguousStoreRecipe(valueType, &reason))) - return emitLayoutContract( + if (failed(supports.getContiguousStoreSupport(valueType, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.tile_write has no registered contiguous-memory local " - "recipe: ") + - reason); + "pto.vmi.tile_write has no registered contiguous-memory layout " + "support", + reason); return success(); } @@ -537,21 +564,19 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupLoadRecipe(capabilities, load, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_load has no registered block8 local recipe: ") + - reason); + "pto.vmi.group_load has no registered block8 layout support", reason); return success(); } if (auto load = dyn_cast(op)) { std::string reason; - if (failed(recipes.getGroupSlotLoadRecipe(capabilities, load, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupSlotLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_slot_load has no registered local recipe: ") + - reason); + "pto.vmi.group_slot_load has no registered layout support", reason); return success(); } @@ -562,12 +587,11 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, store, &reason))) - return emitLayoutContract( + if (failed(supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_store has no registered group_slots local " - "recipe: ") + - reason); + "pto.vmi.group_store has no registered group_slots layout support", + reason); return success(); } @@ -578,13 +602,13 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupReduceAddFRecipe(capabilities, reduce, + if (failed(supports.getGroupReduceAddFSupport(capabilities, reduce, &reason))) - return emitLayoutContract( + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_reduce_addf has no registered group_slots " - "local recipe: ") + - reason); + "pto.vmi.group_reduce_addf has no registered group_slots layout " + "support", + reason); return success(); } @@ -595,39 +619,37 @@ LogicalResult verifyLayoutSemanticRecipe(Operation *op, return success(); std::string reason; - if (failed(recipes.getGroupBroadcastRecipe(capabilities, broadcast, + if (failed(supports.getGroupBroadcastSupport(capabilities, broadcast, &reason))) - return emitLayoutContract( + return emitLayoutSupportContract( op, diagOS, - Twine("pto.vmi.group_broadcast has no registered local recipe: ") + - reason); + "pto.vmi.group_broadcast has no registered layout support", reason); return success(); } if (auto truncf = dyn_cast(op)) { std::string reason; - if (failed(recipes.getTruncFRecipe(truncf, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.truncf has no registered local recipe: ") + reason); + if (failed(supports.getTruncFSupport(truncf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.truncf has no registered layout support", + reason); return success(); } if (auto extf = dyn_cast(op)) { std::string reason; - if (failed(recipes.getExtFRecipe(extf, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.extf has no registered local recipe: ") + reason); + if (failed(supports.getExtFSupport(extf, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.extf has no registered layout support", reason); return success(); } if (auto bitcast = dyn_cast(op)) { std::string reason; - if (failed(recipes.getBitcastRecipe(bitcast, &reason))) - return emitLayoutContract( - op, diagOS, - Twine("pto.vmi.bitcast has no registered local recipe: ") + reason); + if (failed(supports.getBitcastSupport(bitcast, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.bitcast has no registered layout support", + reason); return success(); } @@ -668,9 +690,10 @@ LogicalResult mlir::pto::validateVMIProducerBoundaryIR( } LogicalResult mlir::pto::validateVMILayoutAssignedIR( - ModuleOp module, llvm::raw_ostream *diagOS) { + ModuleOp module, llvm::raw_ostream *diagOS, bool verifyHelperSupports) { WalkResult result = module.walk([&](Operation *op) { - if (failed(verifyLayoutAssignedOperation(op, diagOS))) + if (failed(verifyLayoutAssignedOperation(op, diagOS, + verifyHelperSupports))) return WalkResult::interrupt(); return WalkResult::advance(); }); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 5f30ba82e0..eb3593c9ee 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -15,6 +15,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -248,19 +249,11 @@ struct LayoutSolver { if (VMILayoutAttr existing = type.getLayoutAttr()) if (existing.isGroupSlots() && existing.getSlots() > 0) return existing; - if (numGroups > 0 && type.getElementCount() % numGroups == 0) { - int64_t groupSize = type.getElementCount() / numGroups; - std::optional vlaneElems = getVLaneElems(type.getElementType()); - if (vlaneElems && (groupSize == *vlaneElems || - groupSize == 2 * *vlaneElems || - groupSize == 4 * *vlaneElems)) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (succeeded(lanesPerPart) && groupSize >= *lanesPerPart && - groupSize % *lanesPerPart == 0) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); - } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->resultLayout; return getGroupSlotsLayout(numGroups); } @@ -268,14 +261,11 @@ struct LayoutSolver { int64_t numGroups) { if (VMILayoutAttr existing = type.getLayoutAttr()) return existing; - if (numGroups > 0 && type.getElementCount() % numGroups == 0) { - int64_t groupSize = type.getElementCount() / numGroups; - std::optional vlaneElems = getVLaneElems(type.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems) - return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); - if (vlaneElems && groupSize == 4 * *vlaneElems) - return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); - } + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(type, numGroups); + if (succeeded(fact)) + return fact->sourceLayout; return getContiguousLayout(); } @@ -406,6 +396,27 @@ struct LayoutSolver { return false; } + bool isCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact, + VMILayoutAttr layout) { + if (!layout) + return false; + if (fact.kind == VMIGroupReduceLayoutKind::OneVLane || + fact.kind == VMIGroupReduceLayoutKind::RowLocal) + return layout.isContiguous(); + int64_t factor = fact.kind == VMIGroupReduceLayoutKind::TwoVLane ? 2 : 4; + return layout.isDeinterleaved() && layout.getFactor() == factor && + (layout.getBlockElems() == 1 || layout.getBlockElems() == 8); + } + + VMILayoutAttr getTruncFCompatibleGroupReduceSourceLayout( + VMIGroupReduceLayoutFact fact) { + if (fact.kind == VMIGroupReduceLayoutKind::TwoVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + if (fact.kind == VMIGroupReduceLayoutKind::FourVLane) + return VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + return {}; + } + LogicalResult requestMask(Value mask, VMILayoutAttr layout, StringRef granularity, Operation *op) { unsigned id = addMaskValue(mask); @@ -813,41 +824,23 @@ struct LayoutSolver { if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, reduce.getNumGroupsAttr().getInt()); + sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); - int64_t numGroups = reduce.getNumGroupsAttr().getInt(); - if (solvedSourceLayout && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 2 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - if (vlaneElems && groupSize == 4 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 4 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - } else if (!sourceType.getLayoutAttr() && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), - groupSize)) { - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems) - sourceLayout = - VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); - if (vlaneElems && groupSize == 4 * *vlaneElems) - sourceLayout = - VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/1); + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; } } requestDataUse(reduce.getSourceMutable(), sourceLayout); @@ -857,8 +850,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - getPreferredGroupSlotsLayout( - resultType, reduce.getNumGroupsAttr().getInt()), + succeeded(fact) ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, + numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -866,29 +860,17 @@ struct LayoutSolver { if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, reduce.getNumGroupsAttr().getInt()); + sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); - int64_t numGroups = reduce.getNumGroupsAttr().getInt(); - if (solvedSourceLayout && numGroups > 0 && - sourceType.getElementCount() % numGroups == 0) { - int64_t groupSize = sourceType.getElementCount() / numGroups; - std::optional vlaneElems = - getVLaneElems(sourceType.getElementType()); - if (vlaneElems && groupSize == 2 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 2 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - if (vlaneElems && groupSize == 4 * *vlaneElems && - solvedSourceLayout.isDeinterleaved() && - solvedSourceLayout.getFactor() == 4 && - (solvedSourceLayout.getBlockElems() == 1 || - solvedSourceLayout.getBlockElems() == 8)) - sourceLayout = solvedSourceLayout; - } + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) + sourceLayout = solvedSourceLayout; requestDataUse(reduce.getSourceMutable(), sourceLayout); if (failed(requestMaskUse( reduce.getMaskMutable(), sourceLayout, @@ -896,8 +878,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - getPreferredGroupSlotsLayout( - resultType, reduce.getNumGroupsAttr().getInt()), + succeeded(fact) ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, + numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -912,19 +895,15 @@ struct LayoutSolver { if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extf.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extf.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extf.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extf.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extf.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extf.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -932,19 +911,15 @@ struct LayoutSolver { if (auto extsi = dyn_cast(op)) { auto sourceType = cast(extsi.getSource().getType()); auto resultType = cast(extsi.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extsi.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extsi.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extsi.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extsi.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extsi.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -952,19 +927,15 @@ struct LayoutSolver { if (auto extui = dyn_cast(op)) { auto sourceType = cast(extui.getSource().getType()); auto resultType = cast(extui.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32) { - requestDataUse(extui.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extui.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 2), - op))) - return WalkResult::interrupt(); - } else if (sourceBits == 8 && resultBits == 32) { - requestDataUse(extui.getSourceMutable(), getContiguousLayout()); - if (failed(setNaturalLayout(extui.getResult(), - VMILayoutAttr::getDeinterleaved(ctx, 4), - op))) + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { + requestDataUse(extui.getSourceMutable(), fact->sourceLayout); + if (failed( + setNaturalLayout(extui.getResult(), fact->resultLayout, op))) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -972,48 +943,50 @@ struct LayoutSolver { if (auto truncf = dyn_cast(op)) { auto sourceType = cast(truncf.getSource().getType()); auto resultType = cast(truncf.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); - if (sourceBits == 32 && resultBits == 16 && sourceLayout && + if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout && sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { requestDataUse(truncf.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (sourceBits == 32 && resultBits == 16) - requestDataUse(truncf.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 2)); - else if (sourceBits == 32 && resultBits == 8) - requestDataUse(truncf.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 4)); - if (failed(setNaturalLayout(truncf.getResult(), getContiguousLayout(), - op))) + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) + requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); + VMILayoutAttr resultLayout = + succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(truncf.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto trunci = dyn_cast(op)) { auto sourceType = cast(trunci.getSource().getType()); auto resultType = cast(trunci.getResult().getType()); - unsigned sourceBits = getElementBitWidth(sourceType.getElementType()); - unsigned resultBits = getElementBitWidth(resultType.getElementType()); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); - if (sourceBits == 32 && resultBits == 16 && sourceLayout && + if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout && sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { requestDataUse(trunci.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (sourceBits == 32 && resultBits == 16) - requestDataUse(trunci.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 2)); - else if (sourceBits == 32 && resultBits == 8) - requestDataUse(trunci.getSourceMutable(), - VMILayoutAttr::getDeinterleaved(ctx, 4)); - if (failed(setNaturalLayout(trunci.getResult(), getContiguousLayout(), - op))) + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) + requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); + VMILayoutAttr resultLayout = + succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -1447,27 +1420,6 @@ struct LayoutSolver { } } - std::optional rematerializeDataUse(Value value, VMIVRegType resultType, - Location loc, OpBuilder &builder) { - if (auto constant = value.getDefiningOp()) { - auto denseAttr = dyn_cast(constant.getValue()); - if (denseAttr && denseAttr.isSplat()) - return builder - .create(loc, resultType, constant.getValue()) - .getResult(); - } - if (auto broadcast = value.getDefiningOp()) - return builder - .create(loc, resultType, broadcast.getValue()) - .getResult(); - if (auto iota = value.getDefiningOp()) - return builder - .create(loc, resultType, iota.getBase(), - iota.getOrderAttr()) - .getResult(); - return std::nullopt; - } - LogicalResult insertDataUseMaterializations() { OpBuilder builder(ctx); for (DataUseRequest request : dataUseRequests) { @@ -1488,12 +1440,6 @@ struct LayoutSolver { VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), request.layout); builder.setInsertionPoint(request.operand->getOwner()); - std::optional rematerialized = rematerializeDataUse( - value, resultType, request.operand->getOwner()->getLoc(), builder); - if (rematerialized) { - request.operand->set(*rematerialized); - continue; - } auto ensure = builder.create( request.operand->getOwner()->getLoc(), resultType, value); request.operand->set(ensure.getResult()); @@ -1625,27 +1571,6 @@ struct LayoutSolver { } } - std::optional rematerializeMaskUse(Value value, VMIMaskType resultType, - Location loc, OpBuilder &builder) { - if (auto createMask = value.getDefiningOp()) - return builder - .create(loc, resultType, createMask.getActiveLanes()) - .getResult(); - if (auto createGroupMask = value.getDefiningOp()) - return builder - .create( - loc, resultType, createGroupMask.getActiveElemsPerGroup(), - createGroupMask.getNumGroupsAttr(), - createGroupMask.getGroupSizeAttr()) - .getResult(); - if (auto constantMask = value.getDefiningOp()) - return builder - .create(loc, resultType, - constantMask.getValueAttr()) - .getResult(); - return std::nullopt; - } - LogicalResult insertMaskUseMaterializations() { OpBuilder builder(ctx); for (MaskUseRequest request : maskUseRequests) { @@ -1663,19 +1588,6 @@ struct LayoutSolver { builder.setInsertionPoint(request.operand->getOwner()); Value current = value; VMIMaskType currentType = sourceType; - auto requestedType = - VMIMaskType::get(ctx, sourceType.getElementCount(), - request.granularity, request.layout); - if (sourceType != requestedType) { - std::optional rematerialized = rematerializeMaskUse( - value, requestedType, request.operand->getOwner()->getLoc(), - builder); - if (rematerialized) { - request.operand->set(*rematerialized); - continue; - } - } - if (sourceLayout != request.layout) { auto layoutType = VMIMaskType::get(ctx, currentType.getElementCount(), @@ -1753,7 +1665,8 @@ struct LayoutSolver { if (failed(insertMaskUseMaterializations())) return failure(); rewriteFunctionType(); - return validateVMILayoutAssignedIR(module); + return validateVMILayoutAssignedIR(module, /*diagOS=*/nullptr, + /*verifyHelperSupport=*/false); } ModuleOp module; diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp index 26536f196d..fda374f661 100644 --- a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +++ b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp @@ -14,7 +14,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -43,9 +43,9 @@ static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { if (!sourceType || !resultType) return false; - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; return succeeded( - recipes.canFoldContiguousStoreMaterialization(sourceType, resultType)); + supports.canFoldContiguousStoreMaterialization(sourceType, resultType)); } static void tryFoldEnsureLayoutIntoOperand( diff --git a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp index c3bbf67731..3027d919f7 100644 --- a/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp +++ b/lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -13,6 +13,7 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -39,6 +40,18 @@ struct BinaryVRegOperands { OpOperand *rhs = nullptr; }; +struct TernaryVRegOperands { + OpOperand *lhs = nullptr; + OpOperand *rhs = nullptr; + OpOperand *acc = nullptr; +}; + +struct SelectOperands { + OpOperand *mask = nullptr; + OpOperand *trueValue = nullptr; + OpOperand *falseValue = nullptr; +}; + struct UnaryVRegOperand { OpOperand *source = nullptr; }; @@ -84,6 +97,31 @@ static std::optional getSinkableBinaryOperands(Operation *op return std::nullopt; } +static std::optional +getSinkableCompareOperands(Operation *op) { + if (auto cmpf = dyn_cast(op)) + return BinaryVRegOperands{&cmpf.getLhsMutable(), &cmpf.getRhsMutable()}; + if (auto cmpi = dyn_cast(op)) + return BinaryVRegOperands{&cmpi.getLhsMutable(), &cmpi.getRhsMutable()}; + return std::nullopt; +} + +static std::optional getSinkableSelectOperands(Operation *op) { + if (auto select = dyn_cast(op)) + return SelectOperands{&select.getMaskMutable(), + &select.getTrueValueMutable(), + &select.getFalseValueMutable()}; + return std::nullopt; +} + +static std::optional +getSinkableTernaryOperands(Operation *op) { + if (auto fma = dyn_cast(op)) + return TernaryVRegOperands{&fma.getLhsMutable(), &fma.getRhsMutable(), + &fma.getAccMutable()}; + return std::nullopt; +} + static std::optional getSinkableUnaryOperand(Operation *op) { if (auto negf = dyn_cast(op)) return UnaryVRegOperand{&negf.getSourceMutable()}; @@ -155,6 +193,34 @@ static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, lhsResultType == resultType && lhsSourceType != resultType; } +static bool isSameMaterialization(VMIEnsureLayoutOp lhsEnsure, + VMIEnsureLayoutOp rhsEnsure, + VMIEnsureLayoutOp accEnsure, + VMIVRegType resultType) { + if (!lhsEnsure || !rhsEnsure || !accEnsure || !resultType) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto accSourceType = dyn_cast(accEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + auto accResultType = dyn_cast(accEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !accSourceType || !lhsResultType || + !rhsResultType || !accResultType) + return false; + + return lhsSourceType == rhsSourceType && lhsSourceType == accSourceType && + lhsResultType == rhsResultType && lhsResultType == accResultType && + lhsResultType == resultType && lhsSourceType != resultType; +} + +static bool canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeDataLayout(sourceType, resultType)); +} + template static bool isSameMaskMaterialization(EnsureOp ensure, VMIMaskType resultType) { if (!ensure || !resultType) @@ -185,6 +251,20 @@ static bool isSameMaskMaterialization(EnsureOp lhsEnsure, EnsureOp rhsEnsure, lhsResultType == resultType && lhsSourceType != resultType; } +static bool canMaterializeMask(VMIEnsureMaskLayoutOp, VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded(supports.canMaterializeMaskLayout(sourceType, resultType)); +} + +static bool canMaterializeMask(VMIEnsureMaskGranularityOp, + VMIMaskType sourceType, + VMIMaskType resultType) { + VMILayoutSupport supports; + return succeeded( + supports.canMaterializeMaskGranularity(sourceType, resultType)); +} + static bool trySinkBinaryMaterialization(Operation *op) { std::optional operands = getSinkableBinaryOperands(op); if (!operands || op->getNumResults() != 1) @@ -200,9 +280,175 @@ static bool trySinkBinaryMaterialization(Operation *op) { return false; auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkSelectMaterialization(Operation *op) { + std::optional operands = getSinkableSelectOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto maskEnsure = + operands->mask->get().getDefiningOp(); + auto trueEnsure = + operands->trueValue->get().getDefiningOp(); + auto falseEnsure = + operands->falseValue->get().getDefiningOp(); + if (!maskEnsure || !trueEnsure || !falseEnsure) + return false; + + auto trueSourceType = dyn_cast(trueEnsure.getSource().getType()); + auto falseSourceType = + dyn_cast(falseEnsure.getSource().getType()); + auto trueResultType = dyn_cast(trueEnsure.getResult().getType()); + auto falseResultType = + dyn_cast(falseEnsure.getResult().getType()); + auto maskSourceType = dyn_cast(maskEnsure.getSource().getType()); + auto maskResultType = dyn_cast(maskEnsure.getResult().getType()); + if (!trueSourceType || !falseSourceType || !trueResultType || + !falseResultType || !maskSourceType || !maskResultType) + return false; + + if (trueSourceType != falseSourceType || trueResultType != falseResultType || + trueResultType != resultType || trueSourceType == resultType) + return false; + if (maskResultType != operands->mask->get().getType()) + return false; + if (maskResultType.getLayoutAttr() != resultType.getLayoutAttr() || + maskSourceType.getLayoutAttr() != trueSourceType.getLayoutAttr()) + return false; + if (maskSourceType.getElementCount() != trueSourceType.getElementCount() || + maskResultType.getElementCount() != resultType.getElementCount() || + maskSourceType.getGranularity() != maskResultType.getGranularity()) + return false; + if (!canMaterializeDataLayout(trueSourceType, resultType) || + !canMaterializeMask(maskEnsure, maskSourceType, maskResultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands({maskEnsure.getSource(), trueEnsure.getSource(), + falseEnsure.getSource()}); + state.addTypes(trueSourceType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (maskEnsure->use_empty()) + maskEnsure.erase(); + if (trueEnsure->use_empty()) + trueEnsure.erase(); + if (falseEnsure != trueEnsure && falseEnsure->use_empty()) + falseEnsure.erase(); + return true; +} + +static bool trySinkCompareMaterialization(Operation *op) { + std::optional operands = getSinkableCompareOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultMaskType = dyn_cast(op->getResult(0).getType()); + if (!resultMaskType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + if (!lhsEnsure || !rhsEnsure) + return false; + + auto lhsSourceType = dyn_cast(lhsEnsure.getSource().getType()); + auto rhsSourceType = dyn_cast(rhsEnsure.getSource().getType()); + auto lhsResultType = dyn_cast(lhsEnsure.getResult().getType()); + auto rhsResultType = dyn_cast(rhsEnsure.getResult().getType()); + if (!lhsSourceType || !rhsSourceType || !lhsResultType || !rhsResultType) + return false; + if (lhsSourceType != rhsSourceType || lhsResultType != rhsResultType || + lhsSourceType == lhsResultType) + return false; + if (lhsResultType.getElementCount() != resultMaskType.getElementCount() || + lhsResultType.getLayoutAttr() != resultMaskType.getLayoutAttr()) + return false; + + auto sourceMaskType = VMIMaskType::get( + op->getContext(), resultMaskType.getElementCount(), + resultMaskType.getGranularity(), lhsSourceType.getLayoutAttr()); + VMILayoutSupport supports; + if (failed(supports.canMaterializeMaskLayout(sourceMaskType, resultMaskType))) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); + state.addTypes(sourceMaskType); + state.addAttributes(op->getAttrs()); + Operation *newOp = builder.create(state); + + builder.setInsertionPointAfter(newOp); + auto resultEnsure = builder.create( + op->getLoc(), resultMaskType, newOp->getResult(0)); + op->getResult(0).replaceAllUsesWith(resultEnsure.getResult()); + op->erase(); + + if (lhsEnsure->use_empty()) + lhsEnsure.erase(); + if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) + rhsEnsure.erase(); + return true; +} + +static bool trySinkTernaryMaterialization(Operation *op) { + std::optional operands = getSinkableTernaryOperands(op); + if (!operands || op->getNumResults() != 1) + return false; + + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType) + return false; + + auto lhsEnsure = operands->lhs->get().getDefiningOp(); + auto rhsEnsure = operands->rhs->get().getDefiningOp(); + auto accEnsure = operands->acc->get().getDefiningOp(); + if (!isSameMaterialization(lhsEnsure, rhsEnsure, accEnsure, resultType)) + return false; + + auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + + OpBuilder builder(op); + OperationState state(op->getLoc(), op->getName()); + state.addOperands( + {lhsEnsure.getSource(), rhsEnsure.getSource(), accEnsure.getSource()}); state.addTypes(sourceType); state.addAttributes(op->getAttrs()); Operation *newOp = builder.create(state); @@ -217,6 +463,9 @@ static bool trySinkBinaryMaterialization(Operation *op) { lhsEnsure.erase(); if (rhsEnsure != lhsEnsure && rhsEnsure->use_empty()) rhsEnsure.erase(); + if (accEnsure != lhsEnsure && accEnsure != rhsEnsure && + accEnsure->use_empty()) + accEnsure.erase(); return true; } @@ -236,6 +485,9 @@ static bool trySinkBinaryMaskMaterialization(Operation *op) { return false; auto sourceType = cast(lhsEnsure.getSource().getType()); + if (!canMaterializeMask(lhsEnsure, sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands({lhsEnsure.getSource(), rhsEnsure.getSource()}); @@ -271,6 +523,9 @@ static bool trySinkUnaryMaterialization(Operation *op) { return false; auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeDataLayout(sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands(sourceEnsure.getSource()); @@ -305,6 +560,9 @@ static bool trySinkUnaryMaskMaterialization(Operation *op) { return false; auto sourceType = cast(sourceEnsure.getSource().getType()); + if (!canMaterializeMask(sourceEnsure, sourceType, resultType)) + return false; + OpBuilder builder(op); OperationState state(op->getLoc(), op->getName()); state.addOperands(sourceEnsure.getSource()); @@ -340,8 +598,10 @@ struct VMILayoutSinkMaterializationPass ModuleOp module = getOperation(); SmallVector candidates; module.walk([&](Operation *op) { - if (getSinkableBinaryOperands(op) || getSinkableUnaryOperand(op) || - getSinkableBinaryMaskOperands(op) || getSinkableUnaryMaskOperand(op)) + if (getSinkableBinaryOperands(op) || getSinkableCompareOperands(op) || + getSinkableSelectOperands(op) || getSinkableTernaryOperands(op) || + getSinkableUnaryOperand(op) || getSinkableBinaryMaskOperands(op) || + getSinkableUnaryMaskOperand(op)) candidates.push_back(op); }); @@ -349,8 +609,14 @@ struct VMILayoutSinkMaterializationPass if (op->getBlock() == nullptr) continue; if (!trySinkBinaryMaterialization(op)) { - if (!trySinkUnaryMaterialization(op)) - trySinkMaskMaterialization(op); + if (!trySinkCompareMaterialization(op)) { + if (!trySinkSelectMaterialization(op)) { + if (!trySinkTernaryMaterialization(op)) { + if (!trySinkUnaryMaterialization(op)) + trySinkMaskMaterialization(op); + } + } + } } } } diff --git a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp similarity index 73% rename from lib/PTO/Transforms/VMILocalRecipeRegistry.cpp rename to lib/PTO/Transforms/VMILayoutSupport.cpp index 34b843737c..27a994ba55 100644 --- a/lib/PTO/Transforms/VMILocalRecipeRegistry.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -8,10 +8,10 @@ // FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository // for the full text of the License. -//===- VMILocalRecipeRegistry.cpp - VMI local recipe queries --------------===// +//===- VMILayoutSupport.cpp - VMI layout support queries --------------===// //===----------------------------------------------------------------------===// -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" @@ -311,12 +311,12 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { return bits; } -static FailureOr -getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, +static FailureOr +getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, std::string *reason) { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -325,26 +325,156 @@ getLayoutMaterializationRecipe(VMILayoutAttr sourceLayout, if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); if (sourceLayout == resultLayout) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::Identity}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::Identity}; if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::ContiguousToDeinterleaved}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::ContiguousToDeinterleaved}; if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) - return VMILayoutMaterializationRecipe{ - VMILayoutMaterializationRecipeKind::DeinterleavedToContiguous}; + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::DeinterleavedToContiguous}; return fail("unsupported source/result layout pair"); } } // namespace -FailureOr -VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, +FailureOr +VMILayoutSupport::getPreferredGroupReduceLayoutFact( + VMIVRegType sourceType, int64_t numGroups, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + FailureOr groupSize = + getGroupSizeFromNumGroups(sourceType, numGroups, reason); + if (failed(groupSize)) + return failure(); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) + return fail("requires element type with known physical VLane width"); + + MLIRContext *ctx = sourceType.getContext(); + int64_t vlaneElems = *lanesPerPart / 8; + VMIGroupReduceLayoutFact fact; + fact.groupSize = *groupSize; + fact.lanesPerPart = *lanesPerPart; + fact.vlaneElems = vlaneElems; + + if (*groupSize == vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::OneVLane; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 2 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::TwoVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize == 4 * vlaneElems) { + fact.kind = VMIGroupReduceLayoutKind::FourVLane; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, 4, /*blockElems=*/8); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return fact; + } + + if (*groupSize >= *lanesPerPart && *groupSize % *lanesPerPart == 0) { + fact.kind = VMIGroupReduceLayoutKind::RowLocal; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.maskLayout = fact.sourceLayout; + fact.resultLayout = + VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); + return fact; + } + + return fail("group_reduce layout supports group sizes VLaneElems, " + "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); +} + +FailureOr +VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (sourceBits == 0 || resultBits == 0) + return fail("requires source/result element types with known storage width"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + + MLIRContext *ctx = sourceType.getContext(); + VMICastLayoutFact fact; + fact.sourceBits = sourceBits; + fact.resultBits = resultBits; + + if (resultBits == 32 && sourceBits == 16) { + fact.kind = VMICastLayoutKind::Widen2x; + fact.factor = 2; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if (resultBits == 32 && sourceBits == 8) { + fact.kind = VMICastLayoutKind::Widen4x; + fact.factor = 4; + fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); + fact.resultLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + return fact; + } + + if (sourceBits == 32 && resultBits == 16) { + fact.kind = VMICastLayoutKind::Narrow2x; + fact.factor = 2; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + if (sourceBits == 32 && resultBits == 8) { + fact.kind = VMICastLayoutKind::Narrow4x; + fact.factor = 4; + fact.sourceLayout = + VMILayoutAttr::getDeinterleaved(ctx, fact.factor, /*blockElems=*/1); + fact.resultLayout = VMILayoutAttr::getContiguous(ctx); + return fact; + } + + return fail("supports only 8/16-bit <-> 32-bit dense cast layout facts"); +} + +FailureOr +VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -354,8 +484,8 @@ VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, if (!layout) return fail("requires assigned value layout"); if (layout.isContiguous()) - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::ContiguousVsts}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::ContiguousVsts}; if (!layout.isDeinterleaved()) return fail("requires contiguous or deinterleaved value layout"); if (layout.getBlockElems() != 1) @@ -366,18 +496,18 @@ VMILocalRecipeRegistry::getContiguousStoreRecipe(VMIVRegType valueType, if (layout.getFactor() == 2) { if (!hasX2MemoryDistToken(valueType.getElementType())) return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2}; } if (layout.getFactor() == 4) - return VMIContiguousStoreRecipe{ - VMIContiguousStoreRecipeKind::DeinterleavedMaterializeThenVsts}; + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::DeinterleavedMaterializeThenVsts}; return fail("requires deinterleaved factor 2 or 4"); } -LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( +LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { if (sourceType.getElementType() != resultType.getElementType()) return failWithReason("source/result element types must match", reason); @@ -388,22 +518,22 @@ LogicalResult VMILocalRecipeRegistry::canFoldContiguousStoreMaterialization( if (!resultLayout || !resultLayout.isContiguous()) return failWithReason("result layout must be contiguous", reason); - FailureOr recipe = - getContiguousStoreRecipe(sourceType, reason); - if (failed(recipe)) + FailureOr support = + getContiguousStoreSupport(sourceType, reason); + if (failed(support)) return failure(); - if (recipe->kind == VMIContiguousStoreRecipeKind::ContiguousVsts) + if (support->kind == VMIContiguousStoreSupportKind::ContiguousVsts) return failWithReason("source layout is already contiguous", reason); return success(); } -FailureOr -VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( +FailureOr +VMILayoutSupport::getDataLayoutMaterializationSupport( VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -416,23 +546,33 @@ VMILocalRecipeRegistry::getDataLayoutMaterializationRecipe( VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr recipe = - getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); - if (failed(recipe)) + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (failed(support)) return failure(); if (failed(checkLayoutMaterializationShape(sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); - return recipe; + return support; } -FailureOr -VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( +LogicalResult +VMILayoutSupport::canMaterializeDataLayout(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason) const { + if (failed(getDataLayoutMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); +} + +FailureOr +VMILayoutSupport::getMaskLayoutMaterializationSupport( VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -445,23 +585,33 @@ VMILocalRecipeRegistry::getMaskLayoutMaterializationRecipe( VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - FailureOr recipe = - getLayoutMaterializationRecipe(sourceLayout, resultLayout, reason); - if (failed(recipe)) + FailureOr support = + getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); + if (failed(support)) return failure(); if (failed(checkLayoutMaterializationShape(sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); - return recipe; + return support; +} + +LogicalResult +VMILayoutSupport::canMaterializeMaskLayout(VMIMaskType sourceType, + VMIMaskType resultType, + std::string *reason) const { + if (failed(getMaskLayoutMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); } -FailureOr -VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( +FailureOr +VMILayoutSupport::getMaskGranularityMaterializationSupport( VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) - -> FailureOr { + -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -475,18 +625,27 @@ VMILocalRecipeRegistry::getMaskGranularityMaterializationRecipe( !VMIMaskType::isConcreteGranularity(resultType.getGranularity())) return fail("requires concrete b8/b16/b32 source and result granularities"); if (sourceType.getGranularity() == resultType.getGranularity()) - return VMIMaskGranularityMaterializationRecipe{ - VMIMaskGranularityMaterializationRecipeKind::Identity}; + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::Identity}; - return VMIMaskGranularityMaterializationRecipe{ - VMIMaskGranularityMaterializationRecipeKind::PredicateCast}; + return VMIMaskGranularityMaterializationSupport{ + VMIMaskGranularityMaterializationSupportKind::PredicateCast}; +} + +LogicalResult VMILayoutSupport::canMaterializeMaskGranularity( + VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason) const { + if (failed(getMaskGranularityMaterializationSupport(sourceType, resultType, + reason))) + return failure(); + return success(); } -FailureOr -VMILocalRecipeRegistry::getGroupSlotLoadRecipe( +FailureOr +VMILayoutSupport::getGroupSlotLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -515,8 +674,8 @@ VMILocalRecipeRegistry::getGroupSlotLoadRecipe( if (!stride || *stride != 1) return fail("slots=8 group_slot_load requires constant unit " "source_group_stride"); - return VMIGroupSlotLoadRecipe{ - VMIGroupSlotLoadRecipeKind::Slots8UnitStrideVsldb}; + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots8UnitStrideVsldb}; } unsigned elementBits = @@ -533,14 +692,14 @@ VMILocalRecipeRegistry::getGroupSlotLoadRecipe( " elements for 32B load alignment; packed or unaligned " "scalar load lowering is not implemented"); - return VMIGroupSlotLoadRecipe{ - VMIGroupSlotLoadRecipeKind::Slots1AlignedLane0Vsldb}; + return VMIGroupSlotLoadSupport{ + VMIGroupSlotLoadSupportKind::Slots1AlignedLane0Vsldb}; } -FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( +FailureOr VMILayoutSupport::getGroupLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupLoadOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -584,16 +743,16 @@ FailureOr VMILocalRecipeRegistry::getGroupLoadRecipe( fullChunkReason); if (*groupSize == 16) - return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S16Block8Vsldb}; - return VMIGroupLoadRecipe{VMIGroupLoadRecipeKind::S32Block8Vsldb}; + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S16Block8Vsldb}; + return VMIGroupLoadSupport{VMIGroupLoadSupportKind::S32Block8Vsldb}; } -FailureOr -VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( +FailureOr +VMILayoutSupport::getGroupSlotsStoreSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, std::string *reason) const { auto fail = - [&](const Twine &message) -> FailureOr { + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -636,8 +795,8 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( Twine(alignedStrideElems) + " elements for 32B store alignment; packed or unaligned " "contiguous store lowering is not implemented"); - return VMIGroupSlotsStoreRecipe{ - VMIGroupSlotsStoreRecipeKind::Slots1AlignedLane0Vsts}; + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots1AlignedLane0Vsts}; } if (layout.getSlots() == 8) { @@ -648,23 +807,23 @@ VMILocalRecipeRegistry::getGroupSlotsStoreRecipe( if (*arity != ceilDivNonNegative(numGroups, 8)) return fail("slots=8 group_store arity must equal ceil(num_groups / " "8)"); - return VMIGroupSlotsStoreRecipe{ - VMIGroupSlotsStoreRecipeKind::Slots8UnitStrideVsts}; + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots8UnitStrideVsts}; } return fail("group_slots group_store currently supports only slots=1 or " "unit-stride slots=8"); } -FailureOr -getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, +FailureOr +getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, Operation *op, VMIVRegType sourceType, VMIMaskType maskType, VMIVRegType resultType, int64_t numGroups, bool requiresReassoc, VMIReductionKind reductionKind, std::string *reason) { auto fail = - [&](const Twine &message) -> FailureOr { + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -692,9 +851,9 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && *groupSize != 4 * vlaneElems)) - return fail("stable group_reduce_add slots=8 recipes support group " + return fail("stable group_reduce_add slots=8 support group " "sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems"); - return fail("stable group_reduce_add local recipes currently require " + return fail("stable group_reduce_add layout support currently requires " "result layout slots=8 or slots=1"); } @@ -704,7 +863,7 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (!elementCapability.isSupported()) return fail(elementCapability.reason); if (sourceType.getElementType() != resultType.getElementType()) - return fail("stable group_reduce_add local recipes require matching " + return fail("stable group_reduce_add layout support requires matching " "source/result element types"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -730,7 +889,7 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.getSlots() == 1) { if (failed(lanesPerPart) || *groupSize < *lanesPerPart || *groupSize % *lanesPerPart != 0) - return fail("stable group_reduce_add slots=1 recipes support group " + return fail("stable group_reduce_add slots=1 support group " "sizes that are multiples of one physical chunk"); if (!sourceLayout.isContiguous() || !maskLayout.isContiguous()) return fail("slots=1 group_reduce_add requires contiguous source/mask " @@ -743,8 +902,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, return fail(Twine("slots=1 group_reduce_add requires full source " "chunks; ") + sourceFullReason); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::ContiguousVcaddRows}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::ContiguousVcaddRows}; } if (*groupSize == vlaneElems) { @@ -759,8 +918,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, if (*resultArity != *sourceArity) return fail("one-vlane group_reduce_add requires source/result physical " "arity to match"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::OneVLaneVcgadd}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::OneVLaneVcgadd}; } if (*groupSize == 2 * vlaneElems) { @@ -778,8 +937,8 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, *sourceArity != *resultArity * 2) return fail("two-vlane group_reduce_add requires two source/mask parts per " "result part"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::TwoVLaneDeinterleaved2VcgaddVadd}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd}; } if (*groupSize == 4 * vlaneElems) { @@ -797,19 +956,19 @@ getGroupReduceAddRecipeImpl(const VMITargetCapabilityRegistry &capabilities, *sourceArity != *resultArity * 4) return fail("four-vlane group_reduce_add requires four source/mask parts per " "result part"); - return VMIGroupReduceAddFRecipe{ - VMIGroupReduceAddFRecipeKind::FourVLaneDeinterleaved4VcgaddTree}; + return VMIGroupReduceAddFSupport{ + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree}; } - return fail("stable group_reduce_add slots=8 recipes support group sizes " + return fail("stable group_reduce_add slots=8 support group sizes " "VLaneElems, 2*VLaneElems, or 4*VLaneElems"); } -FailureOr -VMILocalRecipeRegistry::getGroupReduceAddFRecipe( +FailureOr +VMILayoutSupport::getGroupReduceAddFSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, std::string *reason) const { - return getGroupReduceAddRecipeImpl( + return getGroupReduceAddSupportImpl( capabilities, op.getOperation(), cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), @@ -817,11 +976,11 @@ VMILocalRecipeRegistry::getGroupReduceAddFRecipe( VMIReductionKind::GroupAddF, reason); } -FailureOr -VMILocalRecipeRegistry::getGroupReduceAddIRecipe( +FailureOr +VMILayoutSupport::getGroupReduceAddISupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, std::string *reason) const { - return getGroupReduceAddRecipeImpl( + return getGroupReduceAddSupportImpl( capabilities, op.getOperation(), cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), @@ -829,12 +988,12 @@ VMILocalRecipeRegistry::getGroupReduceAddIRecipe( VMIReductionKind::GroupAddI, reason); } -FailureOr -VMILocalRecipeRegistry::getGroupBroadcastRecipe( +FailureOr +VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, std::string *reason) const { (void)capabilities; - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -888,30 +1047,30 @@ VMILocalRecipeRegistry::getGroupBroadcastRecipe( if (failed(resultFactor)) return fail("requires known result layout factor"); if (*resultFactor == 1) - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; bool blockFragmentSmallGroup = resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && *groupSize < *lanesPerPart && *lanesPerPart % resultLayout.getBlockElems() == 0; if (blockFragmentSmallGroup) - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) return fail("deinterleaved result requires every physical result chunk to " "stay within one logical group"); - return VMIGroupBroadcastRecipe{ - VMIGroupBroadcastRecipeKind::GroupSlotsVselr}; + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; } -FailureOr -VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, +FailureOr +VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -928,10 +1087,9 @@ VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, return fail("requires assigned source/result layouts and computable " "physical arity"); - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || sourceLayout.getNumGroups() != resultLayout.getNumGroups() || sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || @@ -940,28 +1098,37 @@ VMILocalRecipeRegistry::getTruncFRecipe(VMITruncFOp op, return fail("group-slot truncf requires matching " "group_slots(num_groups=G, slots=1) source/result layouts, " "f32 source, f16 result, and matching physical arity"); - return VMITruncFRecipe{VMITruncFRecipeKind::GroupSlots1F32ToF16}; + return VMITruncFSupport{VMITruncFSupportKind::GroupSlots1F32ToF16}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || !sourceType.getElementType().isF32() || *resultArity != 1) return fail("requires f32 deinterleaved source and contiguous result"); - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return VMITruncFRecipe{ - VMITruncFRecipeKind::Deinterleaved2F32ToContiguousF16}; - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8) - return VMITruncFRecipe{ - VMITruncFRecipeKind::Deinterleaved4F32ToContiguousF8}; + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved truncf factor, arity, or result " + "element width"); + + if (fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16}; + if (fact->kind == VMICastLayoutKind::Narrow4x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncFSupport{ + VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8}; return fail("unsupported deinterleaved truncf factor, arity, or result " "element width"); } -FailureOr -VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, +FailureOr +VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -982,25 +1149,32 @@ VMILocalRecipeRegistry::getExtFRecipe(VMIExtFOp op, return fail("requires contiguous source layout and deinterleaved f32 " "result layout"); - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - if (sourceBits == 16 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return VMIExtFRecipe{ - VMIExtFRecipeKind::ContiguousF16ToDeinterleaved2F32}; - if (sourceBits == 8 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return VMIExtFRecipe{ - VMIExtFRecipeKind::ContiguousF8ToDeinterleaved4F32}; + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported extf source element width, result factor, or " + "physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + if (fact->kind == VMICastLayoutKind::Widen4x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; return fail("unsupported extf source element width, result factor, or " "physical arity"); } template -static FailureOr getExtIRecipeImpl(OpT op, +static FailureOr getExtISupportImpl(OpT op, std::string *reason) { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1022,39 +1196,45 @@ static FailureOr getExtIRecipeImpl(OpT op, return fail("requires contiguous integer source layout and deinterleaved " "integer result layout"); - unsigned sourceBits = - pto::getPTOStorageElemBitWidth(sourceType.getElementType()); - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (sourceBits == 16 && resultBits == 32 && resultLayout.getFactor() == 2 && - *resultArity == 2 * *sourceArity) - return VMIExtIRecipe{ - VMIExtIRecipeKind::ContiguousI16ToDeinterleaved2I32}; - if (sourceBits == 8 && resultBits == 32 && resultLayout.getFactor() == 4 && - *resultArity == 4 * *sourceArity) - return VMIExtIRecipe{ - VMIExtIRecipeKind::ContiguousI8ToDeinterleaved4I32}; + FailureOr fact = + VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, + reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("unsupported integer extension source/result element width, " + "result factor, or physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + if (fact->kind == VMICastLayoutKind::Widen4x && + resultLayout.getFactor() == fact->factor && + *resultArity == fact->factor * *sourceArity) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; return fail("unsupported integer extension source/result element width, " "result factor, or physical arity"); } -FailureOr -VMILocalRecipeRegistry::getExtSIRecipe(VMIExtSIOp op, +FailureOr +VMILayoutSupport::getExtSISupport(VMIExtSIOp op, std::string *reason) const { - return getExtIRecipeImpl(op, reason); + return getExtISupportImpl(op, reason); } -FailureOr -VMILocalRecipeRegistry::getExtUIRecipe(VMIExtUIOp op, +FailureOr +VMILayoutSupport::getExtUISupport(VMIExtUIOp op, std::string *reason) const { - return getExtIRecipeImpl(op, reason); + return getExtISupportImpl(op, reason); } -FailureOr -VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, +FailureOr +VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1088,31 +1268,42 @@ VMILocalRecipeRegistry::getTruncIRecipe(VMITruncIOp op, "group_slots(num_groups=G, slots=1) source/result layouts, " "32-bit integer source, 16-bit integer result, and matching " "physical arity"); - return VMITruncIRecipe{VMITruncIRecipeKind::GroupSlots1I32ToI16}; + return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToI16}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - sourceBits != 32 || *resultArity != 1) - return fail("requires 32-bit integer deinterleaved source and contiguous " + *resultArity != 1) + return fail("requires integer deinterleaved source and contiguous " "integer result"); - if (sourceLayout.getFactor() == 2 && *sourceArity == 2 && resultBits == 16) - return VMITruncIRecipe{ - VMITruncIRecipeKind::Deinterleaved2I32ToContiguousI16}; - if (sourceLayout.getFactor() == 4 && *sourceArity == 4 && resultBits == 8 && + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return fail("unsupported deinterleaved trunci factor, arity, result " + "element width, or result signedness; 32-bit to 8-bit integer " + "narrowing requires unsigned i8 result"); + + if (fact->kind == VMICastLayoutKind::Narrow2x && + sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved2I32ToContiguousI16}; + if (fact->kind == VMICastLayoutKind::Narrow4x && + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor && cast(resultType.getElementType()).isUnsigned()) - return VMITruncIRecipe{ - VMITruncIRecipeKind::Deinterleaved4I32ToContiguousI8}; + return VMITruncISupport{ + VMITruncISupportKind::Deinterleaved4I32ToContiguousI8}; return fail("unsupported deinterleaved trunci factor, arity, result element " "width, or result signedness; 32-bit to 8-bit integer narrowing " "requires unsigned i8 result"); } -FailureOr -VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, +FailureOr +VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, std::string *reason) const { - auto fail = [&](const Twine &message) -> FailureOr { + auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -1151,5 +1342,5 @@ VMILocalRecipeRegistry::getBitcastRecipe(VMIBitcastOp op, "chunk"); } - return VMIBitcastRecipe{VMIBitcastRecipeKind::PerPartVbitcast}; + return VMIBitcastSupport{VMIBitcastSupportKind::PerPartVbitcast}; } diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index c44fc114ec..7f10e39ea6 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -18,7 +18,7 @@ #include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" -#include "PTO/Transforms/VMILocalRecipeRegistry.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "PTO/Transforms/VMITargetCapabilities.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1202,8 +1202,8 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 8 && resultType.getElementType().isF32()) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupLoadRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupLoadSupport(capabilities, op, reason))) return failure(); return success(); } @@ -1214,8 +1214,8 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, LogicalResult checkSupportedGroupSlotLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupSlotLoadRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupSlotLoadSupport(capabilities, op, reason))) return failure(); return success(); } @@ -1243,8 +1243,8 @@ checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); - VMILocalRecipeRegistry recipes; - if (failed(recipes.getGroupSlotsStoreRecipe(capabilities, op, reason))) + VMILayoutSupport supports; + if (failed(supports.getGroupSlotsStoreSupport(capabilities, op, reason))) return failure(); return success(); } @@ -2424,171 +2424,6 @@ FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, return result; } -LogicalResult checkVcgaddGroupReduceShape(VMIVRegType sourceType, - VMIMaskType maskType, - VMIVRegType resultType, - int64_t groupSize, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - if (sourceType.getElementType() != resultType.getElementType()) - return fail("vcgadd group_reduce_add path requires matching " - "source/result element types"); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(lanesPerPart) || *lanesPerPart % 8 != 0) - return fail("vcgadd group_reduce_add path requires known VLane width"); - int64_t vlaneElems = *lanesPerPart / 8; - if (groupSize != vlaneElems) - return fail("vcgadd group_reduce_add path requires group size equal to " - "one 32-byte VLane"); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - int64_t numGroups = sourceType.getElementCount() / groupSize; - if (!sourceLayout || !resultLayout || !maskLayout || - !sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || !maskLayout.isContiguous()) - return fail("vcgadd group_reduce_add path requires contiguous source/mask " - "layouts and matching num_groups result layout"); - std::string sourceFullReason; - if (failed(checkFullDataPhysicalChunks(sourceType, &sourceFullReason))) - return fail(Twine("vcgadd group_reduce_add path requires full source " - "chunks; ") + - sourceFullReason); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("vcgadd group_reduce_add path requires computable physical " - "arity"); - if (*sourceArity < 1 || *sourceArity != *maskArity || - *sourceArity != *resultArity) - return fail("vcgadd group_reduce_add path requires matching non-empty " - "source/mask/result physical arity"); - return success(); -} - -template -LogicalResult checkS16Block8GroupReduceShape(OpTy op, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("two-vlane group_reduce_add requires matching source/result " - "element types"); - - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || - *groupSize != 2 * (*lanesPerPart / 8)) - return fail("two-vlane group_reduce_add requires group size equal to two " - "32-byte VLanes"); - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (!sourceLayout || !sourceLayout.isDeinterleaved() || - sourceLayout.getFactor() != 2 || - (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("two-vlane group_reduce_add requires source layout " - "deinterleaved=2 with block_elems=1 or block_elems=8"); - if (!maskLayout || !maskLayout.isDeinterleaved() || - maskLayout.getFactor() != 2 || - maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("two-vlane group_reduce_add requires matching mask layout " - "deinterleaved=2 with the same block_elems"); - if (!resultLayout || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("two-vlane group_reduce_add requires " - "group_slots(num_groups, slots=8) result layout"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("two-vlane group_reduce_add requires computable physical " - "arity"); - int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2 || - *maskArity != *sourceArity) - return fail("two-vlane group_reduce_add requires two source/mask " - "parts per result part"); - - return success(); -} - -template -LogicalResult checkS32Block8GroupReduceShape(OpTy op, - std::string *reason) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - auto sourceType = cast(op.getSource().getType()); - auto maskType = cast(op.getMask().getType()); - auto resultType = cast(op.getResult().getType()); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("four-vlane group_reduce_add requires matching source/result " - "element types"); - - FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, op.getNumGroupsAttr().getInt()); - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(groupSize) || failed(lanesPerPart) || *lanesPerPart % 8 != 0 || - *groupSize != 4 * (*lanesPerPart / 8)) - return fail("four-vlane group_reduce_add requires group size equal to four " - "32-byte VLanes"); - - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); - if (!sourceLayout || !sourceLayout.isDeinterleaved() || - sourceLayout.getFactor() != 4 || - (sourceLayout.getBlockElems() != 1 && sourceLayout.getBlockElems() != 8)) - return fail("four-vlane group_reduce_add requires source layout " - "deinterleaved=4 with block_elems=1 or block_elems=8"); - if (!maskLayout || !maskLayout.isDeinterleaved() || - maskLayout.getFactor() != 4 || - maskLayout.getBlockElems() != sourceLayout.getBlockElems()) - return fail("four-vlane group_reduce_add requires matching mask layout " - "deinterleaved=4 with the same block_elems"); - if (!resultLayout || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != numGroups || resultLayout.getSlots() != 8) - return fail("four-vlane group_reduce_add requires " - "group_slots(num_groups, slots=8) result layout"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - if (failed(sourceArity) || failed(maskArity) || failed(resultArity)) - return fail("four-vlane group_reduce_add requires computable physical " - "arity"); - int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4 || - *maskArity != *sourceArity) - return fail("four-vlane group_reduce_add requires four source/mask " - "parts per result part"); - - return success(); -} - std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -3230,13 +3065,13 @@ struct OneToNVMIEnsureLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getDataLayoutMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeDataLayout(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( - op, Twine("ensure_layout has no registered materialization recipe: ") + - recipeReason); + op, Twine("ensure_layout has no registered materialization support: ") + + supportReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3264,14 +3099,14 @@ struct OneToNVMIEnsureMaskLayoutOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getMaskLayoutMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( op, - Twine("ensure_mask_layout has no registered materialization recipe: ") + - recipeReason); + Twine("ensure_mask_layout has no registered materialization support: ") + + supportReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3303,14 +3138,14 @@ struct OneToNVMIEnsureMaskGranularityOpPattern OneToNPatternRewriter &rewriter) const override { auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - VMILocalRecipeRegistry recipes; - std::string recipeReason; - if (failed(recipes.getMaskGranularityMaterializationRecipe( - sourceType, resultType, &recipeReason))) + VMILayoutSupport supports; + std::string supportReason; + if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, + &supportReason))) return rewriter.notifyMatchFailure( op, Twine("ensure_mask_granularity has no registered materialization " - "recipe: ") + - recipeReason); + "support: ") + + supportReason); if (sourceType.getLayout() != resultType.getLayout()) return rewriter.notifyMatchFailure( op, "mask granularity helper cannot also change layout"); @@ -4503,12 +4338,12 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { return failure(); ValueRange valueParts = adaptor.getValue(); - VMILocalRecipeRegistry localRecipes; - FailureOr storeRecipe = - localRecipes.getContiguousStoreRecipe(valueVMIType); - if (succeeded(storeRecipe) && - storeRecipe->kind == - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { + VMILayoutSupport localSupports; + FailureOr storeSupport = + localSupports.getContiguousStoreSupport(valueVMIType); + if (succeeded(storeSupport) && + storeSupport->kind == + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -4961,12 +4796,12 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { ValueRange valueParts = adaptor.getValue(); Value zero = rewriter.create(op.getLoc(), 0); - VMILocalRecipeRegistry localRecipes; - FailureOr storeRecipe = - localRecipes.getContiguousStoreRecipe(valueVMIType); - if (succeeded(storeRecipe) && - storeRecipe->kind == - VMIContiguousStoreRecipeKind::Deinterleaved2Vstsx2) { + VMILayoutSupport localSupports; + FailureOr storeSupport = + localSupports.getContiguousStoreSupport(valueVMIType); + if (succeeded(storeSupport) && + storeSupport->kind == + VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { std::optional dist = getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { @@ -5569,25 +5404,38 @@ struct OneToNVMIReduceAddFOpPattern template struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + OneToNVMIGroupReduceAddOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, context), + capabilities(capabilities) {} LogicalResult matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); - auto maskVMIType = cast(op.getMask().getType()); auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); ValueRange maskParts = adaptor.getMask(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + + VMILayoutSupport supports; + std::string supportReason; + FailureOr support = + getSupport(supports, op, &supportReason); + if (failed(support)) + return rewriter.notifyMatchFailure( + op, Twine("group_reduce_add has no layout support: ") + + supportReason); + FailureOr groupSize = getGroupSizeFromNumGroups( sourceVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( op, "group_reduce_addf requires num_groups to evenly divide lane count"); - if (succeeded(checkVcgaddGroupReduceShape( - sourceVMIType, maskVMIType, resultVMIType, *groupSize, nullptr))) { + + if (support->kind == VMIGroupReduceAddFSupportKind::OneVLaneVcgadd) { if (sourceParts.size() != maskParts.size() || sourceParts.size() != resultTypes.size() || sourceParts.empty()) return rewriter.notifyMatchFailure( @@ -5621,7 +5469,8 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } - if (succeeded(checkS16Block8GroupReduceShape(op, nullptr))) { + if (support->kind == + VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd) { int64_t resultPartCount = resultTypes.size(); if (static_cast(sourceParts.size()) != resultPartCount * 2 || maskParts.size() != sourceParts.size()) @@ -5674,7 +5523,8 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } - if (succeeded(checkS32Block8GroupReduceShape(op, nullptr))) { + if (support->kind == + VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree) { int64_t resultPartCount = resultTypes.size(); if (static_cast(sourceParts.size()) != resultPartCount * 4 || maskParts.size() != sourceParts.size()) @@ -5733,6 +5583,10 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return success(); } + if (support->kind != VMIGroupReduceAddFSupportKind::ContiguousVcaddRows) + return rewriter.notifyMatchFailure(op, + "unknown group_reduce_add support"); + int64_t lanesPerPart = 0; int64_t groupCount = 0; int64_t chunksPerGroup = 0; @@ -5815,6 +5669,21 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } + +private: + FailureOr + getSupport(VMILayoutSupport &supports, VMIGroupReduceAddFOp op, + std::string *reason) const { + return supports.getGroupReduceAddFSupport(capabilities, op, reason); + } + + FailureOr + getSupport(VMILayoutSupport &supports, VMIGroupReduceAddIOp op, + std::string *reason) const { + return supports.getGroupReduceAddISupport(capabilities, op, reason); + } + + const VMITargetCapabilityRegistry &capabilities; }; struct OneToNVMIGroupBroadcastOpPattern @@ -7006,8 +6875,6 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupReduceAddOpPattern, - OneToNVMIGroupReduceAddOpPattern, OneToNVMIGroupBroadcastOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, @@ -7017,6 +6884,10 @@ void populateVMIOneToNConversionPatterns( OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( typeConverter, patterns.getContext()); + patterns + .add, + OneToNVMIGroupReduceAddOpPattern>( + typeConverter, patterns.getContext(), capabilities); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -7059,47 +6930,47 @@ LogicalResult verifyNoResidualVMIIR(ModuleOp module) { LogicalResult checkSupportedExtFShape(VMIExtFOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtFRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtFSupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedTruncFShape(VMITruncFOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getTruncFRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getTruncFSupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedExtSIShape(VMIExtSIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtSIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtSISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedExtUIShape(VMIExtUIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getExtUIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getExtUISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedTruncIShape(VMITruncIOp op, std::string *reason = nullptr) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getTruncIRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getTruncISupport(op, reason))) return failure(); return success(); } LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { - VMILocalRecipeRegistry recipes; - if (failed(recipes.getBitcastRecipe(op, reason))) + VMILayoutSupport supports; + if (failed(supports.getBitcastSupport(op, reason))) return failure(); return success(); } @@ -7399,83 +7270,15 @@ template LogicalResult checkSupportedGroupReduceAddShape( const VMITargetCapabilityRegistry &capabilities, OpTy op, std::string *reason = nullptr) { - auto fail = [&](const Twine &message) -> LogicalResult { - if (reason) - *reason = message.str(); - return failure(); - }; - - if constexpr (std::is_same_v) { - if (!op->hasAttr("reassoc")) - return fail("requires reassoc attr for pair-wise floating-point reduction"); - } - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); - auto maskType = cast(op.getMask().getType()); - VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); - VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - VMILayoutAttr maskLayout = maskType.getLayoutAttr(); - if (!sourceLayout || !resultLayout || !maskLayout) - return fail("requires assigned source, mask, and result layouts"); - - VMILocalRecipeRegistry recipes; + VMILayoutSupport supports; if constexpr (std::is_same_v) { - if (succeeded(recipes.getGroupReduceAddFRecipe(capabilities, op, nullptr))) + if (succeeded(supports.getGroupReduceAddFSupport(capabilities, op, reason))) return success(); } else { - if (succeeded(recipes.getGroupReduceAddIRecipe(capabilities, op, nullptr))) + if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) return success(); } - - FailureOr groupSize = getGroupSizeFromNumGroups( - sourceType, op.getNumGroupsAttr().getInt(), reason); - if (failed(groupSize)) - return failure(); - if (succeeded(checkS16Block8GroupReduceShape(op, reason))) - return success(); - if (succeeded(checkS32Block8GroupReduceShape(op, reason))) - return success(); - if (!sourceLayout.isContiguous() || !resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt() || - !maskLayout.isContiguous()) - return fail("requires contiguous source/mask layouts and matching " - "num_groups result layout"); - VMICapabilityResult elementCapability = - capabilities.supportsReductionElementType( - std::is_same_v ? VMIReductionKind::GroupAddF - : VMIReductionKind::GroupAddI, - sourceType.getElementType()); - if (!elementCapability.isSupported()) - return fail(elementCapability.reason); - if (sourceType.getElementType() != resultType.getElementType()) - return fail("requires source/result element type to match"); - if (sourceType.getElementCount() != resultType.getElementCount()) - return fail("requires source/result lane count to match"); - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr resultArity = getVMIPhysicalArity(resultType); - FailureOr maskArity = getVMIPhysicalArity(maskType); - if (failed(sourceArity) || failed(resultArity) || failed(maskArity)) - return fail("requires computable source/result/mask physical arity"); - if (*sourceArity != *resultArity || *sourceArity != *maskArity) - return fail("requires source/result/mask physical arity to match"); - if (succeeded(checkVcgaddGroupReduceShape(sourceType, maskType, resultType, - *groupSize, nullptr))) - return success(); - if (failed(checkSupportedGroupChunkShape(sourceType, *groupSize, reason))) - return failure(); - if (resultLayout.getSlots() <= 0) - return success(); - - FailureOr lanesPerPart = - getDataLanesPerPart(sourceType.getElementType()); - if (failed(lanesPerPart)) - return fail("requires known physical chunk lane count"); - if (!sourceLayout.isContiguous() || *groupSize != *lanesPerPart || - resultLayout.getSlots() != 1) - return fail("explicit group_slots group_reduce_add chunk path requires " - "contiguous full-physical-chunk group size source and slots=1 " - "result layout"); - return success(); + return failure(); } LogicalResult checkSupportedGroupBroadcastShape( @@ -7500,8 +7303,8 @@ LogicalResult checkSupportedGroupBroadcastShape( VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); - VMILocalRecipeRegistry recipes; - if (succeeded(recipes.getGroupBroadcastRecipe(capabilities, op, nullptr))) + VMILayoutSupport supports; + if (succeeded(supports.getGroupBroadcastSupport(capabilities, op, nullptr))) return success(); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) @@ -7708,7 +7511,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, broadcast.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_broadcast requires full source chunks with " - "#pto.vmi.layout, a dense full result layout, " + "#pto.vmi.layout, a dense full result layout, " "and num_groups deriving a group size that divides or is a " "multiple of physical chunk lanes (" << reason << ")"; @@ -8117,7 +7920,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << kVMIDiagUnsupportedPrefix << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " "VLane groups or through pto.vcadd with reassoc, contiguous full " - "source/mask chunks, #pto.vmi.layout result " + "source/mask chunks, #pto.vmi.layout result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" << reason << ")"; diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto index 51cd09053f..1cbdeea1d8 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -53,12 +53,12 @@ module { // ASSIGN: %[[COPY_DENSE:.*]] = pto.vmi.ensure_layout %[[COPY]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[COPY_DENSE]] -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[PROD:.*]] = pto.vmi.mulf %[[X]], %[[SCALE]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto index 6e165de8a0..eefe95d973 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_broadcast_remat( @@ -37,9 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[BCAST_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[BCAST_DEINT]] -// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.broadcast %[[SCALAR]] -// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[BCAST_CONTIG:.*]] = pto.vmi.ensure_layout %[[BCAST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[BCAST_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto index e387aa077d..a426621c15 100644 --- a/test/lit/vmi/vmi_layout_assignment_constant_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_constant_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_constant_remat( @@ -37,10 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[CONST_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[CONST_DEINT]] -// ASSIGN: %[[CONST_CONTIG:.*]] = "pto.vmi.constant"() -// ASSIGN-SAME: dense<1.000000e+00> : tensor<128xf32> -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[CONST_CONTIG:.*]] = pto.vmi.ensure_layout %[[CONST_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[CONST_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto index 2bc648261f..5999ace148 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -31,10 +31,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( // ASSIGN: %[[X:.*]] = pto.vmi.group_load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto index cb0e15864e..fe5920c07b 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( @@ -39,8 +39,8 @@ module { // ASSIGN-SAME: %[[ACTIVE:arg[0-9]+]]: index) // ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK1:.*]] = pto.vmi.create_group_mask %[[ACTIVE]] -// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK1:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s32_dynamic( diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto index 8e8a86450d..6ffab1471d 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -33,10 +33,12 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_group_reduce_multi_consumer( // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto index ec29b4387a..c8ded49a2f 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -18,7 +18,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store operand #0 has type !pto.vmi.vreg<64xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<64xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store operand #0 has type '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<64xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' (unsupported source/result layout pair) pto.vmi.store %sum, %dst[%off] : !pto.vmi.vreg<64xf32>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index 27e304ae27..b976ab518d 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -36,10 +36,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( // ASSIGN: %[[X32:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto index 6ed4e7f9e7..187a79d42b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto @@ -26,7 +26,7 @@ module { -> !pto.vmi.vreg<128xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index b322e5700e..b63d134392 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -15,8 +15,11 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK-SAME: VMI types: operand#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> + // CHECK-SAME: operand#1=!pto.vmi.mask<96xb32, #pto.vmi.layout> + // CHECK-SAME: result#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index d5fa902c56..602ac579ad 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( @@ -46,10 +46,10 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<192xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] @@ -73,7 +73,9 @@ module { // ASSIGN: %[[PX:.*]] = pto.vmi.load // ASSIGN-SAME: {full_read_elems = 256 : i64} // ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: %[[PMASK:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] +// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto index cface43bab..af78715f95 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -15,11 +15,11 @@ module { %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf operand #0 has type + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.group_reduce_addf operand #0 has type // CHECK-SAME: #pto.vmi.layout // CHECK-SAME: requires // CHECK-SAME: #pto.vmi.layout - // CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion // CHECK: requires source and result to have the same physical arity %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 6, reassoc} diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto index b8cd439d23..01dab5b003 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto index b432d7c68c..1589e531dc 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid( %src: !pto.ptr, %off: index) { %c2 = arith.constant 2 : index - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=1 group_slot_load currently lowers as one lane-0 vsldb per group // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto index c30502a252..95fa93474d 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -51,10 +51,10 @@ module { // ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) // ASSIGN: %[[X:.*]] = pto.vmi.group_load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_group_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_group_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto index 996760ed66..35959585de 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<512xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group // CHECK-SAME: requires constant positive row_stride divisible by 8 elements // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented diff --git a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto index 773fd4187c..d79cdfddba 100644 --- a/test/lit/vmi/vmi_layout_assignment_iota_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_iota_remat.pto @@ -7,7 +7,7 @@ // See LICENSE in the root of the software repository for the full text of the License. // RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { func.func @vmi_layout_assignment_iota_remat( @@ -37,9 +37,8 @@ module { // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.addf %[[IOTA_DEINT]], %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout %[[IOTA_DEINT]] -// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.iota %[[BASE]] -// ASSIGN-SAME: f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[IOTA_CONTIG:.*]] = pto.vmi.ensure_layout %[[IOTA_DEINT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[IOTA_CONTIG]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto index 8a74de4097..1d3a2f3d0b 100644 --- a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -38,8 +38,8 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[X_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: %[[M16:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> // ASSIGN: pto.vmi.masked_store %[[H]] // ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> @@ -54,7 +54,9 @@ module { // LOWER: pto.vcvt // LOWER: pto.vcvt // LOWER: pto.vor -// LOWER: pto.plt_b16 +// LOWER: pto.ppack +// LOWER: pto.ppack +// LOWER: pto.por // LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto index b114643836..8e799c0704 100644 --- a/test/lit/vmi/vmi_layout_assignment_mask_remat.pto +++ b/test/lit/vmi/vmi_layout_assignment_mask_remat.pto @@ -6,7 +6,8 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize | FileCheck %s --check-prefix=REMAT module { func.func @vmi_layout_assignment_create_mask_remat( @@ -47,27 +48,41 @@ module { } } -// CHECK-LABEL: func.func @vmi_layout_assignment_create_mask_remat( -// CHECK-SAME: %[[ACTIVE:.*]]: index -// CHECK: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] -// CHECK-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] -// CHECK-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[M16]] -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[M32]] -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-NOT: pto.vmi.ensure_mask_layout -// CHECK-NOT: pto.vmi.ensure_mask_granularity +// ASSIGN-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// ASSIGN-SAME: %[[ACTIVE:.*]]: index +// ASSIGN: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// ASSIGN-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[M16:.*]] = pto.vmi.ensure_mask_granularity %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[M32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( -// CHECK: %[[CM32:.*]] = "pto.vmi.constant_mask"() -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK: %[[CM16:.*]] = "pto.vmi.constant_mask"() -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[CM16]] -// CHECK-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> -// CHECK: pto.vmi.select %[[CM32]] -// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -// CHECK-NOT: pto.vmi.ensure_mask_layout -// CHECK-NOT: pto.vmi.ensure_mask_granularity +// ASSIGN-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// ASSIGN: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[CM16:.*]] = pto.vmi.ensure_mask_granularity %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM16]] +// ASSIGN-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.select %[[CM32]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_assignment_create_mask_remat( +// REMAT-SAME: %[[ACTIVE:.*]]: index +// REMAT: %[[M32:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[M16:.*]] = pto.vmi.create_mask %[[ACTIVE]] +// REMAT-SAME: index -> !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[M16]] +// REMAT: pto.vmi.select %[[M32]] +// REMAT-LABEL: func.func @vmi_layout_assignment_constant_mask_remat( +// REMAT: %[[CM32:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// REMAT: %[[CM16:.*]] = "pto.vmi.constant_mask"() +// REMAT-SAME: !pto.vmi.mask<128xb16, #pto.vmi.layout> +// REMAT: pto.vmi.select %[[CM16]] +// REMAT: pto.vmi.select %[[CM32]] +// REMAT-NOT: pto.vmi.ensure_mask_layout +// REMAT-NOT: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto index 6c0b2d2ece..796f446b60 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -48,8 +48,8 @@ module { // ASSIGN: pto.vmi.store %[[X]] // ASSIGN: %[[X_SPLIT:.*]] = pto.vmi.ensure_layout %[[X]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index 968e8d03c2..9d3147aaea 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -40,8 +40,8 @@ module { // ASSIGN: pto.vmi.ensure_layout // ASSIGN-SAME: #pto.vmi.layout // ASSIGN-SAME: #pto.vmi.layout -// ASSIGN: pto.vmi.create_group_mask -// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf // LOWER: pto.pdintlv_b32 // LOWER: pto.pdintlv_b32 diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto index 46f7ff71f2..dd8f2910ab 100644 --- a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -42,8 +42,8 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[X:.*]] = pto.vmi.addf %[[A]], %[[BIASV]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index e57954b16e..71d282577a 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<128xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type !pto.vmi.vreg<128xf32, #pto.vmi.layout> but requires !pto.vmi.vreg<128xf32, #pto.vmi.layout>; pto.vmi.ensure_layout has no registered materialization recipe: unsupported source/result layout pair + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto index 63fc33cfe6..e9553c2c9d 100644 --- a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -36,10 +36,10 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.create_mask +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -// ASSIGN: %[[MASK:.*]] = pto.vmi.create_mask -// ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] +// ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto index e63567e48d..f946686b6f 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_gate_bitcast_group_slots_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support // CHECK-SAME: does not support group_slots layouts // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" %out = pto.vmi.bitcast %source diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto index 806aaa26dd..2acec47cd2 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_bitcast_recipe_invalid( + func.func @vmi_layout_gate_bitcast_support_invalid( %source: !pto.vmi.vreg<65xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support // CHECK-SAME: requires matching logical bit footprint in every physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" %out = pto.vmi.bitcast %source diff --git a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto index 7bda214fed..4e14381743 100644 --- a/test/lit/vmi/vmi_layout_gate_extf_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_extf_recipe_invalid( + func.func @vmi_layout_gate_extf_support_invalid( %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered layout support // CHECK-SAME: requires contiguous source layout and deinterleaved f32 result layout // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" %out = pto.vmi.extf %source diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto index 224858064c..64681b5dd3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_broadcast_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_broadcast_recipe_invalid( + func.func @vmi_layout_gate_group_broadcast_support_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" %out = pto.vmi.group_broadcast %source {num_groups = 8} diff --git a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto index 8f9fb2c809..a14ff20a0b 100644 --- a/test/lit/vmi/vmi_layout_gate_group_load_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_load_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_load_recipe_invalid( + func.func @vmi_layout_gate_group_load_support_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_load has no registered block8 layout support // CHECK-SAME: block8 strided group_load requires constant positive row_stride divisible by 8 f32 elements // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_load" %out = pto.vmi.group_load %src[%off], %stride diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto similarity index 85% rename from test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto index d33315f88d..0c792693f3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto @@ -9,11 +9,11 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_reduce_slots1_recipe_invalid( + func.func @vmi_layout_gate_group_reduce_slots1_support_invalid( %source: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=1 recipes support group sizes that are multiples of one physical chunk + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=1 support group sizes that are multiples of one physical chunk // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto similarity index 85% rename from test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto index 33a7bc0fae..734c9dd497 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto @@ -9,11 +9,11 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_reduce_recipe_invalid( + func.func @vmi_layout_gate_group_reduce_support_invalid( %source: !pto.vmi.vreg<96xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<96xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add slots=8 recipes support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_reduce_addf" %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto index 31e7f13c3e..334be3d744 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slot_load_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_slot_load_recipe_invalid( + func.func @vmi_layout_gate_group_slot_load_support_invalid( %src: !pto.ptr, %off: index, %stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_slot_load has no registered layout support // CHECK-SAME: slots=8 group_slot_load requires constant unit source_group_stride // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" %out = pto.vmi.group_slot_load %src[%off], %stride diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto index c787f57fea..f3263148b3 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_gate_group_store_slots2_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} @@ -28,8 +28,8 @@ module { func.func @vmi_layout_gate_group_reduce_slots2_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots local recipe - // CHECK-SAME: stable group_reduce_add local recipes currently require result layout slots=8 or slots=1 + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_reduce_addf has no registered group_slots layout support + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto similarity index 93% rename from test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto index c7003a887d..db0794748d 100644 --- a/test/lit/vmi/vmi_layout_gate_group_store_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto @@ -9,10 +9,10 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_group_store_recipe_invalid( + func.func @vmi_layout_gate_group_store_support_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride // CHECK: note: see current operation: "pto.vmi.group_store" pto.vmi.group_store %value, %dst[%off], %row_stride diff --git a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto index 53cc5c2a12..4aa1f30cbb 100644 --- a/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_helper_materialization_shape_invalid.pto @@ -11,7 +11,7 @@ module { func.func @vmi_layout_gate_ensure_layout_shape_invalid( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support // CHECK-SAME: requires source and result to have the same physical arity %dense = pto.vmi.ensure_layout %value : !pto.vmi.vreg<128xf32, #pto.vmi.layout> @@ -25,7 +25,7 @@ module { module { func.func @vmi_layout_gate_ensure_mask_layout_shape_invalid( %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_mask_layout has no registered materialization support // CHECK-SAME: requires source and result to have the same physical arity %dense = pto.vmi.ensure_mask_layout %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto similarity index 92% rename from test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto index 871e14eb5b..90e49c52dd 100644 --- a/test/lit/vmi/vmi_layout_gate_helper_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto @@ -9,7 +9,7 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_helper_recipe_invalid( + func.func @vmi_layout_gate_helper_support_invalid( %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { %bad = pto.vmi.ensure_layout %value : !pto.vmi.vreg<64xf32, #pto.vmi.layout> @@ -18,5 +18,5 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.ensure_layout has no registered materialization support // CHECK-SAME: unsupported source/result layout pair diff --git a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto similarity index 95% rename from test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_store_support_invalid.pto index 3877eb1a3a..7c62871865 100644 --- a/test/lit/vmi/vmi_layout_gate_store_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_gate_store_deint_tail_invalid( %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, %dst: !pto.ptr, %offset: index) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.store has no registered contiguous-memory layout support // CHECK-SAME: requires arity divisible by layout factor pto.vmi.store %value, %dst[%offset] : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, @@ -27,7 +27,7 @@ module { func.func @vmi_layout_gate_tile_write_deint_tail_invalid( %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, %dst: memref<129xf32>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory layout support // CHECK-SAME: requires arity divisible by layout factor pto.vmi.tile_write %value, %dst : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_layout_gate_local_recipe.pto b/test/lit/vmi/vmi_layout_gate_support.pto similarity index 92% rename from test/lit/vmi/vmi_layout_gate_local_recipe.pto rename to test/lit/vmi/vmi_layout_gate_support.pto index 7644fae1c6..629b85c208 100644 --- a/test/lit/vmi/vmi_layout_gate_local_recipe.pto +++ b/test/lit/vmi/vmi_layout_gate_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s module { - func.func @vmi_layout_gate_local_recipe( + func.func @vmi_layout_gate_support( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} @@ -20,5 +20,5 @@ module { } } -// CHECK-LABEL: func.func @vmi_layout_gate_local_recipe( +// CHECK-LABEL: func.func @vmi_layout_gate_support( // CHECK: pto.vmi.group_reduce_addf diff --git a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto similarity index 94% rename from test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto rename to test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto index 68e7963b1b..3021b88a7d 100644 --- a/test/lit/vmi/vmi_layout_gate_truncf_recipe_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto @@ -9,9 +9,9 @@ // RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s module { - func.func @vmi_layout_gate_truncf_recipe_invalid( + func.func @vmi_layout_gate_truncf_support_invalid( %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered local recipe + // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered layout support // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" %out = pto.vmi.truncf %source diff --git a/test/lit/vmi/vmi_layout_rematerialize_data.pto b/test/lit/vmi/vmi_layout_rematerialize_data.pto index 29faa34fb1..22a03d88a5 100644 --- a/test/lit/vmi/vmi_layout_rematerialize_data.pto +++ b/test/lit/vmi/vmi_layout_rematerialize_data.pto @@ -39,6 +39,18 @@ module { !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> } + + func.func @vmi_layout_rematerialize_keeps_load_helper( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %load = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %load_deint = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %load_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } } // CHECK-LABEL: func.func @vmi_layout_rematerialize_data( @@ -47,3 +59,8 @@ module { // CHECK: %[[CONST:.*]] = "pto.vmi.constant"(){{.*}}dense<1.000000e+00> : tensor<128xf32>{{.*}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-NOT: pto.vmi.ensure_layout // CHECK: return %[[BCAST]], %[[IOTA]], %[[CONST]] + +// CHECK-LABEL: func.func @vmi_layout_rematerialize_keeps_load_helper( +// CHECK: %[[LOAD:.*]] = pto.vmi.load +// CHECK: %[[LOAD_DEINT:.*]] = pto.vmi.ensure_layout %[[LOAD]] +// CHECK: return %[[LOAD_DEINT]] diff --git a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto index 9db3fcb22b..eb21fae758 100644 --- a/test/lit/vmi/vmi_layout_sink_materialization_binary.pto +++ b/test/lit/vmi/vmi_layout_sink_materialization_binary.pto @@ -57,6 +57,85 @@ module { return %sum : !pto.vmi.vreg<128xf32, #pto.vmi.layout> } + func.func @vmi_layout_sink_materialization_fma( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %acc: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %acc_deint = pto.vmi.ensure_layout %acc + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %out = pto.vmi.fma %lhs_deint, %rhs_deint, %acc_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpf( + %lhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %mask = pto.vmi.cmpf "olt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_cmpi( + %lhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %rhs: !pto.vmi.vreg<128xi32, #pto.vmi.layout>) + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> { + %lhs_deint = pto.vmi.ensure_layout %lhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %rhs_deint = pto.vmi.ensure_layout %rhs + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + %mask = pto.vmi.cmpi "slt", %lhs_deint, %rhs_deint + : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + return %mask : !pto.vmi.mask<128xb32, #pto.vmi.layout> + } + + func.func @vmi_layout_sink_materialization_select( + %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, + %true_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %false_value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %mask_deint = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<128xb32, #pto.vmi.layout> + -> !pto.vmi.mask<128xb32, #pto.vmi.layout> + %true_deint = pto.vmi.ensure_layout %true_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %false_deint = pto.vmi.ensure_layout %false_value + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %selected = pto.vmi.select %mask_deint, %true_deint, %false_deint + : !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %selected + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + func.func @vmi_layout_sink_materialization_unary( %src: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { @@ -154,6 +233,49 @@ module { // CHECK: %[[SUM2:.*]] = pto.vmi.addf %[[LHS_DEINT]], %arg1 // CHECK: return %[[SUM2]] +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_fma( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[FMA:.*]] = pto.vmi.fma %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[FMA_DEINT:.*]] = pto.vmi.ensure_layout %[[FMA]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[FMA_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpf( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPF:.*]] = pto.vmi.cmpf "olt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPF_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPF]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPF_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_cmpi( +// CHECK-NOT: pto.vmi.ensure_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK: %[[CMPI:.*]] = pto.vmi.cmpi "slt", %arg0, %arg1 +// CHECK-SAME: !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK: %[[CMPI_DEINT:.*]] = pto.vmi.ensure_mask_layout %[[CMPI]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[CMPI_DEINT]] + +// CHECK-LABEL: func.func @vmi_layout_sink_materialization_select( +// CHECK-NOT: pto.vmi.ensure_mask_layout %arg0 +// CHECK-NOT: pto.vmi.ensure_layout %arg1 +// CHECK-NOT: pto.vmi.ensure_layout %arg2 +// CHECK: %[[SELECT:.*]] = pto.vmi.select %arg0, %arg1, %arg2 +// CHECK-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: %[[SELECT_DEINT:.*]] = pto.vmi.ensure_layout %[[SELECT]] +// CHECK-SAME: #pto.vmi.layout +// CHECK: return %[[SELECT_DEINT]] + // CHECK-LABEL: func.func @vmi_layout_sink_materialization_unary( // CHECK-NOT: pto.vmi.ensure_layout %arg0 // CHECK: %[[NEG:.*]] = pto.vmi.negf %arg0 diff --git a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto index 3b2fc0d080..55e1308b4b 100644 --- a/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto +++ b/test/lit/vmi/vmi_to_vpto_constant_mask_rematerialize.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s module { func.func @vmi_to_vpto_constant_mask_rematerialize( diff --git a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto index 74ef8194d5..03add9ada4 100644 --- a/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto +++ b/test/lit/vmi/vmi_to_vpto_create_mask_rematerialize.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-rematerialize -vmi-to-vpto | FileCheck %s module { func.func @vmi_to_vpto_create_mask_rematerialize( diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto index dc1b938924..55ed864da1 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( + func.func @vmi_to_vpto_group_broadcast_slots8_support( %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -34,7 +34,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_slots8_support( // CHECK-COUNT-16: pto.vselr // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto index 01d9711ef0..3c40457460 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_vselr( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) diff --git a/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_load_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_load_support.pto index a1c5959f98..1af77958af 100644 --- a/test/lit/vmi/vmi_to_vpto_group_load_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_load_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_load_local_recipe( + func.func @vmi_to_vpto_group_load_support( %source: !pto.ptr, %row_stride: index) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -30,7 +30,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_load_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_load_support( // CHECK-COUNT-8: pto.vlds // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index 380a090a71..019b45f7c5 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -21,9 +21,9 @@ module { %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %b = pto.vmi.group_broadcast %r {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> pto.vmi.group_store %b, %dst[%c0], %row_stride {num_groups = 2} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto new file mode 100644 index 0000000000..b3e48c56b4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @legacy_group_slots_without_explicit_slots( + %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> { + // CHECK: pto.vmi.group_reduce_addf lowers through pto.vcgadd + // CHECK-SAME: stable group_reduce_add layout support currently requires result layout slots=8 or slots=1 + %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + -> !pto.vreg<64xf32> + return %part : !pto.vreg<64xf32> + } +} diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto similarity index 94% rename from test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto index 4b706dc08d..99359b1a8e 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_s64_local_recipe( + func.func @vmi_to_vpto_group_reduce_s64_support( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<512xb32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, @@ -31,7 +31,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_support( // CHECK-COUNT-8: pto.vcadd // CHECK: pto.vsel // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto similarity index 91% rename from test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto index a6737eae1f..9e6a9faf00 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( + func.func @vmi_to_vpto_group_reduce_slots8_support( %source: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) -> !pto.vreg<64xf32> { @@ -24,7 +24,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_slots8_support( // CHECK: pto.vcgadd // CHECK-NOT: pto.vcadd // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto index 27d246e6d2..d6b52468b4 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -16,9 +16,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto index d3da9416b6..d6265bd490 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -19,10 +19,10 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, !pto.vmi.mask<1024xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto similarity index 91% rename from test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto index 3a9aa117b5..e806b28b92 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_load_local_recipe( + func.func @vmi_to_vpto_group_slot_load_support( %src: !pto.ptr, %off: index) -> !pto.vreg<64xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} @@ -22,7 +22,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_support( // CHECK: pto.vsldb // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto similarity index 96% rename from test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto rename to test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto index eec3c06d2a..4874117e69 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_local_recipe.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto @@ -9,7 +9,7 @@ // RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( + func.func @vmi_to_vpto_group_slot_truncf_slots1_support( %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, @@ -29,7 +29,7 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_local_recipe( +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_truncf_slots1_support( // CHECK-COUNT-8: pto.vcvt // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index a0cc8215cb..dd69bcfaa2 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -258,7 +258,9 @@ module { // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> // CHECK: scf.if -// CHECK: pto.plt_b16 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask // CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask // CHECK-LABEL: func.func @vmi_to_vpto_dequant_matrix_fp8_to_f32( @@ -301,7 +303,12 @@ module { // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.plt_b8 {{.*}} : i32 -> !pto.mask, i32 +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask +// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask // CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index f78e4ef5f2..5297123e5a 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -17,9 +17,9 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: pto.vmi.ensure_layout has no registered materialization recipe +// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion // CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: requires source and result to have the same physical arity From 200798f6a4aa39d049c94f9f7f52a5ade20dde29 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:17:33 +0800 Subject: [PATCH 23/54] Support partial packed VMI group slots --- docs/designs/vmi-layout-lowering-cases.md | 155 ++++++++++++++++++ lib/PTO/IR/VMI.cpp | 5 +- ...assignment_group_reduce_partial_slots8.pto | 61 +++++++ .../vmi/vmi_layout_group_slots_invalid.pto | 4 +- ...mi_to_vpto_group_reduce_partial_slots8.pto | 94 +++++++++++ test/lit/vmi/vmi_type_attr_parse.pto | 7 +- 6 files changed, 320 insertions(+), 6 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index d2d7b3835d..5a007987a7 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -5494,6 +5494,12 @@ S = logical_lane_count / num_groups The canonical grouped-reduce layouts are: ```text +Packed group-slot rule: + K is the physical slot capacity of one packed group-result chunk. + For VCG-style packed reductions, K = 8. + G does not have to be divisible by K; the final chunk may be partial. + active_groups(chunk c) = min(K, G - c * K). + S == VLaneElems: source/mask layout = contiguous result layout = group_slots(num_groups=G, slots=8) @@ -5696,6 +5702,155 @@ for r = 0..7: out[group_off + r] = reduce_T16(base[off + r * 64 + 0 .. 63]) ``` +#### 3.50.1 Partial Packed `S = 64` Reductions + +This is the same `S = 4 * VLaneElems` lowering family as section 3.50, but it +covers `G` values that do not fill every packed group-result chunk. The key +point is that `slots = 8` is a physical capacity, not a promise that every +chunk contains eight valid group results. + +The result layout remains: + +```text +!pto.vmi.vreg<(G * 64)xf16, #pto.vmi.layout> +``` + +The lowering computes per result chunk: + +```text +K = 8 +chunk c active groups A(c) = min(K, G - c * K) + +source active lanes per deinterleaved part for chunk c: + A(c) * VLaneElems = A(c) * 16 f16 lanes + +reduce input mask: + PAT_VL(A(c) * 16) + +combine/store mask: + PAT_VL(A(c)) +``` + +For full chunks, `A(c) = 8`, so the reduce input mask is `PAT_ALL` for f16 +and the combine/store mask is `PAT_VL8`. For partial chunks, masks are +required for correctness. The semantic source mask produced by +`pto.vmi.create_group_mask` must also materialize only the valid source lanes; +the reduce lowering should not treat padding lanes as active data. + +##### `G = 4`: `256xf16, num_groups = 4` + +VMI-shaped input: + +```text +%x = pto.vmi.load %base[%off] + : memref<256xf16> -> !pto.vmi.vreg<256xf16> +%mask = pto.vmi.create_group_mask %c64 {num_groups = 4, group_size = 64} + : index -> !pto.vmi.mask<256xpred> +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 4, reassoc} +pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 4} +``` + +Assigned layouts: + +```text +%x, %mask: + #pto.vmi.layout + +%sum: + !pto.vmi.vreg<256xf16, #pto.vmi.layout> +``` + +VPTO lowering shape for the only result chunk: + +```text +%x_p0, %x_p1, %x_p2, %x_p3 = materialize deinterleaved=4, block_elems=8 input + : four !pto.vreg<128xf16> + +%lane64_b16 = pto.pge_b16 "PAT_VL64" // A * 16 = 4 * 16 +%slot4_b16 = pto.pge_b16 "PAT_VL4" + +%s0 = pto.vcgadd %x_p0, %lane64_b16 : !pto.vreg<128xf16> +%s1 = pto.vcgadd %x_p1, %lane64_b16 : !pto.vreg<128xf16> +%s2 = pto.vcgadd %x_p2, %lane64_b16 : !pto.vreg<128xf16> +%s3 = pto.vcgadd %x_p3, %lane64_b16 : !pto.vreg<128xf16> + +%s01 = pto.vadd %s0, %s1, %slot4_b16 : !pto.vreg<128xf16> +%s23 = pto.vadd %s2, %s3, %slot4_b16 : !pto.vreg<128xf16> +%sum0 = pto.vadd %s01, %s23, %slot4_b16 : !pto.vreg<128xf16> + +pto.vsts %sum0, %out[%group_off], %slot4_b16 {dist = "NORM_B16"} + : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +``` + +Memory result: + +```text +for r = 0..3: + out[group_off + r] = reduce_f16(base[off + r * 64 + 0 .. 63]) + +sum0 lanes 4..127 are not semantic for this VMI result. +``` + +##### `G = 8`: full packed chunk + +This is section 3.50. There is one result chunk with `A = 8`: + +```text +source mask = PAT_ALL // 8 * 16 = 128 f16 lanes +combine/store = PAT_VL8 +result layout = group_slots(num_groups=8, slots=8) +``` + +##### `G = 12`: full chunk plus partial chunk + +This case needs two packed result chunks: + +```text +result layout = group_slots(num_groups=12, slots=8) +result arity = ceil(12 / 8) = 2 +``` + +Chunk 0 handles groups `0..7`: + +```text +A(0) = 8 +source mask = PAT_ALL +combine/store = PAT_VL8 +``` + +Chunk 1 handles groups `8..11`: + +```text +A(1) = 4 +source mask = PAT_VL64 +combine/store = PAT_VL4 +``` + +Implementation checklist for this family: + +```text +layout attr: + slots=8 should be legal even when num_groups is not divisible by 8. + slot_block(g) = g / 8 and slot_lane(g) = g % 8 are still well-defined. + +layout assignment: + packed VCG-style group_reduce results keep slots=8. + +mask materialization: + create_group_mask must not activate padding lanes in partial chunks. + For chunk c, source active lanes are A(c) * VLaneElems. + +vmi-to-vpto group_reduce: + use A(c) from result layout slots and num_groups. + combine masks use PAT_VL(A(c)). + input vcgadd consumes the physical mask parts, which must already encode + PAT_VL(A(c) * VLaneElems) for all-true grouped masks. + +vmi-to-vpto group_store: + use A(c) to build the store predicate. + output group offset for chunk c is c * slots. +``` + ### 3.51 16-bit Typed Group Reduce, `S = L = 128` This is the first row-local full-physical-chunk case for both `f16` and `i16`. diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index b504de67f5..d3d2dc6b14 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -531,12 +531,11 @@ VMILayoutAttr::verify(function_ref emitError, if (blockElems != 1) return emitError() << "#pto.vmi.layout requires block_elems to be 1"; - if (slots < 0 || (slots != 0 && factor % slots != 0)) + if (slots < 0) return emitError() << "#pto.vmi.layout requires slots to be positive and divide num_groups when " - "specified"; + << "> requires slots to be omitted or positive"; return success(); } diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..e828ba6b2d --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16>, + %mask: !pto.vmi.mask<256xpred>) + -> !pto.vmi.vreg<256xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16>, + %mask: !pto.vmi.mask<768xpred>) + -> !pto.vmi.vreg<768xf16> { + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16>, !pto.vmi.mask<768xpred> + -> !pto.vmi.vreg<768xf16> + return %out : !pto.vmi.vreg<768xf16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( +// CHECK-SAME: %arg0: !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[SRC4:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: %[[MASK4_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// CHECK: %[[MASK4:.*]] = pto.vmi.ensure_mask_granularity %[[MASK4_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> +// CHECK: %[[OUT4:.*]] = pto.vmi.group_reduce_addf %[[SRC4]], %[[MASK4]] +// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK: return %[[OUT4]] + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( +// CHECK-SAME: %arg0: !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK-SAME: %arg1: !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: %[[SRC12:.*]] = pto.vmi.ensure_layout %arg0 +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: %[[MASK12_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 +// CHECK-SAME: -> !pto.vmi.mask<768xb32, #pto.vmi.layout> +// CHECK: %[[MASK12:.*]] = pto.vmi.ensure_mask_granularity %[[MASK12_LAYOUT]] +// CHECK-SAME: -> !pto.vmi.mask<768xb16, #pto.vmi.layout> +// CHECK: %[[OUT12:.*]] = pto.vmi.group_reduce_addf %[[SRC12]], %[[MASK12]] +// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK: return %[[OUT12]] diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto index f354adb6e8..1f4ccd2856 100644 --- a/test/lit/vmi/vmi_layout_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -10,9 +10,9 @@ module { func.func @vmi_layout_group_slots_invalid( - %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { return } } -// CHECK: #pto.vmi.layout requires slots to be positive and divide num_groups when specified +// CHECK: #pto.vmi.layout requires slots to be omitted or positive diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto new file mode 100644 index 0000000000..8efe26cf22 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_reduce_f16_s64_g4( + %source: !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 4, reassoc} + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.mask<256xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 4} + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g8( + %source: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<512xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 8, reassoc} + : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + !pto.vmi.mask<512xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_reduce_f16_s64_g12( + %source: !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<768xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_reduce_addf %source, %mask + {num_groups = 12, reassoc} + : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + !pto.vmi.mask<768xb16, #pto.vmi.layout> + -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> + pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 12} + : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g4( +// CHECK-DAG: %[[SLOT4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4]] +// CHECK: %[[STORE4:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g8( +// CHECK-DAG: %[[SLOT8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8]] +// CHECK: %[[STORE8:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_f16_s64_g12( +// CHECK: %[[SLOT8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT8_12]] +// CHECK: %[[SLOT4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK-COUNT-4: pto.vcgadd +// CHECK-COUNT-3: pto.vadd {{.*}}, {{.*}}, %[[SLOT4_12]] +// CHECK: %[[STORE8_12:.*]] = pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE8_12]] +// CHECK: %[[STORE4_12:.*]] = pto.pge_b16 "PAT_VL4" : !pto.mask +// CHECK: pto.vsts {{.*}}, {{.*}}, %[[STORE4_12]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto index 5798114cc7..b2001c29f0 100644 --- a/test/lit/vmi/vmi_type_attr_parse.pto +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -14,7 +14,9 @@ module attributes { pto.vmi_deinterleaved4 = #pto.vmi.layout, pto.vmi_deinterleaved4_block8 = #pto.vmi.layout, - pto.vmi_group_slots8 = #pto.vmi.layout + pto.vmi_group_slots8 = #pto.vmi.layout, + pto.vmi_group_slots_partial = + #pto.vmi.layout } { func.func @vmi_type_attr_parse( %surface: !pto.vmi.vreg<128xf32>, @@ -23,6 +25,7 @@ module attributes { %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %group_slots8: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %group_slots_partial: !pto.vmi.vreg<640xf32, #pto.vmi.layout>, %surface_mask: !pto.vmi.mask<128xpred>, %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, @@ -37,6 +40,7 @@ module attributes { // CHECK: pto.vmi_deinterleaved4 = #pto.vmi.layout // CHECK: pto.vmi_deinterleaved4_block8 = #pto.vmi.layout // CHECK: pto.vmi_group_slots8 = #pto.vmi.layout +// CHECK: pto.vmi_group_slots_partial = #pto.vmi.layout // CHECK-LABEL: func.func @vmi_type_attr_parse( // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> @@ -44,6 +48,7 @@ module attributes { // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<640xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> From 21510d9586992bd41af7fb60fd0b6c1bbc661e39 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:05:06 +0800 Subject: [PATCH 24/54] Support arith select in VPTO LLVM lowering --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 36 +++++++++++++ lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 36 +++++++++++++ test/lit/vpto/arith_select_vpto_llvm.pto | 54 +++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 test/lit/vpto/arith_select_vpto_llvm.pto diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 8362aea64b..4d4b82f5a8 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -9356,6 +9356,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10231,6 +10266,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 35f8cc51a3..bee22fed58 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -9300,6 +9300,41 @@ class ConvertVPTOUnrealizedCastOp final } }; +class ConvertArithSelectOp final : public OpConversionPattern { +public: + ConvertArithSelectOp(TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context, + PatternBenefit(2)) {} + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!hasVPTOConvertibleType(op->getOperandTypes()) && + !hasVPTOConvertibleType(op->getResultTypes())) + return failure(); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + op, "only scalar i1 conditions supported for VPTO arith.select"); + + Type convertedResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (!convertedResultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + Value trueValue = adaptor.getTrueValue(); + Value falseValue = adaptor.getFalseValue(); + if (trueValue.getType() != convertedResultType || + falseValue.getType() != convertedResultType) + return rewriter.notifyMatchFailure( + op, "converted true/false values must match result type"); + + rewriter.replaceOpWithNewOp( + op, convertedResultType, adaptor.getCondition(), trueValue, + falseValue); + return success(); + } +}; + class ConvertPtoAddPtrOp final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -10177,6 +10212,7 @@ static LogicalResult lowerVPTOTypes(ModuleOp module, llvm::raw_ostream &diagOS) patterns.add( typeConverter, context, state); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/test/lit/vpto/arith_select_vpto_llvm.pto b/test/lit/vpto/arith_select_vpto_llvm.pto new file mode 100644 index 0000000000..b32a7fe0de --- /dev/null +++ b/test/lit/vpto/arith_select_vpto_llvm.pto @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @arith_select_vreg(%cond: i1, %lhs_scalar: f32, %rhs_scalar: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b32 "PAT_ALL" : !pto.mask + %lhs = pto.vdup %lhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %rhs = pto.vdup %rhs_scalar, %mask + : f32, !pto.mask -> !pto.vreg<64xf32> + %chosen = arith.select %cond, %lhs, %rhs : !pto.vreg<64xf32> + pto.vsts %chosen, %dst[%c0], %mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } + + func.func @arith_select_mask(%cond: i1, %value: f32, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %all = pto.pset_b32 "PAT_ALL" : !pto.mask + %tail = pto.pge_b32 "PAT_VL4" : !pto.mask + %chosen_mask = arith.select %cond, %all, %tail : !pto.mask + %vec = pto.vdup %value, %all + : f32, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %vec, %dst[%c0], %chosen_mask + : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @arith_select_vreg_mix_aiv +// CHECK: %[[LHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[RHS:.*]] = llvm.call @llvm.hivm.vdups{{.*}} +// CHECK: %[[CHOSEN:.*]] = llvm.select %arg0, %[[LHS]], %[[RHS]] : i1, vector<64xf32> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32(%[[CHOSEN]] + +// CHECK-LABEL: llvm.func @arith_select_mask_mix_aiv +// CHECK: %[[ALL:.*]] = llvm.call @llvm.hivm.pset.b32 +// CHECK: %[[TAIL:.*]] = llvm.call @llvm.hivm.pge.b32 +// CHECK: %[[CHOSEN_MASK:.*]] = llvm.select %arg0, %[[ALL]], %[[TAIL]] : i1, vector<256xi1> +// CHECK: llvm.call @llvm.hivm.vstsx1.v64f32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[CHOSEN_MASK]]) From fe3aacb52cb1ba10dbddab0dbbe2631195ba93e9 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:09:44 +0800 Subject: [PATCH 25/54] Add VMI introduction design doc --- docs/designs/vmi-introduction.md | 658 +++++++++++++++++++++++++++++++ 1 file changed, 658 insertions(+) create mode 100644 docs/designs/vmi-introduction.md diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md new file mode 100644 index 0000000000..94ca638cb2 --- /dev/null +++ b/docs/designs/vmi-introduction.md @@ -0,0 +1,658 @@ +# VMI 介绍 + +本文介绍 VMI 的设计入口:VMI 解决什么问题,layout 有哪些,pass pipeline +如何分工,以及这些机制分别应对哪些典型场景。更完整的逐 case lowering 结果见 +`docs/designs/vmi-layout-lowering-cases.md`。 + +示例是设计级 IR,保留关键 type、layout、helper op 和 VPTO op 形状, +省略 module wrapper、完整 operand list 和不影响讨论的 SSA 细节。 + +## 1. VMI 表达什么 + +VMI 是 VPTO 之前的逻辑向量层。它让前端先表达“我要对 `NxT` 的逻辑向量做什么”, +再由 layout assignment 决定这个逻辑向量如何拆到 256B 物理 vector register 上。 + +Surface VMI 类型不携带布局: + +```mlir +!pto.vmi.vreg<128xf32> +!pto.vmi.mask<128xpred> +``` + +Layout-assigned VMI 类型携带具体布局和 mask granularity: + +```mlir +!pto.vmi.vreg<128xf32, #pto.vmi.layout> +!pto.vmi.mask<128xb32, #pto.vmi.layout> +``` + +VMI 的核心约束是:`vmi-to-vpto` 只从当前 op 的 attrs、operands、types、 +layouts 和显式 helper ops 做 lowering,不读取隐藏 plan/recipe,也不通过 +defining op 或 sibling user 恢复上下文。 + +## 2. Layout 类型 + +### 2.1 `contiguous` + +```mlir +#pto.vmi.layout +``` + +含义:logical lane 按顺序落入物理 register list。 + +```text +logical lanes: 0 1 2 ... 63 | 64 65 ... 127 +physical part: p0 | p1 +``` + +典型场景: + +```text +dense load/store +普通 elementwise compute +一个 group 天然适配当前 reduce op 时的 reduction input +caller/callee 约定 dense order 时的 control-flow/function boundary +``` + +### 2.2 `deinterleaved = F, block_elems = B` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +``` + +`block_elems` 缺省为 `1`。逻辑 lane 到物理 part 的映射是: + +```text +logical lane i +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +physical part p, physical lane t * B + r +``` + +`deinterleaved=2` 的直观例子: + +```text +logical lanes: 0 1 2 3 4 5 ... +physical part0: 0 2 4 ... +physical part1: 1 3 5 ... +``` + +`deinterleaved=4, block_elems=8` 的直观例子: + +```text +logical group S=32: + lanes 0.. 7 -> part0 lanes 0..7 + lanes 8..15 -> part1 lanes 0..7 + lanes 16..23 -> part2 lanes 0..7 + lanes 24..31 -> part3 lanes 0..7 +``` + +典型场景: + +```text +f16 -> f32: + vcvt 天然产生 even/odd 两个 f32 part,所以结果使用 deinterleaved=2。 + +f32 -> f16: + vcvt 需要 f32 source 先拆成 even/odd 两个 part,所以 source 使用 + deinterleaved=2。 + +S=32 group_reduce f32: + 一个 group 有 32 个 f32 element。高效 reduce path 消费四个 8-lane block, + 所以 source/mask 使用 deinterleaved=4, block_elems=8。 +``` + +### 2.3 `num_groups = G, slots = K` + +```mlir +#pto.vmi.layout +#pto.vmi.layout +``` + +这是 sparse group-result layout。它不表示全部 `N` 个 logical lane 都有语义值。 +只有 `G` 个 group 结果 slot 有语义值。 + +```text +slot_block(g) = g / K +slot_lane(g) = g % K + +physical part slot_block(g) 的 lane slot_lane(g) 保存 group g 的结果 +``` + +`num_groups=16, slots=8` 的例子: + +```text +part0 lane0..7 = group result 0..7 +part1 lane0..7 = group result 8..15 +other lanes = 对普通 dense consumer 来说未定义 +``` + +为什么 group 信息也要放进 layout: + +```text +group_reduce 自身有 num_groups,但它的结果可能继续跨过 truncf、 +group_broadcast、group_store、scf.if、scf.for、function call 或多个 consumer。 + +这些后续 op 不应该回看 producer attr。value layout 因此需要记录有多少个 +group result,以及这些 result 如何 packed 到 physical slot。 +``` + +典型场景: + +```text +group_reduce result +group_slot_load result +group_store input +group_broadcast input +group-slot control-flow/function boundary +部分 row-local cast 路径,通常使用 slots=1 +``` + +## 3. Pass Pipeline + +```text +pto-validate-vmi-ir + -> vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold-consumers + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> vmi-legalize-arith-select + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto +``` + +### 3.1 `pto-validate-vmi-ir` + +检查 surface VMI 边界。 + +合法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> +``` + +非法输入: + +```mlir +%x = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +``` + +原因:具体 layout 由 `vmi-layout-assignment` 产生,不应该由 surface frontend +提前写入。 + +### 3.2 `vmi-layout-assignment` + +这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, +并在 layout 不匹配的 use site 插入显式 helper op。 + +例子:`f16 -> f32 -> store`。 + +Surface VMI: + +```mlir +%x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<128xf16> +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> +pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<128xf32>, !pto.ptr +``` + +Assignment 之后: + +```mlir +%x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + +%x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%x32_dense = pto.vmi.ensure_layout %x32 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +pto.vmi.store %x32_dense, %dst[%off] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +``` + +即使不跑任何优化 pass,这个 assignment 后的 IR 也已经是正确可降的。 + +### 3.3 `vmi-layout-fold-consumers` + +当 consumer 可以直接保持同样的外部效果时,把显式 materialization 折进 +consumer。 + +变换前: + +```mlir +%dense = pto.vmi.ensure_layout %x + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +pto.vmi.store %dense, %dst[%off] +``` + +变换后: + +```mlir +pto.vmi.store %x, %dst[%off] + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +``` + +可能的 VPTO 形状: + +```text +fold 前:vintlv + vsts + vsts +fold 后:vstsx2,使用交错 store mode +``` + +### 3.4 `vmi-layout-rematerialize` + +通过 clone 低成本、layout-polymorphic 的 producer 来替换 `ensure_*`。 + +变换前: + +```mlir +%s = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +%s_split = pto.vmi.ensure_layout %s + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +变换后: + +```mlir +%s_split = pto.vmi.broadcast %scale + : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +预期可 rematerialize 的 producer: + +```text +splat constant +broadcast +iota +create_mask +create_group_mask +constant_mask +``` + +这个 pass 不 rematerialize: + +```text +load / masked_load / group_load / group_slot_load +reduce / group_reduce +control-flow results +``` + +### 3.5 `vmi-layout-sink-materialization` + +把匹配的 layout 转换跨过 layout-transparent elementwise op。 + +变换前: + +```mlir +%a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous +%b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous +%y_dense = pto.vmi.addf %a_dense, %b_dense : contiguous +``` + +变换后: + +```mlir +%y_split = pto.vmi.addf %a, %b : deinterleaved=2 +%y_dense = pto.vmi.ensure_layout %y_split : deinterleaved=2 -> contiguous +``` + +效果: + +```text +两个 input materialization -> 一个 result materialization +``` + +这个 pass 不会 sink 穿过 cast、load、store、reduce、group_broadcast 或 +control-flow op。 + +### 3.6 `vmi-legalize-arith-select` + +Canonicalization 可能把简单的 `scf.if` 折成 `arith.select`。VMI 希望把 +control-flow lowering 保持在结构化控制流里,所以这个 pass 会把 VMI value 上的 +`arith.select` 改回 `scf.if`。 + +```mlir +%r = arith.select %cond, %a, %b + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +``` + +改成: + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + scf.yield %a : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} else { + scf.yield %b : !pto.vmi.vreg<128xf32, #pto.vmi.layout> +} +``` + +### 3.7 `pto-validate-vmi-layout-ir` + +检查 post-assignment gate: + +```text +每个 VMI 数据值都有 concrete layout +每个 VMI mask 都有 concrete granularity 和 layout +helper op 有支持的 materialization path +semantic op/layout 组合有支持的 local lowering +vmi-to-vpto 之前没有物理 VPTO value 泄漏到 VMI IR 中 +``` + +非法例子: + +```mlir +%sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : ... -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +pto.vmi.store %sum, %dst[%off] + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr +``` + +原因: + +```text +dense store 不能把 sparse group_slots 当 dense vector 读取。 +应使用 group_store、group_broadcast 或显式支持的 group-to-dense op。 +``` + +### 3.8 `vmi-to-vpto` + +把 layout-assigned VMI value 转换成有序物理 VPTO value 列表,并对每个 +VMI op 做 local lowering。 + +例子: + +```text +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + +!pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> 两个 physical !pto.vreg<64xf32> part + part0 携带 even lanes,part1 携带 odd lanes + +!pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> 四个 physical part + part0 携带 group 0..7,part1 携带 group 8..15,... +``` + +`VMILayoutSupport` 不是 pass。它是 assignment、validation、optimization 和 +lowering 共享的查询库,用来避免重复实现 layout fact 和 supported +materialization 检查。 + +## 4. 典型场景 + +### 4.1 Dense Cast 与 Store + +```text +surface: + load f16,语义上连续 + extf 到 f32 + dense store f32 + +assignment: + load result = contiguous + extf result = deinterleaved=2 + store use = ensure_layout(deinterleaved=2 -> contiguous) + +baseline VPTO: + vlds + vcvt even / vcvt odd + vintlv + vsts + vsts + +fold-consumers 后的优化 VPTO: + vlds + vcvt even / vcvt odd + vstsx2,使用 interleaving store +``` + +这个场景说明为什么需要 `deinterleaved=2`,以及为什么 store-consumer folding +有价值。 + +### 4.2 Narrow Cast 与 Store + +```text +surface: + load f32 + truncf 到 f16 + dense store f16 + +assignment: + load result = deinterleaved=2 + truncf result = contiguous + +VPTO: + vldsx2 deinterleaving load + vcvt even / vcvt odd + vor + vsts +``` + +这个场景说明 memory op 可以直接产生 consumer 需要的 layout,但不需要保存隐藏 +plan。 + +### 4.3 一个 Producer 同时服务 Dense 和 Group Consumer + +```mlir +%x32 = pto.vmi.extf %x16 +%sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} +pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} +pto.vmi.store %x32, %dense_out[%off] +``` + +Assignment 形状: + +```text +%x32 layout = deinterleaved=2 +group_reduce 直接消费 %x32 +dense store 获得 ensure_layout(%x32 -> contiguous) +``` + +VPTO 形状: + +```text +vcvt even/odd +vcgadd + vcgadd + vadd -> group_store result +vintlv + dense stores -> 产生 dense store 结果 +``` + +这个场景说明为什么需要 use-site materialization。producer 不需要选择一个能同时 +满足所有 consumer 的唯一 layout。 + +### 4.4 按 Group Size 区分的 Group Reduce + +对于 `N` 个 f32 lane 和 `G = num_groups`,group size 是 `S = N / G`。 + +```text +S=8: + input layout 可以是 contiguous。 + group_reduce result 通常使用 layout。 + +S=16: + 如果 input 来自 f16->f32 vcvt,layout 可以是 deinterleaved=2。 + 如果 input 从 dense 拆出,layout 可以是 deinterleaved=2, block_elems=8。 + result 通常使用 layout。 + +S=32: + input layout 使用 deinterleaved=4, block_elems=8。 + VPTO 形状是四个部分 group reduction 后接 add tree。 + result 通常使用 layout。 + +S=64: + row-local path 在可行时让每个 group 使用一条 physical row。 + result 可以使用 layout,避免 unsupported packing。 +``` + +S=32 例子: + +```text +assignment: + source/mask = deinterleaved=4, block_elems=8 + result = group_slots(num_groups=8, slots=8) + +VPTO: + vdintlv / pdintlv_b32 + vcgadd x4 + 使用 PAT_VL8 做 vadd tree + 通过一次 PAT_VL8 store 完成 group_store +``` + +这个场景说明为什么需要 `block_elems`。 + +### 4.5 Group Result 继续作为 Dense Rows 使用 + +Surface 意图: + +```mlir +%sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} +%sum16 = pto.vmi.truncf %sum32 +%rows16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +pto.vmi.store %rows16, %dst[%off] +``` + +支持的 assignment 形状: + +```mlir +%sum32 = pto.vmi.group_reduce_addf ... + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + +%rows32_split = pto.vmi.ensure_layout %rows32 + : contiguous -> deinterleaved=2 + +%rows16 = pto.vmi.truncf %rows32_split + : deinterleaved=2 -> contiguous + +pto.vmi.store %rows16, %dst[%off] +``` + +VPTO 形状: + +```text +group_reduce: + vcgadd partials + vadd tree + +group_broadcast: + vselr 风格 selection,把 group slots 展开到 dense row lanes + +truncf: + vcvt even/odd + merge + +store: + vsts +``` + +这个场景说明为什么 group 结果 layout 必须挂在 value 上:reduce 之后, +cast 和 broadcast 必须知道 group 结果在哪里,而不能回看 producer。 + +### 4.6 通过 Mask 表达 Tail + +VMI 通过 mask 表达 tail,不通过 padding 表达 tail。 + +```mlir +%mask = pto.vmi.create_mask %active_lanes +%x = pto.vmi.masked_load %src[%off], %mask +%y = pto.vmi.mulf %x, %scale +pto.vmi.masked_store %y, %dst[%off], %mask +``` + +Grouped tail: + +```mlir +%gmask = pto.vmi.create_group_mask %active_elems_per_group + {num_groups = 8, group_size = 32} +%sum = pto.vmi.group_reduce_addf %x, %gmask {num_groups = 8, reassoc} +``` + +同一个 semantic mask 面对 f8/f16/f32 user 时,可能需要不同 concrete +granularity。Assignment 会通过 mask helper op 显式表达这些转换。 + +### 4.7 控制流和函数边界 + +Concrete layout 必须显式跨过 CFG 和内部 function boundary。 + +```mlir +%r = scf.if %cond + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + %a_dense = pto.vmi.ensure_layout %a : deinterleaved=2 -> contiguous + scf.yield %a_dense +} else { + %b_dense = pto.vmi.ensure_layout %b : deinterleaved=2 -> contiguous + scf.yield %b_dense +} +``` + +`vmi-to-vpto` 之后,region result 会变成多个物理 VPTO value: + +```text +scf.if -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) +``` + +这个场景说明为什么 layout 应该是 type 的一部分,而不是依赖 defining op。 + +## 5. 当前边界 + +当前设计方向: + +```text +surface VMI: + 描述不带 layout 的逻辑向量语义。 + +layout assignment: + 选择 layout、mask granularity 和显式 materialization helper。 + +optimization: + 只在结果 IR 仍然可以 local lowering 时改写显式 helper。 + +vmi-to-vpto: + 严格 lower 它看到的 assigned/optimized IR。 +``` + +暂不支持或有意收紧的范围: + +```text +group_slots value 的普通 dense store: + 非法,除非先经过 group_broadcast 或其他显式 group-to-dense op。 + +packed group_slots f32->f16 cast: + 非法,除非 assignment 能把它 commute 到 group_broadcast 之后,或者使用 + 支持的 row-local slots=1 path。 + +extract: + 暂不作为支持的 VMI surface。 + +padding transfer_read: + 当前 tail 设计不需要;tail 使用 mask。 + +scan / contract / gather / scatter / compress / active_prefix_index: + dialect surface 中可以存在,但除非补充具体 case,否则不属于第一阶段聚焦的 + layout/lowering 实现集合。 +``` + +设计目标是优先保证语义完整:只要 VMI 接受某个 case,所需的 layout 沟通就必须 +在 IR 中显式表达,并且能被 `vmi-to-vpto` local lowering。 From 40376d799aca19e438bc946863a88ff35b940d9f Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:59:52 +0800 Subject: [PATCH 26/54] Fold deinterleaved VMI loads through vldsx2 --- docs/designs/vmi-introduction.md | 59 +++++++++------- docs/designs/vmi-layout-lowering-cases.md | 70 ++++++++++++++++--- lib/PTO/Transforms/VMIToVPTO.cpp | 62 ++++++++++++++++ ..._layout_assignment_f32_f8_store_reduce.pto | 4 +- test/lit/vmi/vmi_to_vpto_load_deint.pto | 12 ++-- .../vmi/vmi_to_vpto_load_deint_multichunk.pto | 36 ++++++++++ test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 10 +-- test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 8 +-- 8 files changed, 204 insertions(+), 57 deletions(-) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 94ca638cb2..fb1f1b7135 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -106,6 +106,16 @@ S=32 group_reduce f32: 所以 source/mask 使用 deinterleaved=4, block_elems=8。 ``` +`block_elems=8` 表示一种按 32B row fragment 组织的输入形态,不表示 +S=32 reduce 只能接受这一种形态。如果同一个 value 还要服务 narrow cast 等 +element-parity consumer,assignment 可以选择 `deinterleaved=4, block_elems=1` +作为共同 layout,再由 lowering 生成对应的物理指令序列。 + +`deinterleaved` 只描述最终物理 part 中有哪些 logical lane,不描述这个 layout +由哪条指令生成。不同 producer 可以用不同方式直接产生同一个 layout;如果不能 +直接产生,后续 lowering 再通过显式 materialization helper 把 source layout +转换成 consumer 需要的 layout。具体 lowering 形状见 case catalog。 + ### 2.3 `num_groups = G, slots = K` ```mlir @@ -196,39 +206,40 @@ pto-validate-vmi-ir 这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, 并在 layout 不匹配的 use site 插入显式 helper op。 -例子:`f16 -> f32 -> store`。 +实现上它维护 data 和 mask 两套求解状态: -Surface VMI: +```text +data value: + 每个 !pto.vmi.vreg 是一个节点,节点记录最终选择的布局。 -```mlir -%x16 = pto.vmi.load %src[%off] - : !pto.ptr -> !pto.vmi.vreg<128xf16> -%x32 = pto.vmi.extf %x16 - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> -pto.vmi.store %x32, %dst[%off] - : !pto.vmi.vreg<128xf32>, !pto.ptr +mask value: + 每个 !pto.vmi.mask 是一个节点,节点记录最终选择的布局和 predicate 粒度。 ``` -Assignment 之后: - -```mlir -%x16 = pto.vmi.load %src[%off] - : !pto.ptr - -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +data value 使用 union-find 表示“这些 value 必须共用 layout”。函数参数、 +call operand/result、return/yield、block argument、bitcast 等边界会把相关 +value 合并到同一个等价类里。等价类只能有一个最终 data layout。 -%x32 = pto.vmi.extf %x16 - : !pto.vmi.vreg<128xf16, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +assignment 遍历 IR 时,每类 op 向求解器贡献两种信息: -%x32_dense = pto.vmi.ensure_layout %x32 - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +```text +result 自然布局: + 这个 op 自己产生的 result 适合用什么 layout 表达。 -pto.vmi.store %x32_dense, %dst[%off] - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +operand 使用请求: + 这个 op 消费某个 operand 时希望 operand 是什么 layout。 ``` -即使不跑任何优化 pass,这个 assignment 后的 IR 也已经是正确可降的。 +有些 producer 生成的是同一个逻辑向量,但可以用多种物理 layout 表达。若它的 +所有 consumer 给出的使用请求一致,assignment 会把这个请求反推为 producer +result 的最终布局。否则,producer 保持自己的布局,assignment 在不匹配的 use +site 插入 `pto.vmi.ensure_layout`。mask 使用同样思路,但还会同时求解 predicate +粒度,必要时插入 `ensure_mask_layout` 或 `ensure_mask_granularity`。 + +最后,pass 会把所有 VMI data/mask type 改写成带 layout 的 type,并同步更新 +function type、call site、block argument 和 terminator operand。这个阶段之后, +IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type 上的 layout +和显式 `ensure_*` helper。 ### 3.3 `vmi-layout-fold-consumers` diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 5a007987a7..855a7a486f 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -131,6 +131,52 @@ group_store future explicit group-pack op ``` +Contiguous memory loads may produce a non-contiguous physical value directly +when the requested result layout is a dense deinterleaved layout. This is a +lowering choice, not a separate layout family. + +```text +pto.vmi.load -> #pto.vmi.layout + lower as: + vlds NORM for each physical chunk + +pto.vmi.load -> #pto.vmi.layout + lower as: + vldsx2 DINTLV_B* for each pair of physical chunks + +pto.vmi.load -> #pto.vmi.layout + lower as: + two vldsx2 DINTLV_B* operations for each four-chunk group + followed by two vdintlv operations to split mod4 parts + +pto.vmi.load -> #pto.vmi.layout + lower using the producer-specific path or fall back to explicit + materialization. Do not treat DINTLV_B* as a block-fragment layout. +``` + +The `deinterleaved = 4` result order remains the normal VMI physical part +order: + +```text +results = [part0 chunks..., part1 chunks..., part2 chunks..., part3 chunks...] +``` + +For one full `256xf32` tile: + +```text +%even0, %odd0 = pto.vldsx2 %base[%off0], "DINTLV_B32" +%even1, %odd1 = pto.vldsx2 %base[%off128], "DINTLV_B32" + +%part0, %part2 = pto.vdintlv %even0, %even1 +%part1, %part3 = pto.vdintlv %odd0, %odd1 + +replace pto.vmi.load with [%part0, %part1, %part2, %part3] +``` + +This optimization is legal only for full physical chunks and supported +`DINTLV_B8/B16/B32` element widths. Tail and masked loads keep their explicit +safe lowering until a masked or guarded `vldsx2` strategy is designed. + ## 3. Lowering Results The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true @@ -3927,18 +3973,14 @@ VPTO lowering result: %all_b32 = pto.pge_b32 "PAT_ALL" %sum_mask = pto.pge_b32 "PAT_VL8" -%x0 = pto.vlds %base[%off] : memref<256xf32> -> !pto.vreg<64xf32> -%x1 = pto.vlds %base[%off_plus_64] : memref<256xf32> -> !pto.vreg<64xf32> -%x2 = pto.vlds %base[%off_plus_128] : memref<256xf32> -> !pto.vreg<64xf32> -%x3 = pto.vlds %base[%off_plus_192] : memref<256xf32> -> !pto.vreg<64xf32> +%x_even_0, %x_odd_0 = pto.vldsx2 %base[%off], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +%x_even_1, %x_odd_1 = pto.vldsx2 %base[%off_plus_128], "DINTLV_B32" + : memref<256xf32>, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x01_lo, %x01_hi = pto.vdintlv %x0, %x1 - : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x23_lo, %x23_hi = pto.vdintlv %x2, %x3 - : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo +%x_p0, %x_p2 = pto.vdintlv %x_even_0, %x_even_1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -%x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi +%x_p1, %x_p3 = pto.vdintlv %x_odd_0, %x_odd_1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> %s0 = pto.vcgadd %x_p0, %all_b32 : !pto.vreg<64xf32> @@ -3991,6 +4033,14 @@ The common layout selected for `%x32` is `truncf f32 -> f8` and S=32 `group_reduce_addf`. A later strided block-load producer may introduce `block_elems = 8`, but that is a different case and requires an explicit materialization/rematerialization decision. + +When `%x32` is produced by a full contiguous `pto.vmi.load`, `vmi-to-vpto` +should not first materialize four contiguous f32 chunks and then run a full +four-op `vdintlv` tree. The load lowering should fold the first deinterleave +level into two `vldsx2 DINTLV_B32` operations and then run only the second +`vdintlv` level, as shown above. The layout remains just +`deinterleaved = 4, block_elems = 1`; it does not encode the fact that `vldsx2` +was used. ``` ### 3.33 One Dense Value Feeding S=16 And S=32 Reduces diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 7f10e39ea6..ea286520bf 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -3794,6 +3794,68 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } } + if (resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getFactor() == 4 && resultLayout.getBlockElems() == 1) { + std::optional dist = + getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); + if (dist && !resultTypes.empty() && resultTypes.size() % 4 == 0) { + int64_t groups = resultTypes.size() / 4; + SmallVector part0; + SmallVector part1; + SmallVector part2; + SmallVector part3; + part0.reserve(groups); + part1.reserve(groups); + part2.reserve(groups); + part3.reserve(groups); + for (int64_t group = 0; group < groups; ++group) { + Type part0Type = resultTypes[group]; + Type part1Type = resultTypes[groups + group]; + Type part2Type = resultTypes[2 * groups + group]; + Type part3Type = resultTypes[3 * groups + group]; + if (part0Type != part1Type || part0Type != part2Type || + part0Type != part3Type) + return rewriter.notifyMatchFailure( + op, "vldsx2 deinterleaved=4 load requires matching part " + "types"); + + Value firstOffset = createChunkOffset( + op.getLoc(), *offset, group * 4 * *lanesPerPart, rewriter); + Value secondOffset = createChunkOffset( + op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, + rewriter); + auto first = + rewriter.create(op.getLoc(), part0Type, part1Type, + *source, firstOffset, + rewriter.getStringAttr(*dist)); + auto second = + rewriter.create(op.getLoc(), part2Type, part3Type, + *source, secondOffset, + rewriter.getStringAttr(*dist)); + + auto even = rewriter.create( + op.getLoc(), part0Type, part2Type, first.getLow(), + second.getLow()); + auto odd = rewriter.create( + op.getLoc(), part1Type, part3Type, first.getHigh(), + second.getHigh()); + part0.push_back(even.getLow()); + part1.push_back(odd.getLow()); + part2.push_back(even.getHigh()); + part3.push_back(odd.getHigh()); + } + + SmallVector results; + results.reserve(resultTypes.size()); + results.append(part0); + results.append(part1); + results.append(part2); + results.append(part3); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + } + SmallVector contiguousParts; contiguousParts.reserve(resultTypes.size()); for (auto [index, resultType] : llvm::enumerate(resultTypes)) { diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index b976ab518d..8dfe2292cf 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -48,8 +48,8 @@ module { // ASSIGN: pto.vmi.store %[[X8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( -// LOWER-COUNT-4: pto.vlds -// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-2: pto.vldsx2 +// LOWER-COUNT-2: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd // LOWER-COUNT-3: pto.vadd // LOWER: pto.vsts diff --git a/test/lit/vmi/vmi_to_vpto_load_deint.pto b/test/lit/vmi/vmi_to_vpto_load_deint.pto index 715dacdfa6..0f3c3f825a 100644 --- a/test/lit/vmi/vmi_to_vpto_load_deint.pto +++ b/test/lit/vmi/vmi_to_vpto_load_deint.pto @@ -39,14 +39,10 @@ module { // CHECK-NOT: unrealized_conversion_cast // CHECK-LABEL: func.func @vmi_to_vpto_load_deint4( -// CHECK: %[[D0:.*]] = pto.vlds %arg0[%arg1] -// CHECK: %[[D1:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[D2:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[D3:.*]] = pto.vlds %arg0[{{.*}}] -// CHECK: %[[A0:.*]], %[[B0:.*]] = pto.vdintlv %[[D0]], %[[D1]] -// CHECK: %[[A1:.*]], %[[B1:.*]] = pto.vdintlv %[[D2]], %[[D3]] -// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[A0]], %[[A1]] -// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[B0]], %[[B1]] +// CHECK: %[[E0:.*]], %[[O0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1:.*]], %[[O1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0:.*]], %[[P2:.*]] = pto.vdintlv %[[E0]], %[[E1]] +// CHECK: %[[P1:.*]], %[[P3:.*]] = pto.vdintlv %[[O0]], %[[O1]] // CHECK: return %[[P0]], %[[P1]], %[[P2]], %[[P3]] // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto index 433f222af3..200a1af04e 100644 --- a/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto +++ b/test/lit/vmi/vmi_to_vpto_load_deint_multichunk.pto @@ -20,6 +20,28 @@ module { return %p0, %p1, %p2, %p3 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> } + + func.func @vmi_to_vpto_load_deint4_multichunk( + %src: !pto.ptr, %offset: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %value = pto.vmi.load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 = + "pto.vmi.unpack"(%value) + : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0_0, %p0_1, %p1_0, %p1_1, %p2_0, %p2_1, %p3_0, %p3_1 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32>, + !pto.vreg<64xf32>, !pto.vreg<64xf32> + } } // CHECK-LABEL: func.func @vmi_to_vpto_load_deint2_multichunk( @@ -29,3 +51,17 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_load_deint4_multichunk( +// CHECK: %[[E0_0:.*]], %[[O0_0:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// CHECK: %[[E1_0:.*]], %[[O1_0:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_0:.*]], %[[P2_0:.*]] = pto.vdintlv %[[E0_0]], %[[E1_0]] +// CHECK: %[[P1_0:.*]], %[[P3_0:.*]] = pto.vdintlv %[[O0_0]], %[[O1_0]] +// CHECK: %[[E0_1:.*]], %[[O0_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[E1_1:.*]], %[[O1_1:.*]] = pto.vldsx2 %arg0[{{.*}}], "DINTLV_B32" +// CHECK: %[[P0_1:.*]], %[[P2_1:.*]] = pto.vdintlv %[[E0_1]], %[[E1_1]] +// CHECK: %[[P1_1:.*]], %[[P3_1:.*]] = pto.vdintlv %[[O0_1]], %[[O1_1]] +// CHECK: return %[[P0_0]], %[[P0_1]], %[[P1_0]], %[[P1_1]], %[[P2_0]], %[[P2_1]], %[[P3_0]], %[[P3_1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index dd69bcfaa2..c3a1a0fede 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -287,15 +287,11 @@ module { // CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto index c44de2ec84..01e92013cc 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -30,12 +30,8 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> From f1a643fecd6a4aadde39c58014172eba2ce0a239 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 17:31:33 +0800 Subject: [PATCH 27/54] Document VMI layout assignment mechanism --- docs/designs/vmi-introduction.md | 225 ++++++++++++++++++++++++++++--- 1 file changed, 203 insertions(+), 22 deletions(-) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index fb1f1b7135..33852a69ce 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -206,40 +206,221 @@ pto-validate-vmi-ir 这是硬合法化 pass。它选择具体 value layout、具体 mask granularity, 并在 layout 不匹配的 use site 插入显式 helper op。 -实现上它维护 data 和 mask 两套求解状态: +这个 pass 的工作顺序是固定的: ```text -data value: - 每个 !pto.vmi.vreg 是一个节点,节点记录最终选择的布局。 +1. 做少量 VMI 内部规整,让后续 layout 规则面对稳定形态。 +2. 为 data value 建 union-find 求解器,并收集 data 约束和 data use request。 +3. 把可采纳的 consumer request 提升为 producer/result 的最终 layout。 +4. 改写所有 data value type,让 !pto.vmi.vreg 携带具体 layout。 +5. 对仍不匹配的 data use 插入 pto.vmi.ensure_layout。 +6. 基于已经确定的 data layout 推导 mask layout 和 predicate granularity。 +7. 改写所有 mask type,并对不匹配的 mask use 插入 ensure_mask_*。 +8. 同步更新 function type、call boundary 和 block argument type。 +9. 校验 layout-assigned VMI IR。 +``` + +Data 和 mask 分两轮求解。原因是 mask layout 通常依赖对应 data operand 或 result +的 layout;例如 `cmpf` 产生的 mask 跟比较输入的 data layout 对齐, +`select`/`reduce`/`masked_load` 消费的 mask 也要跟对应 data value 的 lane +layout 和元素 bitwidth 对齐。 + +Data 求解器为每个 `!pto.vmi.vreg` 建一个节点: -mask value: - 每个 !pto.vmi.mask 是一个节点,节点记录最终选择的布局和 predicate 粒度。 +```text +DataNode: + value = 对应 SSA value + original type = surface VMI type + parent = union-find parent + naturalLayout = 当前等价类选择的自然 layout,可能为空 ``` -data value 使用 union-find 表示“这些 value 必须共用 layout”。函数参数、 -call operand/result、return/yield、block argument、bitcast 等边界会把相关 -value 合并到同一个等价类里。等价类只能有一个最终 data layout。 +遍历 IR 时,每个 op 向 data 求解器贡献三类信息。 -assignment 遍历 IR 时,每类 op 向求解器贡献两种信息: +第一类是 layout 等价约束。它表示几个 value 必须使用同一个 physical layout, +也就是 union-find 中的同一个等价类。典型来源: ```text -result 自然布局: - 这个 op 自己产生的 result 适合用什么 layout 表达。 +layout-transparent elementwise: + addf/addi/subf/subi/mulf/muli/fma/divf/minf/maxf/... + L(operands...) = L(result) + +unary elementwise: + negf/absf/absi/sqrt/exp/ln/relu/not + L(source) = L(result) + +select: + L(true_value) = L(false_value) = L(result) + +bitcast: + L(source) = L(result) + +structured control flow: + scf.if result = then/else yield operand + scf.for result = init operand = iter_arg = yield operand + scf.while result = init/before/condition/after/yield carried value + +cf branch: + branch operand = destination block argument + +function boundary: + call operand = callee argument + call result = callee return operand + multiple returns of the same function agree per result index +``` + +这一步只说明“这些 value 如果存在布局,就必须一致”。它不等价于把某个 +consumer 的 request 无条件推过所有 producer 或控制流。 + +第二类是 result 自然布局。某些 op 的结果本身有目标相关的自然布局: + +```text +普通 reduce / compress / shuffle: + result 通常是 contiguous。 + +group_reduce: + source 需要适配 group reduce 指令形态; + result 使用 group_slots(num_groups, slots) 描述 sparse group result。 + +cast: + widening/narrowing 根据 cast support 决定 source request 和 result layout。 + +group_load / group_slot_load: + result 根据 group size、row stride 和目标能力选择 contiguous、deinterleaved + 或 group_slots。 + +active_prefix_index: + result 使用 contiguous。 +``` + +若同一个等价类已经有自然布局,再设置不同自然布局会报 layout contract 冲突。 + +第三类是 operand 使用请求。consumer 不直接修改 operand 的 type,而是记录 +“这个 use site 希望 operand 是什么 layout”: + +```text +store / tile_write / masked_store value: + wants contiguous + +ordinary reduce source/init: + wants contiguous + +group_reduce source: + wants preferred group-reduce source layout + +group_store value: + wants preferred group result layout + +truncf/trunci/extf/extsi/extui source: + wants cast support 给出的 source layout + +channel_split / channel_merge / shuffle: + wants 各自 lowering 需要的 source/input layout +``` + +收集完这些信息后,assignment 才尝试做 consumer-driven adoption。它逐个查看 +use request:如果 operand 的 producer 可以直接用 consumer 需要的 layout 产生 +同一个逻辑向量,并且多 use 时所有 use 都请求同一个 layout,那么这个 request +会被提升为该 value 所在 data 等价类的最终 layout。 + +可采纳 producer 是受限集合: + +```text +load / tile_read +broadcast / constant / iota +layout-transparent elementwise +select +bitcast +``` + +这就是 request 看起来能穿过 elemwise 的原因: + +```mlir +%x = pto.vmi.load ... +%k = pto.vmi.broadcast ... +%y = pto.vmi.mulf %x, %k +%q = pto.vmi.truncf %y +``` + +`mulf` 先把 `%x`、`%k`、`%y` 合成同一个 data 等价类。`truncf` 对 `%y` +的 source use 请求 `deinterleaved=4` 时,这个 request 作用到 `%y` 所在等价类; +因为 `mulf` 是可采纳 producer,assignment 可以把整个等价类选成 +`deinterleaved=4`,从而让 load/broadcast/mulf 直接在这个 layout 下产生数据。 + +控制流边界也会形成等价类,但它不是任意 request 的自动传播通道: + +```mlir +%y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a +} else { + scf.yield %b +} +%q = pto.vmi.truncf %y +``` + +`%y`、`%a`、`%b` 的 layout 必须一致;但 `scf.if` result 本身不是 +consumer-driven adoption 的可采纳 producer。若 `%q` 需要的 layout 无法成为 +这个等价类的最终布局,assignment 会在 `%q` 的 use site 插 +`pto.vmi.ensure_layout`,而不是隐式重写两个 branch 的内部计算。 + +Data layout 确定后,pass 会把每个 `!pto.vmi.vreg` 改写成 +`!pto.vmi.vreg`。如果某个记录过的 use request 仍然和 operand +当前 layout 不一致,pass 在该 consumer 前插显式 materialization: + +```mlir +%x_req = pto.vmi.ensure_layout %x + : !pto.vmi.vreg + -> !pto.vmi.vreg +consumer %x_req +``` + +这个规则也处理多 consumer 冲突: + +```mlir +%y = pto.vmi.mulf %x, %k +pto.vmi.store %y, %out0 // wants contiguous +%q = pto.vmi.truncf %y // wants deinterleaved=4 source +``` + +一个 SSA value 只能属于一个 data layout 等价类。若两个 use 不能共同满足, +baseline assignment 保留一个等价类 layout,并在不匹配 use 前插 +`ensure_layout`。后续 `vmi-layout-fold-consumers`、`vmi-layout-rematerialize` +和 `vmi-layout-sink-materialization` 可以在显式 helper op 上做优化,但 +`vmi-to-vpto` 不读取隐藏 plan 或 sibling user。 + +Mask 求解发生在 data type 改写之后。它同样维护 union-find 等价类,但节点记录 +两件事: + +```text +mask layout +predicate granularity: b8 / b16 / b32 +``` + +mask request 从已经带 layout 的 data value 推导: + +```text +cmpf/cmpi result: + mask layout = lhs data layout + granularity = lhs element bitwidth 对应的 predicate 粒度 + +select mask: + mask layout = result data layout + granularity = result element bitwidth 对应的 predicate 粒度 -operand 使用请求: - 这个 op 消费某个 operand 时希望 operand 是什么 layout。 +reduce / group_reduce / masked_load / expand_load mask: + mask layout = source/result data layout + granularity = 对应 data element bitwidth 的 predicate 粒度 ``` -有些 producer 生成的是同一个逻辑向量,但可以用多种物理 layout 表达。若它的 -所有 consumer 给出的使用请求一致,assignment 会把这个请求反推为 producer -result 的最终布局。否则,producer 保持自己的布局,assignment 在不匹配的 use -site 插入 `pto.vmi.ensure_layout`。mask 使用同样思路,但还会同时求解 predicate -粒度,必要时插入 `ensure_mask_layout` 或 `ensure_mask_granularity`。 +若 mask use 的 layout 或 granularity 不匹配,pass 显式插 +`pto.vmi.ensure_mask_layout` 或 `pto.vmi.ensure_mask_granularity`。 -最后,pass 会把所有 VMI data/mask type 改写成带 layout 的 type,并同步更新 -function type、call site、block argument 和 terminator operand。这个阶段之后, -IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type 上的 layout -和显式 `ensure_*` helper。 +完成 data/mask 改写和 helper 插入后,pass 会同步更新 function type。直接 +internal call 会把 call operand/result 与 callee argument/return operand 合成 +同一布局约束;带 VMI type 的 external declaration 或 indirect call 没有可见 +body,当前需要显式 ABI materialization 设计,因此 layout assignment 会拒绝。 +这个阶段之后,IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type +上的 layout 和显式 `ensure_*` helper。 ### 3.3 `vmi-layout-fold-consumers` From 8fc9c04c0aacd69f8824e836ea624575df7176da Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 23 Jun 2026 18:23:51 +0800 Subject: [PATCH 28/54] Illustrate VMI layout equivalence classes --- docs/designs/vmi-introduction.md | 85 ++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 33852a69ce..dbf09230f1 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -272,6 +272,91 @@ function boundary: 这一步只说明“这些 value 如果存在布局,就必须一致”。它不等价于把某个 consumer 的 request 无条件推过所有 producer 或控制流。 +等价类可以画成“同一个框里的 value 共用一个 layout 变量”。例如普通 +elementwise 链: + +```text +surface VMI: + + %x = pto.vmi.load ... + %k = pto.vmi.broadcast ... + %y = pto.vmi.mulf %x, %k + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C0 + +--------------------------------------+ + | %x %k %y | + | load broadcast mulf result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +若 %y 的 producer chain 可采纳该 request,assignment 可以选择: + + L(C0) = deinterleaved=4 +``` + +控制流 join 也是等价类,但 request adoption 的含义不同: + +```text +surface VMI: + + %y = scf.if %c -> !pto.vmi.vreg<128xf32> { + scf.yield %a + } else { + scf.yield %b + } + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C1 + +--------------------------------------+ + | %a %b %y | + | then yield else yield if result | + +--------------------------------------+ + ^ + | + use request from truncf source: + wants deinterleaved=4 + +scf.if result 不是 consumer-driven adoption 的可采纳 producer。 +若 C1 不能直接选择 deinterleaved=4,assignment 保持 C1 的布局, +并在 use site materialize: + + %y_for_q = pto.vmi.ensure_layout %y : L(C1) -> deinterleaved=4 + %q = pto.vmi.truncf %y_for_q +``` + +多 consumer 冲突时,等价类仍然只有一个 layout: + +```text +surface VMI: + + %y = pto.vmi.mulf %x, %k + pto.vmi.store %y, %out0 + %q = pto.vmi.truncf %y + +data layout 等价类: + + class C2 + +-----------------------------+ + | %x %k %y | + +-----------------------------+ + |\ + | \ use request from truncf: deinterleaved=4 + | + +--- use request from store: contiguous + +两个 use request 不一致时,不能让 %y 同时拥有两个 layout。 +baseline assignment 保留 C2 已有的 natural layout;若没有 natural layout, +则使用默认 contiguous。与该 layout 不匹配的 edge 会插 ensure_layout。 +``` + 第二类是 result 自然布局。某些 op 的结果本身有目标相关的自然布局: ```text From 9398b606b614262480b8a09856105d266883af73 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 12:53:10 +0800 Subject: [PATCH 29/54] Add VMI histogram lowering support --- docs/designs/vmi-implementation-manual.md | 117 +++++++++++ docs/designs/vmi-introduction.md | 53 +++++ .../vmi-layout-assignment-implementation.md | 97 +++++++-- .../vmi-layout-assignment-lowering-design.md | 46 +++++ docs/designs/vmi-layout-lowering-cases.md | 194 ++++++++++++++++++ include/PTO/IR/VMIOps.td | 17 ++ include/PTO/Transforms/VMILayoutSupport.h | 14 ++ lib/PTO/IR/VMI.cpp | 51 +++++ lib/PTO/Transforms/PTOValidateVMIIR.cpp | 18 ++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 22 ++ lib/PTO/Transforms/VMILayoutSupport.cpp | 67 ++++++ lib/PTO/Transforms/VMIToVPTO.cpp | 110 +++++++++- test/lit/vmi/vmi_layout_assignment_dhist.pto | 37 ++++ .../vmi_to_vpto_chist_semantics_invalid.pto | 27 +++ test/lit/vmi/vmi_to_vpto_dhist.pto | 41 ++++ test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto | 49 +++++ .../vmi/dhist-tail-mask-store/compare.py | 36 ++++ .../cases/vmi/dhist-tail-mask-store/golden.py | 44 ++++ .../vmi/dhist-tail-mask-store/kernel.pto | 56 +++++ .../vmi/dhist-tail-mask-store/launch.cpp | 33 +++ .../cases/vmi/dhist-tail-mask-store/main.cpp | 94 +++++++++ .../vmi/dhist-tail-mask-store/ptoas.flags | 1 + 22 files changed, 1211 insertions(+), 13 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_dhist.pto create mode 100644 test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_dhist.pto create mode 100644 test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/compare.py create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/golden.py create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp create mode 100644 test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 6bb7a7e0fe..2cd72208a6 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -1905,6 +1905,9 @@ layout-producing conversion: externally ordered memory: load, store, tile_read, tile_write + +value-indexed accumulation: + dhist, chist ``` Per-part elementwise ops are straightforward only when all operands/results already share the same assigned layout: @@ -2219,6 +2222,17 @@ pto.vmi.tile_read pto.vmi.tile_write ``` +Value-indexed accumulation: + +```text +pto.vmi.dhist +pto.vmi.chist +``` + +`pto.vmi.dhist` is a first-stage semantic op when histogram support is enabled. +`pto.vmi.chist` may share the surface verifier, but its final lowering must be +gated until the target CHISTv2 high-range cumulative semantics are verified. + Current implementation scope note: ```text @@ -2299,6 +2313,18 @@ load/tile_read memory element type must match result VMI data element type when store/tile_write memory element type must match stored VMI data element type when the destination is PtrType or MemRefType ``` +Histogram op verifier: + +```text +dhist/chist acc type must be !pto.vmi.vreg<256xui16> +dhist/chist result type must match acc type +source type must be !pto.vmi.vreg +mask logical lane count must match source logical lane count +surface mask may be pred; after layout assignment it must be b8 contiguous +source/result/acc must not carry layout before vmi-layout-assignment +layout-assigned dhist/chist requires contiguous source, mask, acc, and result +``` + `shuffle` verifier: ```text @@ -3833,6 +3859,87 @@ vmi.tile_read / vmi.tile_write, current direct full-footprint path: any path that would expose padding lanes or reorder externally visible memory ``` +Histogram lowering: + +```text +vmi.dhist semantics: + source lanes are ui8 samples + mask selects active source lanes + acc/result are complete logical 256-bin ui16 histograms + result[b] = acc[b] + count(active source lanes whose value equals b) + +layout assignment: + source layout = contiguous + mask layout = contiguous, granularity b8 + acc/result layout = contiguous !pto.vmi.vreg<256xui16> + +physicalization: + acc/result physical arity is 2 because 256xui16 is 512B + part0 represents logical bins 0..127 + part1 represents logical bins 128..255 +``` + +`vmi-to-vpto` lowering for `pto.vmi.dhist` is local and deterministic from the +op and assigned types: + +```text +lo = converted acc part0 +hi = converted acc part1 + +for each converted source physical chunk c in logical order: + chunk_mask = converted b8 mask chunk c + + if source chunk c contains padding lanes because N is not a multiple of 256: + valid = pto.pge/plt_b8 prefix mask for the valid logical lanes in this chunk + chunk_mask = pto.pand chunk_mask, valid + + lo = pto.dhistv2 lo, src_c, chunk_mask, #bin=0 + hi = pto.dhistv2 hi, src_c, chunk_mask, #bin=1 + +return physical result parts [lo, hi] +``` + +Required preflight: + +```text +acc/result element type is ui16 and logical lane count is exactly 256 +source element type is ui8 +source and mask logical lane counts match +source/mask are contiguous +mask granularity is b8 +source physical chunks are 256-lane ui8 chunks; final partial chunk is allowed +only when the lowering can construct the valid-lane prefix mask +``` + +Diagnostics: + +```text +VMI-UNSUPPORTED: pto.vmi.dhist requires contiguous ui8 source, b8 mask, and +contiguous 256xui16 accumulator/result + +VMI-UNSUPPORTED: pto.vmi.dhist final partial source chunk requires valid-lane +b8 mask materialization +``` + +`pto.vmi.chist` has the same verifier and assignment requirements, but final +lowering is capability-gated: + +```text +if CHISTv2 high-range semantics are verified as global cumulative: + replace the two pto.dhistv2 calls above with pto.chistv2 calls + +elif CHISTv2 high-range semantics are verified as range-local cumulative: + lower low/high pto.chistv2 and add the low-half total count to every high-half bin, + but only after low-total materialization and broadcast support is explicit + +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +Do not classify histogram as `group_reduce`. Its result location is selected +by source values, not by lane/group position, and its low/high split is caused +by the physical `128xui16` VPTO result width. + Final hard gate: ```text @@ -3895,6 +4002,12 @@ Slice 4 完成条件: 11. Same-family mask logic ops lower through the physical mask granularity instead of assuming b32 masks. Covered by vmi_to_vpto_mask_logic.pto for mask_and/mask_or/mask_xor/mask_not on b32 masks produced by cmpf and on direct b8/b16 mask operands. +12. `pto.vmi.dhist` lowers one logical 256-bin histogram into two VPTO low/high + bin-range histogram accumulator chains, and tail source chunks are masked + with a valid-lane b8 prefix. `pto.vmi.chist` is rejected until the target + CHISTv2 cumulative range semantics are classified. + Covered by vmi_to_vpto_dhist.pto, vmi_to_vpto_dhist_tail_mask.pto, and + vmi_to_vpto_chist_semantics_invalid.pto. ``` ## 7. Slice 5: Tile Memory And Padding @@ -4075,6 +4188,7 @@ currently routed through the registry: supported source/result layout conversion pairs supported b8/b16/b32 mask granularity conversion pairs pto.vmi.channel_split/channel_merge supported channel count + pto.vmi.dhist direct target support and pto.vmi.chist cumulative range semantics classification still legacy helper-based and should migrate into the registry as follow-up: full layout materialization plans and padding-safety checks @@ -4132,6 +4246,9 @@ vmi_to_vpto_deinterleaved2.mlir vmi_to_vpto_deinterleaved4.mlir vmi_to_vpto_compaction_deint_invalid.mlir vmi_to_vpto_non_full_tile.mlir +vmi_to_vpto_dhist.mlir +vmi_to_vpto_dhist_tail_mask.mlir +vmi_to_vpto_chist_semantics_invalid.mlir vmi_tile_read_padding.mlir vmi_tile_write_mask.mlir vmi_pipeline_hard_gates.mlir diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index dbf09230f1..e7161dc4a0 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -11,6 +11,10 @@ VMI 是 VPTO 之前的逻辑向量层。它让前端先表达“我要对 `NxT` 的逻辑向量做什么”, 再由 layout assignment 决定这个逻辑向量如何拆到 256B 物理 vector register 上。 +当 VPTO 指令因为物理 register 宽度只能暴露半宽接口时,VMI 也负责提供完整的 +逻辑语义。例如 `ui8` histogram 的完整结果是 `256xui16`,物理 VPTO histogram +一次只能返回 `128xui16`;VMI surface 应该表达完整 histogram,low/high bin +range 拆分属于 lowering 细节。 Surface VMI 类型不携带布局: @@ -892,6 +896,55 @@ scf.if -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) 这个场景说明为什么 layout 应该是 type 的一部分,而不是依赖 defining op。 +### 4.8 完整 Histogram 语义 + +VPTO 的 histogram 指令一次读取 `256xui8` source,但结果只能写 +`128xui16` accumulator。完整 `ui8` histogram 有 256 个 bin,因此物理 VPTO +接口需要通过 `#bin = 0/1` 分两次统计低半区和高半区。 + +VMI surface 不暴露这个物理 split: + +```mlir +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, + !pto.vmi.vreg, + !pto.vmi.mask + -> !pto.vmi.vreg<256xui16> +``` + +语义是完整 256-bin distribution histogram: + +```text +for b = 0..255: + hist[b] = acc[b] + count(i where mask[i] && src[i] == b) +``` + +Assignment 形状: + +```text +src/mask = contiguous, b8 mask granularity +acc/result = contiguous 256xui16 logical value +``` + +VPTO 形状: + +```text +acc/result part0 = bins 0..127 +acc/result part1 = bins 128..255 + +for each 256-lane source chunk: + part0 = dhistv2(part0, src_chunk, mask_chunk, #bin=0) + part1 = dhistv2(part1, src_chunk, mask_chunk, #bin=1) +``` + +这说明 VMI 的易用性不只来自 layout assignment。对于这种 value-indexed +accumulation,VMI 还应该隐藏 VPTO 为了物理 vreg 宽度暴露出来的 range +selector、lo/hi accumulator 和多条物理指令。 + +`pto.vmi.chist` 可以使用相同 surface 形状,但当前必须先验证 VPTO `CHISTv2` +在 high range 上返回的是全局累计还是 range-local 累计。这个差异会影响是否需要 +额外给 high half 加上 low half 的总计数,因此不能只按 op 名字猜 lowering。 + ## 5. 当前边界 当前设计方向: diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index e1fa19cc7e..1162818712 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -104,7 +104,7 @@ pto-validate-vmi-layout-ir: `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, - `pto.vmi.extf`, and `pto.vmi.bitcast`, at the layout gate. + `pto.vmi.extf`, `pto.vmi.bitcast`, and histogram family ops at the layout gate. vmi-to-vpto: use OneToN type conversion @@ -302,6 +302,7 @@ group_slot_load result group_slots layout and source_group_stride group_reduce_add{f|i} source/mask/result layouts, num_groups, typed reduce semantics group_broadcast source/result layouts and num_groups truncf source/result layouts and element widths +dhist/chist acc/source/mask/result layouts and target capability ensure_layout always carries source/result layouts ensure_mask_layout always carries source/result layouts ensure_mask_granularity always carries source/result granularities @@ -373,6 +374,8 @@ group_reduce_addf group_reduce_addi group_broadcast group_store +dhist +chist ensure_layout // internal ensure_mask_layout // internal @@ -444,6 +447,13 @@ group_reduce layout fact: Example: S=2*VLaneElems means deinterleaved=2 source/mask and group_slots(G, slots=8) result in every stage. +histogram layout fact: + shared by layout assignment, layout validation, and vmi-to-vpto. + Example: dhist requires contiguous Nxui8 source, contiguous b8 mask, and + contiguous 256xui16 acc/result. chist uses the same layout fact but also + requires a target capability that classifies CHISTv2 cumulative range + semantics. + layout materialization support: shared by layout validation, vmi-to-vpto, and helper-based optimizations. Example: ensure_layout from deinterleaved=2 f32 to contiguous f32 is the same @@ -670,6 +680,15 @@ ordinary store: group_store: source request group_slots(G,K) + +dhist: + acc/result request contiguous 256xui16 + source request contiguous Nxui8 + mask request contiguous b8 + +chist: + same layout requests as dhist + diagnostic unless CHISTv2 cumulative range semantics are classified ``` Baseline assignment does not perform consumer-driven adoption for performance. @@ -679,8 +698,8 @@ use. ```text natural layout producer: - extf/truncf, group_reduce, group_slot_load, group_load when the op itself - carries a layout-producing contract + extf/truncf, group_reduce, group_slot_load, group_load, dhist/chist when the + op itself carries a layout-producing contract layout equality producer: dense add/mul/select and CFG-carried values tie operands/results but do not @@ -754,6 +773,14 @@ buildMaskRequests: masked_store requests source layout, mask layout, and store predicate granularity explicitly +buildHistogramRequests: + dhist -> acc/result contiguous 256xui16, source contiguous Nxui8, + mask contiguous b8 + chist -> same layout requests, plus target capability diagnostic until + CHISTv2 high-range semantics are classified + do not create group_slots or group_reduce requests; histogram result bins are + selected by source values, not by lane/group position + buildControlFlowRequests: region yields, branch operands, loop iter_args, call operands, and returns create equality requests on the carried VMI layout variable @@ -797,6 +824,7 @@ fixed-layout producers: extf/truncf physical conversion layouts group_load block-fragment layouts group_reduce result group_slots + dhist/chist result contiguous 256xui16 and source/mask contiguous b8 contract masked_load when the physical memory-safety proof fixes a full-read lowering ``` @@ -955,6 +983,20 @@ vmi-to-vpto contract: dynamic mask generation. ``` +```text +case family builder / owner assignment artifact +3.56 full distribution hist buildHistogramRequests contiguous src/mask/acc/result +3.57 cumulative hist boundary buildHistogramRequests capability diagnostic or classified path + +vmi-to-vpto contract: + lower dhist from the current op and assigned layouts by carrying two physical + accumulator parts for bins 0..127 and 128..255. It must not expose the VPTO + #bin range selector on the VMI surface and must not model histogram as + group_reduce. chist remains rejected until the target records whether the + high-range cumulative result is global or range-local and, for range-local + behavior, until low-total materialization is explicit. +``` + ```text case family builder / owner assignment artifact 3.15.1 S=16 row stride 16 buildGroupMemoryRequests block_elems=8 group_load layout @@ -1170,6 +1212,19 @@ group_reduce_add{f|i}, lowering=full_chunk_reduce_row_local: the existing row-local VCADD/VADD/VSEL sequence while preserving the same group_slots(G, slots=1) value contract +dhist, lowering=full_256bin_histogram: + consumes contiguous Nxui8 source and contiguous b8 mask + consumes/produces contiguous 256xui16 accumulator/result + physical result parts are [bins 0..127, bins 128..255] + emits one low-range and one high-range histogram update for each 256-lane + source chunk + final partial source chunks require an explicit valid-lane b8 mask + +chist, lowering=capability_gated_cumulative_histogram: + uses the same layout shape as dhist + rejects until target capability classifies CHISTv2 high-range cumulative + semantics and any required low-total correction materialization is explicit + group_slot_load, lowering=group_slot_load_slots8_unit_stride: result group_slots(G, slots=8) requires source_group_stride == 1 @@ -1558,6 +1613,11 @@ strided/group-slot memory: function/control-flow: 3.12, 3.20, 3.22, 3.25.1, 3.42, 3.43 + +histogram: + 3.56 positive dhist layout/lowering and simulator case when backend support + is enabled + 3.57 diagnostic chist case until CHISTv2 range semantics are classified ``` Aggregate catalog headings are covered through their endpoint subcases: @@ -1607,6 +1667,8 @@ repository evidence: golden.py, and compare.py latest broad VMI runtime sweep passed: PASS=47 FAIL=0 latest full VMI lit sweep passed: 350/350 + this historical sweep predates 3.56/3.57; histogram endpoints require new + lit/SIM or diagnostic tests before they can be counted as implemented ``` Current checked-in coverage for 3.3 dense f8->f32->compute->f8: @@ -2208,23 +2270,34 @@ internal function argument boundary materialization public ABI diagnostic ``` +### Slice 7: Histogram + +```text +3.56 full 256-bin dhist logical op +3.57 chist semantic capability diagnostic +``` + ## 13. Completion Checklist Current evidence for the case-catalog objective: ```text -1. every catalog endpoint is mapped in section 6.6 to an assignment owner, - assignment artifact, and vmi-to-vpto contract -2. every SIM-backed positive endpoint is listed in section 11.3 and has a - checked-in runtime case directory -3. every runtime case directory contains kernel.pto, launch.cpp, main.cpp, - golden.py, and compare.py -4. the latest broad VMI runtime sweep passed: PASS=47 FAIL=0 -5. the latest full VMI lit sweep passed: 350/350 -6. every unsupported endpoint listed in section 11.3 has a diagnostic lit test +1. every pre-histogram catalog endpoint is mapped in section 6.6 to an + assignment owner, assignment artifact, and vmi-to-vpto contract +2. every pre-histogram SIM-backed positive endpoint is listed in section 11.3 + and has a checked-in runtime case directory +3. every existing runtime case directory contains kernel.pto, launch.cpp, + main.cpp, golden.py, and compare.py +4. the latest historical broad VMI runtime sweep passed: PASS=47 FAIL=0 +5. the latest historical full VMI lit sweep passed: 350/350 +6. every pre-histogram unsupported endpoint listed in section 11.3 has a + diagnostic lit test 7. vmi-to-vpto decisions are represented by current-op attrs/operands, assigned layouts, helper ops, rematerialization, or diagnostics 8. no separate lowering-plan string attr is emitted or consumed 9. release docs remain untouched; this is still a design/implementation plan under docs/designs +10. new histogram endpoints 3.56/3.57 are mapped in section 6.6, but their + implementation evidence is intentionally pending new lit/SIM or diagnostic + tests ``` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 42e62e8b3a..00f69aae05 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -148,6 +148,11 @@ strided memory: strided group_load feeding broadcast and a second group_reduce group_slot_load slots=1 with non-unit source stride group_store slots=1 with non-unit output stride + +value-indexed accumulation: + full 256-bin distribution histogram over Nxui8 source lanes + VPTO low/high bin range split hidden behind one logical 256xui16 VMI result + cumulative histogram is a semantic boundary until CHISTv2 range semantics are verified ``` ### 1.1 Case-Set Sufficiency @@ -184,6 +189,10 @@ control-flow propagation: memory legality: full_tile_readable proof, grouped masks, predicate granularity, aligned strided group memory, stable gather diagnostic + +value-indexed accumulation: + histogram source/result shape, b8 source mask, and fixed low/high VPTO bin + split for a logical 256-bin result ``` No extra layout kind should be added unless a new case proves that the existing @@ -235,6 +244,13 @@ compute boundary: storage must be widened first because integer reduction instructions widen narrow inputs. f8/i8 are not baseline accumulator/compute element types. + +value-indexed accumulation boundary: + pto.vmi.dhist consumes ui8 source lanes and produces a logical 256xui16 + accumulator/result. It is not a group_reduce family member because result + bins are selected by source values rather than by source lane/group position. + pto.vmi.chist uses the same surface shape only after the target CHISTv2 + range semantics are verified. ``` ### 2.1 Dense Layouts @@ -298,6 +314,22 @@ S=8/16/32 packed VCG result -> slots=8 S=64 row-local result -> slots=1 ``` +Histogram does not add a layout family. A full logical histogram result uses: + +```text +!pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +and physicalizes to two ordered VPTO parts: + +```text +part0 = logical bins 0..127 +part1 = logical bins 128..255 +``` + +The VPTO `#bin` selector is therefore an op-local lowering detail, not a VMI +layout attribute and not a user-visible operand on `pto.vmi.dhist`. + ## 3. Lowering Context Must Become Explicit IR Output `vmi-to-vpto` may inspect only: @@ -670,6 +702,20 @@ create_mask/create_group_mask: incompatible mask consumers are represented by ensure_mask_layout or ensure_mask_granularity; optimization may clone/rematerialize the mask op +dhist: + requests acc/result contiguous !pto.vmi.vreg<256xui16> + requests source contiguous !pto.vmi.vreg + requests mask contiguous with b8 granularity + lowers each 256-lane source chunk by carrying two accumulator parts: + bins 0..127 use VPTO histogram #bin=0, bins 128..255 use #bin=1 + final partial source chunks are represented by AND-ing the user mask with a + valid-lane prefix mask before the VPTO histogram op + +chist: + same layout requests as dhist + baseline lowering is disabled until target capability records whether the + high-range VPTO cumulative result is global or range-local + scf.if/scf.for/call/return: requests equality across carried VMI values, yielded values, call operands, callee arguments, and function results diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 855a7a486f..efb2a7c502 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -246,6 +246,9 @@ the immediately following complete endpoints. 3.44 masked_load grouped tail feeding S=32 reduce complete 3.45 dynamic S=32 create_group_mask complete 3.46 extf value and derived elemwise value both stored complete/optimization +3.47-3.55 typed group-reduce generalization complete/diagnostic +3.56 full 256-bin distribution histogram complete +3.57 full 256-bin cumulative histogram design boundary ``` ### 3.1 `f16 -> f32 -> store` @@ -6235,3 +6238,194 @@ pto.vmi.group_store %sum8, %out_i8[%group_off], %c1 {num_groups = 8} That packed group-slot `trunci` path is not baseline lowering support yet; the implementation must either define slot-wise VCVTII lowering support or diagnose at layout assignment. + +### 3.56 Full 256-Bin Distribution Histogram + +Histogram is not modeled as `group_reduce`. A group reduce maps source lanes to +result slots by lane/group position. A histogram maps each active source lane +to a result bin by the source value itself. + +VMI-shaped input: + +```text +%src = pto.vmi.load %src_base[%src_off] + : memref -> !pto.vmi.vreg +%mask = pto.vmi.create_mask %active_lanes + : index -> !pto.vmi.mask +%acc = pto.vmi.load %acc_base[%acc_off] + : memref<256xui16> -> !pto.vmi.vreg<256xui16> +%hist = pto.vmi.dhist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +pto.vmi.store %hist, %out[%out_off] +``` + +Logical semantics: + +```text +for b = 0..255: + hist[b] = acc[b] + +for i = 0..N-1: + if mask[i]: + hist[src[i]] += 1 +``` + +Assigned layouts: + +```text +%src: + !pto.vmi.vreg> + +%mask: + !pto.vmi.mask> + +%acc, %hist: + !pto.vmi.vreg<256xui16, #pto.vmi.layout> +``` + +The `256xui16` accumulator/result is one logical VMI value but two physical +VPTO vector registers: + +```text +physical result part0 = logical bins 0..127 +physical result part1 = logical bins 128..255 +``` + +For `N = 256`, VPTO lowering shape: + +```text +%src0 = pto.vlds %src_base[%src_off] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<256xui8> + +%acc_lo = pto.vlds %acc_base[%acc_off + 0] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> +%acc_hi = pto.vlds %acc_base[%acc_off + 128] {dist = "NORM"} + : !pto.ptr -> !pto.vreg<128xui16> + +%hist_lo = pto.dhistv2 %acc_lo, %src0, %mask0, %bin0 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> +%hist_hi = pto.dhistv2 %acc_hi, %src0, %mask0, %bin1 + : !pto.vreg<128xui16>, !pto.vreg<256xui8>, !pto.mask, i32 + -> !pto.vreg<128xui16> + +pto.vsts %hist_lo, %out[%out_off + 0], %all_b16 {dist = "NORM_B16"} +pto.vsts %hist_hi, %out[%out_off + 128], %all_b16 {dist = "NORM_B16"} +``` + +Memory result: + +```text +for b = 0..127: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) + +for b = 128..255: + out[out_off + b] = acc_base[acc_off + b] + + count(i where mask[i] && src_base[src_off + i] == b) +``` + +For `N > 256`, the source is processed in contiguous 256-lane chunks. The two +histogram accumulator parts are carried through all chunks: + +```text +%lo = %acc_lo +%hi = %acc_hi + +for source chunk c in logical order: + %chunk_mask = mask chunk c + if c is the final partial chunk: + %chunk_mask = %chunk_mask & valid-lane-prefix-for-this-chunk + + %lo = pto.dhistv2 %lo, %src_c, %chunk_mask, %bin0 + %hi = pto.dhistv2 %hi, %src_c, %chunk_mask, %bin1 + +result physical parts = [%lo, %hi] +``` + +Tail source lanes are expressed only through the b8 mask. Padding lanes in the +last physical source chunk must be masked off before `pto.dhistv2`; they are +not padding values. + +The VMI op does not expose `#bin`. `#bin` is a VPTO range selector forced by +the physical result width: + +```text +ui8 value domain = 256 bins +complete histogram = 256 x ui16 = 512B +one VPTO vreg result = 128 x ui16 = 256B +``` + +Therefore VMI represents one logical `256xui16` result and `vmi-to-vpto` +locally emits the low-range and high-range VPTO histogram updates. + +### 3.57 Full 256-Bin Cumulative Histogram + +The desired VMI surface shape mirrors `dhist`: + +```text +%hist = pto.vmi.chist %acc, %src, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg, + !pto.vmi.mask -> !pto.vmi.vreg<256xui16> +``` + +The intended logical semantics is a full cumulative histogram: + +```text +dist[b] = count(i where mask[i] && src[i] == b) + +hist[0] = acc[0] + dist[0] +for b = 1..255: + hist[b] = acc[b] + dist[0] + dist[1] + ... + dist[b] +``` + +The current VPTO/VISA documentation only states that `CHISTv2` computes a +`uint16 Cumulative histogram` over the selected bin range. It does not state +whether the high-range call with `#bin = 1` returns: + +```text +global cumulative: + result[j] = count(src <= 128 + j) + +or range-local cumulative: + result[j] = count(128 <= src <= 128 + j) +``` + +These two interpretations have different VMI lowerings. If the hardware result +is global cumulative, the full VMI lowering is the same low/high split as +`dhist`, replacing `pto.dhistv2` with `pto.chistv2`. If the hardware result is +range-local cumulative, the high half also needs the total low-half count added +to every high-half bin: + +```text +%lo = pto.chistv2 %acc_lo, %src0, %mask0, %bin0 +%hi_local = pto.chistv2 %acc_hi, %src0, %mask0, %bin1 + +%low_total = materialize count(src <= 127) from the low-half result +%low_total_vec = broadcast %low_total to every high-half bin +%hi = pto.vadd %hi_local, %low_total_vec, %all_b16 +``` + +That correction path also requires a designed way to materialize and broadcast +the low-half total. Since baseline VMI does not support arbitrary vector +extract, the range-local CHISTv2 interpretation remains unsupported until that +materialization path is explicit. + +The baseline design therefore treats `pto.vmi.chist` as a semantic op whose +exact lowering is gated by a target semantic capability: + +```text +if target documents or validation proves CHISTv2 high range is global: + lower as two pto.chistv2 calls +elif target documents or validation proves CHISTv2 high range is range-local: + lower as pto.chistv2 low/high plus explicit high-half correction only after + low-total materialization support is designed +else: + VMI-UNSUPPORTED: pto.vmi.chist requires a verified CHISTv2 range semantics contract +``` + +This boundary is deliberate. `pto.vmi.dhist` is fully defined because +distribution bins are independent across the low/high split. `pto.vmi.chist` +has cross-range prefix semantics, so VMI must not guess the high-half behavior +from the VPTO op name alone. diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index d14b6fe8ee..98083fb687 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -437,6 +437,23 @@ def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } +class VMIHistogramOp + : VMI_Op { + let summary = summaryText; + let arguments = (ins VMI_VRegTypeConstraint:$acc, + VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$acc `,` $source `,` $mask attr-dict `:` type($acc) `,` type($source) `,` type($mask) `->` type($result)"; +} + +def VMIDhistOp : VMIHistogramOp<"dhist", + "VMI full 256-bin distribution histogram over unsigned 8-bit source lanes">; + +def VMIChistOp : VMIHistogramOp<"chist", + "VMI full 256-bin cumulative histogram over unsigned 8-bit source lanes">; + def VMIExtFOp : VMI_Op<"extf"> { let summary = "VMI floating-point elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 9a274a2a9b..41b686b322 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -188,6 +188,14 @@ struct VMIBitcastSupport { VMIBitcastSupportKind kind = VMIBitcastSupportKind::PerPartVbitcast; }; +enum class VMIHistogramSupportKind { + Full256BinDhist, +}; + +struct VMIHistogramSupport { + VMIHistogramSupportKind kind = VMIHistogramSupportKind::Full256BinDhist; +}; + class VMILayoutSupport { public: FailureOr @@ -280,6 +288,12 @@ class VMILayoutSupport { FailureOr getBitcastSupport(VMIBitcastOp op, std::string *reason = nullptr) const; + + FailureOr + getDhistSupport(VMIDhistOp op, std::string *reason = nullptr) const; + + FailureOr + getChistSupport(VMIChistOp op, std::string *reason = nullptr) const; }; } // namespace mlir::pto diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index d3d2dc6b14..edd55a2bb3 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1231,6 +1231,57 @@ LogicalResult VMIGroupBroadcastOp::verify() { getNumGroupsAttr().getInt()); } +template +static LogicalResult verifyVMIHistogramOp(OpTy op) { + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + auto accElemType = dyn_cast(accType.getElementType()); + auto sourceElemType = dyn_cast(sourceType.getElementType()); + if (!accElemType || !accElemType.isUnsigned() || + accElemType.getWidth() != 16 || accType.getElementCount() != 256) + return op.emitOpError("requires acc type to be " + "!pto.vmi.vreg<256xui16>"); + if (resultType != accType) + return op.emitOpError("requires result type to match acc type"); + if (!sourceElemType || !sourceElemType.isUnsigned() || + sourceElemType.getWidth() != 8) + return op.emitOpError("requires source type to be " + "!pto.vmi.vreg"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return op.emitOpError("requires mask logical lane count to match source"); + + if (auto accLayout = accType.getLayoutAttr()) { + if (!accLayout.isContiguous()) + return op.emitOpError("requires layout-assigned acc to use contiguous " + "layout"); + } + if (auto sourceLayout = sourceType.getLayoutAttr()) { + if (!sourceLayout.isContiguous()) + return op.emitOpError("requires layout-assigned source to use contiguous " + "layout"); + } + if (auto resultLayout = resultType.getLayoutAttr()) { + if (!resultLayout.isContiguous()) + return op.emitOpError("requires layout-assigned result to use " + "contiguous layout"); + } + if (auto maskLayout = maskType.getLayoutAttr()) { + if (!maskLayout.isContiguous()) + return op.emitOpError("requires layout-assigned mask to use contiguous " + "layout"); + if (maskType.getGranularity() != "b8") + return op.emitOpError("requires layout-assigned mask granularity b8"); + } + return success(); +} + +LogicalResult VMIDhistOp::verify() { return verifyVMIHistogramOp(*this); } + +LogicalResult VMIChistOp::verify() { return verifyVMIHistogramOp(*this); } + LogicalResult VMIExtFOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6fdf6acf07..1186a25e26 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -627,6 +627,24 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getDhistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.dhist has no registered histogram support", + reason); + return success(); + } + + if (auto hist = dyn_cast(op)) { + std::string reason; + if (failed(supports.getChistSupport(hist, &reason))) + return emitLayoutSupportContract( + op, diagOS, "pto.vmi.chist has no registered histogram support", + reason); + return success(); + } + if (auto truncf = dyn_cast(op)) { std::string reason; if (failed(supports.getTruncFSupport(truncf, &reason))) diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index eb3593c9ee..99e4314cf9 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -892,6 +892,28 @@ struct LayoutSolver { sourceType, broadcast.getNumGroupsAttr().getInt())); return WalkResult::advance(); } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto hist = dyn_cast(op)) { + requestDataUse(hist.getAccMutable(), getContiguousLayout()); + requestDataUse(hist.getSourceMutable(), getContiguousLayout()); + if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), + "b8", op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto extf = dyn_cast(op)) { auto sourceType = cast(extf.getSource().getType()); auto resultType = cast(extf.getResult().getType()); diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 27a994ba55..acb687eed0 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -1344,3 +1344,70 @@ VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, return VMIBitcastSupport{VMIBitcastSupportKind::PerPartVbitcast}; } + +template +static FailureOr +getHistogramSupportImpl(OpTy op, std::string *reason) { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto accType = cast(op.getAcc().getType()); + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + + VMILayoutAttr accLayout = accType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!accLayout || !sourceLayout || !maskLayout || !resultLayout) + return fail("requires assigned acc/source/mask/result layouts"); + if (!accLayout.isContiguous() || !sourceLayout.isContiguous() || + !maskLayout.isContiguous() || !resultLayout.isContiguous()) + return fail("requires contiguous acc, source, mask, and result layouts"); + if (maskType.getGranularity() != "b8") + return fail("requires b8 mask granularity"); + if (maskType.getElementCount() != sourceType.getElementCount()) + return fail("requires mask lane count to match source lane count"); + + auto accElem = dyn_cast(accType.getElementType()); + auto sourceElem = dyn_cast(sourceType.getElementType()); + if (!accElem || !accElem.isUnsigned() || accElem.getWidth() != 16 || + accType.getElementCount() != 256 || resultType != accType) + return fail("requires contiguous 256xui16 acc/result"); + if (!sourceElem || !sourceElem.isUnsigned() || sourceElem.getWidth() != 8) + return fail("requires unsigned 8-bit source elements"); + + FailureOr accArity = getVMIPhysicalArity(accType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(accArity) || failed(resultArity) || failed(sourceArity) || + failed(maskArity)) + return fail("requires computable physical arity"); + if (*accArity != 2 || *resultArity != 2) + return fail("requires acc/result to physicalize to two 128xui16 parts"); + if (*sourceArity != *maskArity) + return fail("requires source and mask physical arity to match"); + if (*sourceArity < 1) + return fail("requires at least one source physical chunk"); + + return VMIHistogramSupport{VMIHistogramSupportKind::Full256BinDhist}; +} + +FailureOr +VMILayoutSupport::getDhistSupport(VMIDhistOp op, + std::string *reason) const { + return getHistogramSupportImpl(op, reason); +} + +FailureOr +VMILayoutSupport::getChistSupport(VMIChistOp op, + std::string *reason) const { + if (reason) + *reason = "CHISTv2 cumulative high-range semantics are not classified"; + return failure(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index ea286520bf..4115b90324 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -5986,6 +5986,76 @@ struct OneToNVMIGroupBroadcastOpPattern } }; +struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIDhistOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange accParts = adaptor.getAcc(); + ValueRange sourceParts = adaptor.getSource(); + ValueRange maskParts = adaptor.getMask(); + if (accParts.size() != 2 || sourceParts.empty() || + sourceParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure( + op, "expected two accumulator parts and matching source/mask chunks"); + + auto loType = dyn_cast(accParts[0].getType()); + auto hiType = dyn_cast(accParts[1].getType()); + if (!loType || loType != hiType) + return rewriter.notifyMatchFailure(op, + "expected matching ui16 acc parts"); + auto sourceType = cast(op.getSource().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(sourceType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure(op, + "failed to compute source lanes"); + + Location loc = op.getLoc(); + Value bin0 = createI32Constant(loc, 0, rewriter); + Value bin1 = createI32Constant(loc, 1, rewriter); + Value lo = accParts[0]; + Value hi = accParts[1]; + + for (size_t index = 0, e = sourceParts.size(); index < e; ++index) { + Value source = sourceParts[index]; + Value userMask = maskParts[index]; + auto maskType = dyn_cast(userMask.getType()); + if (!maskType || !maskType.isB8()) + return rewriter.notifyMatchFailure(op, "expected b8 source mask"); + + Value chunkMask = userMask; + int64_t firstLane = static_cast(index) * *lanesPerPart; + int64_t activeLanes = + std::min(*lanesPerPart, + sourceType.getElementCount() - firstLane); + if (activeLanes < *lanesPerPart) { + FailureOr validMask = + createPrefixMaskForActiveLanes(loc, maskType, activeLanes, + rewriter); + FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); + if (failed(validMask) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to materialize tail-valid b8 mask"); + chunkMask = rewriter + .create(loc, maskType, chunkMask, *validMask, + *allMask) + .getResult(); + } + + lo = rewriter.create(loc, loType, lo, source, chunkMask, bin0) + .getResult(); + hi = rewriter.create(loc, hiType, hi, source, chunkMask, bin1) + .getResult(); + } + + rewriter.replaceOp(op, SmallVector{lo, hi}, + adaptor.getResultMapping()); + return success(); + } +}; + template struct OneToNVMIReduceMinMaxFOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -6937,7 +7007,7 @@ void populateVMIOneToNConversionPatterns( OneToNVMISelectOpPattern, OneToNVMIActivePrefixIndexOpPattern, OneToNVMICompressOpPattern, OneToNVMICompressStoreOpPattern, OneToNVMIReduceAddIOpPattern, OneToNVMIReduceAddFOpPattern, - OneToNVMIGroupBroadcastOpPattern, + OneToNVMIGroupBroadcastOpPattern, OneToNVMIDhistOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, @@ -7420,6 +7490,22 @@ LogicalResult checkSupportedGroupBroadcastShape( return success(); } +LogicalResult checkSupportedDhistShape(VMIDhistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getDhistSupport(op, reason))) + return success(); + return failure(); +} + +LogicalResult checkSupportedChistShape(VMIChistOp op, + std::string *reason = nullptr) { + VMILayoutSupport supports; + if (succeeded(supports.getChistSupport(op, reason))) + return success(); + return failure(); +} + LogicalResult checkSupportedFmaShape(const VMITargetCapabilityRegistry &capabilities, VMIFmaOp op, std::string *reason = nullptr) { @@ -7579,6 +7665,28 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << reason << ")"; return WalkResult::interrupt(); } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedDhistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.dhist requires contiguous Nxui8 source, contiguous b8 " + "mask, and contiguous 256xui16 acc/result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto hist = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedChistShape(hist, &reason))) + return WalkResult::advance(); + hist.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.chist requires a verified CHISTv2 range semantics " + "contract before lowering (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { std::optional explicitFullReadElems; diff --git a/test/lit/vmi/vmi_layout_assignment_dhist.pto b/test/lit/vmi/vmi_layout_assignment_dhist.pto new file mode 100644 index 0000000000..89d9aee9e6 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_dhist.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s + +module { + func.func @vmi_layout_assignment_dhist( + %acc: !pto.vmi.vreg<256xui16>, + %source: !pto.vmi.vreg<300xui8>, + %mask: !pto.vmi.mask<300xpred>) + -> !pto.vmi.vreg<256xui16> { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<300xui8>, + !pto.vmi.mask<300xpred> -> !pto.vmi.vreg<256xui16> + return %hist : !pto.vmi.vreg<256xui16> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_dhist( +// CHECK-SAME: %[[ACC:.*]]: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: %[[SRC:.*]]: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: %[[MASK:.*]]: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: %[[MASK_B8:.*]] = pto.vmi.ensure_mask_granularity %[[MASK]] +// CHECK-SAME: !pto.vmi.mask<300xb32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK: %[[HIST:.*]] = pto.vmi.dhist %[[ACC]], %[[SRC]], %[[MASK_B8]] +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<300xui8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.mask<300xb8, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// CHECK: return %[[HIST]] diff --git a/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto new file mode 100644 index 0000000000..1049cbdf2e --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_chist_semantics_invalid.pto @@ -0,0 +1,27 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_chist_semantics_invalid( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) { + %hist = pto.vmi.chist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + return + } +} + +// CHECK: VMI{{-}}UNSUP{{.*}} pto.vmi.chist requires a verified CHISTv2 range semantics contract before lowering diff --git a/test/lit/vmi/vmi_to_vpto_dhist.pto b/test/lit/vmi/vmi_to_vpto_dhist.pto new file mode 100644 index 0000000000..b8a1113534 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<256xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.vmi.mask<256xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK: %[[LO:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC]], %[[MASK]], %[[BIN0]] +// CHECK: %[[HI:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC]], %[[MASK]], %[[BIN1]] +// CHECK: return %[[LO]], %[[HI]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto new file mode 100644 index 0000000000..4aada7a188 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_dhist_tail_mask.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_dhist_tail_mask( + %acc: !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + %source: !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<300xb8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16, #pto.vmi.layout>, + !pto.vmi.vreg<300xui8, #pto.vmi.layout>, + !pto.vmi.mask<300xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %lo, %hi = "pto.vmi.unpack"(%hist) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %lo, %hi : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_dhist_tail_mask( +// CHECK-SAME: %[[ACC0:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[ACC1:[^,]+]]: !pto.vreg<128xui16> +// CHECK-SAME: %[[SRC0:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[SRC1:[^,]+]]: !pto.vreg<256xui8> +// CHECK-SAME: %[[MASK0:[^,]+]]: !pto.mask +// CHECK-SAME: %[[MASK1:[^)]+]]: !pto.mask +// CHECK-DAG: %[[BIN0:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[BIN1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C44:.*]] = arith.constant 44 : i32 +// CHECK: %[[LO0:.*]] = pto.dhistv2 %[[ACC0]], %[[SRC0]], %[[MASK0]], %[[BIN0]] +// CHECK: %[[HI0:.*]] = pto.dhistv2 %[[ACC1]], %[[SRC0]], %[[MASK0]], %[[BIN1]] +// CHECK: %[[TAIL:.*]], %{{.*}} = pto.plt_b8 %[[C44]] : i32 -> !pto.mask, i32 +// CHECK: %[[ALL:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: %[[MASK1_VALID:.*]] = pto.pand %[[MASK1]], %[[TAIL]], %[[ALL]] +// CHECK: %[[LO1:.*]] = pto.dhistv2 %[[LO0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN0]] +// CHECK: %[[HI1:.*]] = pto.dhistv2 %[[HI0]], %[[SRC1]], %[[MASK1_VALID]], %[[BIN1]] +// CHECK: return %[[LO1]], %[[HI1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py new file mode 100644 index 0000000000..22aff69b5d --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + output = np.fromfile("v3.bin", dtype=np.uint16) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + + if golden.shape != output.shape: + print(f"[ERROR] compare failed v3.bin: shape golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v3.bin idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py new file mode 100644 index 0000000000..0c09bb49d7 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +SOURCE_ELEMS = 512 +LOGICAL_LANES = 300 +BINS = 256 + + +def generate(output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + + src = (np.arange(SOURCE_ELEMS, dtype=np.uint16) % BINS).astype(np.uint8) + acc = (np.arange(BINS, dtype=np.uint16) % np.uint16(5)).astype(np.uint16) + dst = np.full(BINS, np.uint16(0xcccc), dtype=np.uint16) + + counts = np.bincount(src[:LOGICAL_LANES].astype(np.int64), minlength=BINS) + golden = (acc.astype(np.uint32) + counts.astype(np.uint32)).astype(np.uint16) + + src.tofile(output_dir / "v1.bin") + acc.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto new file mode 100644 index 0000000000..4fb1fe531c --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/kernel.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dhist_tail_mask_store_kernel( + %src_gm: !pto.ptr, %acc_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c300 = arith.constant 300 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_acc = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %acc_gm, %ub_acc, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %source = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<512xui8> + %acc = pto.vmi.load %ub_acc[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xui16> + %mask = pto.vmi.create_mask %c300 : index -> !pto.vmi.mask<512xpred> + %hist = pto.vmi.dhist %acc, %source, %mask + : !pto.vmi.vreg<256xui16>, !pto.vmi.vreg<512xui8>, + !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<256xui16> + pto.vmi.store %hist, %ub_dst[%c0] + : !pto.vmi.vreg<256xui16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp new file mode 100644 index 0000000000..4031c8131e --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/launch.cpp @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dhist_tail_mask_store_kernel(__gm__ uint8_t *src, __gm__ uint16_t *acc, + __gm__ uint16_t *dst); + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream) { + vmi_dhist_tail_mask_store_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint16_t *)acc, (__gm__ uint16_t *)dst); +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp new file mode 100644 index 0000000000..aa1288ab26 --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dhist_tail_mask_store_kernel(uint8_t *src, uint16_t *acc, + uint16_t *dst, void *stream); + +int main() { + constexpr size_t kSourceElems = 512; + constexpr size_t kBins = 256; + size_t srcBytes = kSourceElems * sizeof(uint8_t); + size_t accBytes = kBins * sizeof(uint16_t); + size_t dstBytes = kBins * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint16_t *accHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint16_t *accDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&accHost), accBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&accDevice, accBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", accBytes, accHost, accBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(accDevice, accBytes, accHost, accBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + + LaunchVmi_dhist_tail_mask_store_kernel(srcDevice, accDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(accDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(accHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/dhist-tail-mask-store/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 4af06583755ca9c1d0475b540132f7743c94fb86 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 14:44:35 +0800 Subject: [PATCH 30/54] Remove VMI load full read attribute --- .../vmi-layout-assignment-implementation.md | 22 ++++--- .../vmi-layout-assignment-lowering-design.md | 13 ++-- docs/designs/vmi-layout-lowering-cases.md | 14 ++--- include/PTO/IR/VMIOps.td | 3 +- lib/PTO/IR/VMI.cpp | 4 -- lib/PTO/Transforms/VMIToVPTO.cpp | 59 +++++++------------ ...gnment_group_reduce_s32_tail_full_tile.pto | 30 +++++----- .../vmi/vmi_load_full_read_elems_invalid.pto | 20 ------- .../golden.py | 2 +- .../kernel.pto | 16 ++--- 10 files changed, 70 insertions(+), 113 deletions(-) delete mode 100644 test/lit/vmi/vmi_load_full_read_elems_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 1162818712..fcdf7fe292 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -311,7 +311,7 @@ ensure_mask_granularity always carries source/result granularities Layout/attr-only decisions today: ```text -load result layout plus full_read_elems/full chunk proof +load result layout plus full chunk or shaped memref proof group_store source group_slots layout plus explicit output stride masked_load explicit passthrough, mask layout, and memory proof masked_store/select operand/result layouts plus mask granularity @@ -410,9 +410,8 @@ Important semantic split: ```text load: - optional full_read_elems=N is a memory-safety contract for pointer sources. - It states that source[offset : offset + N) may be physically read even if the - VMI logical result has fewer active lanes. + pointer sources must load full physical chunks directly. Partial logical + loads require a shaped memref proof or a future guarded/scratch fallback. group_load: loads group_size data elements per group @@ -806,7 +805,7 @@ helpers with cheaper equivalent IR. ```text cheap rematerializable producers: load when address operands dominate the clone site, no intervening may-alias - write exists, and any full_read_elems proof is preserved + write exists, and any shaped memory proof is preserved broadcast create_mask create_group_mask @@ -1036,7 +1035,7 @@ vmi-to-vpto contract: ```text case family builder / owner assignment artifact -3.21 S=32 safe full-read tail buildMaskRequests full_read_elems memory proof +3.21 S=32 rounded tail mask buildMaskRequests rounded vector plus mask 3.24 mask/select/store buildMaskRequests explicit mask layout/granularity 3.12 scf.if before reduce buildControlFlowRequests common yielded layout 3.20 group_slots scf.if buildControlFlowRequests common group_slots layout @@ -1469,7 +1468,7 @@ Current audit result: masked_load: direct lowering is load + vsel. It does not inspect the mask producer to choose a different load form; memory safety is provided by full physical - chunks, shaped memref proof, or load full_read_elems. + chunks or shaped memref proof. memref.subview: mentioned only after identity lane-to-address planning fails. It is not used @@ -2187,11 +2186,10 @@ private physical function ABI: rejected until a stable VMI ABI is defined. memory-proof runtime coverage: - 3.21 S=32 full-tile-readable tail is covered by a runtime case that uses - `pto.vmi.load {full_read_elems = 256}` on a UB pointer source. The attr is - the explicit safe-read proof consumed by `vmi-to-vpto`; no surrounding MTE, - caller/body context, or producer/user scan is inspected to justify the - rounded-up physical reads. + 3.21 S=32 rounded tail-mask coverage is provided by a runtime case that loads + a full 256xf32 UB pointer vector and uses a 192-lane mask to define the active + logical rows. No surrounding MTE, caller/body context, or producer/user scan is + inspected to justify partial pointer reads. ``` ## 12. Implementation Slices diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 00f69aae05..497f6cad8c 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -422,8 +422,8 @@ group_reduce layout fact: group_slots(G, slots=1) result. memory safety fact: - full_read_elems, shaped safe-tail memref, or explicit fallback option - proves whether rounded-up physical reads are legal. + full physical chunks are legal for pointer sources. Partial logical loads + need a shaped safe-tail memref proof or an explicit fallback option. ``` These helpers return semantic layout requirements and capability diagnostics. @@ -596,10 +596,11 @@ full_tile_readable: ``` The full-tile-readable proof must be explicit. It may be carried by a -statically shaped memref source, or by `pto.vmi.load {full_read_elems = N}` for -pointer sources. `vmi-to-vpto` consumes only this proof carrier; it does not -inspect surrounding MTE copies, producer bodies, callers, or later consumers to -decide whether inactive physical lanes are safe to read. +statically shaped memref source. Pointer-source runtime kernels should load a +rounded physical vector and use a mask to express logical active lanes. +`vmi-to-vpto` consumes only the op/type-local proof carrier; it does not inspect +surrounding MTE copies, producer bodies, callers, or later consumers to decide +whether inactive physical lanes are safe to read. Example: diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index efb2a7c502..3fd0c4b7eb 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -2845,11 +2845,10 @@ for r = 0..7: This is the positive counterpart to section 3.11.2. Tail participation is still expressed by masks, but the source must provide a static proof that reading the rounded-up 8-row physical tile is memory-safe. That proof is -explicit: it can come from a statically shaped memref source, or from -`pto.vmi.load {full_read_elems = N}` on a pointer source. The pointer attr -means the memory interval starting at the load offset is safe to read for `N` -logical elements; it is not inferred from surrounding MTE copies or caller -context. +explicit for partial logical loads: it can come from a statically shaped memref +source. Pointer-source runtime kernels should instead load the rounded physical +vector and use a mask to express active logical lanes; this is not inferred from +surrounding MTE copies or caller context. VMI input: @@ -2864,8 +2863,9 @@ pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 6} Equivalent pointer-source VMI input for runtime kernels: ```text -%x = pto.vmi.load %base[%off] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> +%x = pto.vmi.load %base[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> +%mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> ``` Assigned layouts: diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 98083fb687..263146ec1f 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -504,8 +504,7 @@ def VMIBitcastOp : VMI_Op<"bitcast"> { def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical vector load"; - let arguments = (ins PtrOrMemRef:$source, Index:$offset, - OptionalAttr:$full_read_elems); + let arguments = (ins PtrOrMemRef:$source, Index:$offset); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($result)"; diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index edd55a2bb3..3589ec603e 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1394,10 +1394,6 @@ LogicalResult VMIBitcastOp::verify() { } LogicalResult VMILoadOp::verify() { - if (auto fullReadElems = getFullReadElemsAttr()) { - if (fullReadElems.getInt() <= 0) - return emitOpError("requires full_read_elems to be positive"); - } return verifyMemoryElementMatches(getOperation(), getSource().getType(), cast(getResult().getType()), "source"); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 4115b90324..b0da679232 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -923,8 +923,7 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, VMIMemorySafeReadProof computeSafeFullReadProof( Type sourceType, std::optional constantOffset, - VMIVRegType resultType, - std::optional explicitFullReadElems = std::nullopt) { + VMIVRegType resultType) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -937,15 +936,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( if (!constantOffset) return fail("requires constant index offset"); - std::optional elements = explicitFullReadElems; - if (!elements) { - FailureOr staticElements = getStaticMemRefElementCount(sourceType); - if (failed(staticElements)) - return fail("requires statically shaped memref source or explicit " - "full_read_elems"); - elements = *staticElements; - } - proof.staticElementCount = *elements; + FailureOr staticElements = getStaticMemRefElementCount(sourceType); + if (failed(staticElements)) + return fail("requires statically shaped memref source"); + int64_t elements = *staticElements; + proof.staticElementCount = elements; if (*constantOffset < 0) return fail("requires non-negative offset"); @@ -959,11 +954,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( proof.laneAddressMap = *addressMap; proof.physicalFootprint = addressMap->physicalLaneFootprint; - if (addressMap->getExclusiveEndElement() > *elements) + if (addressMap->getExclusiveEndElement() > elements) return fail(Twine("full physical read footprint [") + Twine(addressMap->baseElementOffset) + ", " + Twine(addressMap->getExclusiveEndElement()) + - ") exceeds static memref element count " + Twine(*elements)); + ") exceeds static memref element count " + Twine(elements)); proof.proven = true; return proof; @@ -972,8 +967,7 @@ VMIMemorySafeReadProof computeSafeFullReadProof( VMIMemoryAccessPlan buildReadAccessPlan( const VMITargetCapabilityRegistry &capabilities, Value source, Type sourceType, VMIVRegType resultType, - std::optional constantOffset, VMIMemoryValidMaskKind validMask, - std::optional explicitFullReadElems = std::nullopt) { + std::optional constantOffset, VMIMemoryValidMaskKind validMask) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -982,8 +976,8 @@ VMIMemoryAccessPlan buildReadAccessPlan( plan.validMask = validMask; plan.permutation = VMIMemoryPermutationKind::Identity; plan.writeMask = VMIMemoryWriteMaskKind::AllTrue; - plan.safeReadProof = computeSafeFullReadProof( - sourceType, constantOffset, resultType, explicitFullReadElems); + plan.safeReadProof = + computeSafeFullReadProof(sourceType, constantOffset, resultType); plan.laneAddressMap = plan.safeReadProof.laneAddressMap; plan.targetCapability = capabilities.supportsDirectMemory(sourceType, "source"); @@ -1040,16 +1034,15 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { FailureOr verifyFullOrSafeReadVRegChunks( Operation *op, VMIVRegType type, Type sourceType, Value offset, - PatternRewriter &rewriter, - std::optional explicitFullReadElems = std::nullopt) { + PatternRewriter &rewriter) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); if (succeeded(lanesPerPart)) return *lanesPerPart; - VMIMemorySafeReadProof safeReadProof = computeSafeFullReadProof( - sourceType, getConstantIndexValue(offset), type, explicitFullReadElems); + VMIMemorySafeReadProof safeReadProof = + computeSafeFullReadProof(sourceType, getConstantIndexValue(offset), type); if (safeReadProof.proven) { lanesPerPart = getDataLanesPerPart(type.getElementType()); if (succeeded(lanesPerPart)) @@ -1065,7 +1058,7 @@ FailureOr verifyFullOrSafeReadVRegChunks( LogicalResult checkSupportedLoadShape( const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, Value source, Type sourceType, std::optional constantOffset, - std::optional explicitFullReadElems, std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); @@ -1074,7 +1067,7 @@ LogicalResult checkSupportedLoadShape( VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( capabilities, source, sourceType, type, constantOffset, - VMIMemoryValidMaskKind::AllTrue, explicitFullReadElems); + VMIMemoryValidMaskKind::AllTrue); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -1195,7 +1188,7 @@ checkSupportedGroupLoadShape(const VMITargetCapabilityRegistry &capabilities, if (resultLayout.isContiguous()) { if (failed(checkSupportedLoadShape(capabilities, resultType, op.getSource(), op.getSource().getType(), std::nullopt, - std::nullopt, reason))) + reason))) return failure(); return checkSupportedGroupChunkShape(resultType, *groupSize, reason); } @@ -3750,12 +3743,8 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { "load offset must convert to one value", rewriter); if (failed(source) || failed(offset)) return failure(); - std::optional explicitFullReadElems; - if (auto attr = op.getFullReadElemsAttr()) - explicitFullReadElems = attr.getInt(); FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( - op, resultVMIType, op.getSource().getType(), *offset, rewriter, - explicitFullReadElems); + op, resultVMIType, op.getSource().getType(), *offset, rewriter); if (failed(lanesPerPart)) return failure(); @@ -7585,13 +7574,11 @@ verifySupportedVMIToVPTOOps(ModuleOp module, bool enableStableGatherMaskedLoad) { auto emitMemoryUnsupported = [&](Operation *op, StringRef opName, VMIVRegType type, Value source, - std::optional constantOffset, - std::optional explicitFullReadElems = - std::nullopt) -> WalkResult { + std::optional constantOffset) -> WalkResult { std::string reason; if (succeeded(checkSupportedLoadShape(capabilities, type, source, source.getType(), constantOffset, - explicitFullReadElems, &reason))) + &reason))) return WalkResult::advance(); op->emitError() @@ -7689,13 +7676,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, } if (auto load = dyn_cast(op)) { - std::optional explicitFullReadElems; - if (auto attr = load.getFullReadElemsAttr()) - explicitFullReadElems = attr.getInt(); return emitMemoryUnsupported( op, "pto.vmi.load", cast(load.getResult().getType()), - load.getSource(), getConstantIndexValue(load.getOffset()), - explicitFullReadElems); + load.getSource(), getConstantIndexValue(load.getOffset())); } if (auto load = dyn_cast(op)) { std::string reason; diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index 602ac579ad..31e83e37d7 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -26,19 +26,19 @@ module { return } - func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( + func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( %src: !pto.ptr, %dst: !pto.ptr, %off: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c192 = arith.constant 192 : index - %x = pto.vmi.load %src[%c0] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> - %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> - %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} - : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> - pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + %x = pto.vmi.load %src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr return } } @@ -69,16 +69,16 @@ module { // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast -// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( // ASSIGN: %[[PX:.*]] = pto.vmi.load -// ASSIGN-SAME: {full_read_elems = 256 : i64} -// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> -// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[PMASK0:.*]] = pto.vmi.create_mask %{{.*}} : index -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] -// ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_contract( +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( // LOWER-COUNT-4: pto.vlds // LOWER-COUNT-3: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd diff --git a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto b/test/lit/vmi/vmi_load_full_read_elems_invalid.pto deleted file mode 100644 index 102efd4f0e..0000000000 --- a/test/lit/vmi/vmi_load_full_read_elems_invalid.pto +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s - -module { - func.func @vmi_load_full_read_elems_invalid(%src: !pto.ptr) { - %c0 = arith.constant 0 : index - %value = pto.vmi.load %src[%c0] {full_read_elems = 0} - : !pto.ptr -> !pto.vmi.vreg<100xf32> - return - } -} - -// CHECK: 'pto.vmi.load' op requires full_read_elems to be positive diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py index cf80936861..a521122803 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/golden.py @@ -22,7 +22,7 @@ def generate(output_dir: Path) -> None: src = np.empty(INPUT_ELEMS, dtype=np.float32) dst = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) - golden = np.full(PHYSICAL_ROWS, SENTINEL, dtype=np.float32) + golden = np.zeros(PHYSICAL_ROWS, dtype=np.float32) base_row = np.linspace(-0.875, 0.625, GROUP_SIZE, dtype=np.float32) for row in range(PHYSICAL_ROWS): diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto index fabed4ee8b..4e311c0703 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -32,14 +32,14 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<192xpred> - %x = pto.vmi.load %ub_src[%c0] {full_read_elems = 256} - : !pto.ptr -> !pto.vmi.vreg<192xf32> - %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} - : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> - pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] From 997bb3d08bb8efc535b7edea6fd76c645fe08797 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 16:56:04 +0800 Subject: [PATCH 31/54] Define VMI scatter as unique-index op --- docs/designs/vmi-dialect-design.md | 19 ++++++------- docs/designs/vmi-implementation-manual.md | 8 +----- include/PTO/IR/VMIOps.td | 3 +-- lib/PTO/Transforms/VMIToVPTO.cpp | 11 +++----- .../lit/vmi/vmi_layout_assignment_scatter.pto | 3 +-- test/lit/vmi/vmi_scatter_indices_invalid.pto | 2 +- ...i_to_vpto_gather_scatter_shape_invalid.pto | 4 +-- test/lit/vmi/vmi_to_vpto_scatter.pto | 2 +- ...to_vpto_scatter_missing_unique_invalid.pto | 27 ------------------- 9 files changed, 20 insertions(+), 59 deletions(-) delete mode 100644 test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md index 7569b787a0..897b1661cc 100644 --- a/docs/designs/vmi-dialect-design.md +++ b/docs/designs/vmi-dialect-design.md @@ -1178,7 +1178,7 @@ interleave/deinterleave boundary: vldsx2/vstsx2 dist or explicit rearrangement indexed memory: - gather/scatter if inactive and duplicate-index semantics match + gather/scatter; ordinary scatter requires pairwise-distinct active indices ``` GM-backed VMI memory is semantic input, not a direct vector load/store target. @@ -1356,20 +1356,21 @@ lanes to preserve passthru, so the `vsel` is semantically required, not an optim gather, tail gather, non-contiguous layout, memref/gm source, and fallback through guarded scalar load or scratch are future target-capability paths. -当前 `scatter` direct lowering 只在 VMI IR 携带显式 no-conflict proof 时启用: +`scatter` 的基础语义要求所有 active logical lanes 的 `%indices` 两两不同。inactive lane 不写内存, +因此不参与这个唯一性约束。如果两个 active lane 的 index 相同,程序违反 `pto.vmi.scatter` 的 +语义前置条件;VMI 不为这种输入定义 logical lane order 或 winner。 ```mlir -pto.vmi.scatter %v, %base[%indices], %mask {indices_unique} +pto.vmi.scatter %v, %base[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> ``` -`indices_unique` 的含义是:所有 active logical lanes 的 `%indices` 两两不同。这个 proof 可以来自 -producer 的静态分析、前端语义或上游 canonicalization;VMI lowering 不从 runtime 值猜测它。direct -path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、32-bit value -element、i32 indices 和 b32 mask。没有 `indices_unique` 时,`vmi-to-vpto` 必须诊断,而不能直接发 -`VSCATTER`,因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于 VMI -logical lane order。 +当前 direct path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、 +32-bit value element、i32 indices 和 b32 mask。允许冲突的 scatter 不能复用普通 `pto.vmi.scatter`, +因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于确定的 VMI logical +lane order。后续如果需要定义 duplicate-index scatter,需要新增显式语义,例如 ordered fallback、 +atomic scatter、reduce-scatter 或 target-specific unordered scatter。 `expand_load/compress_store` 表达 masked contiguous stream,不是 arbitrary indexed access: diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 2cd72208a6..f86a812f05 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -3730,13 +3730,11 @@ vmi.scatter: if mask[lane] is true, memory[base + indices[lane]] = value[lane] if mask[lane] is false, no memory write occurs for that lane indices are interpreted in element units, not bytes - if two active lanes have the same index, VMI logical semantics require an ordered conflict policy or an explicit - no-conflict proof before direct target lowering + all active lanes must have pairwise-distinct indices; duplicate active indices violate the VMI scatter contract layout assignment: value and indices uses are requested as contiguous mask use is requested as contiguous with granularity derived from value element width current direct path: - op must carry {indices_unique} destination must be !pto.ptr T must be a 32-bit element type indices must be signless or unsigned i32 @@ -3744,11 +3742,7 @@ vmi.scatter: mask granularity must be b32 for each physical chunk i: pto.vscatter value_i, destination, indices_i, mask_i - reason for indices_unique: - VSCATTER false predicate lanes do not write, but duplicate active indices have target-defined/undefined grant - behavior. VMI cannot lower duplicate-index logical order semantics to VSCATTER without a proof or fallback. unsupported cases: - missing indices_unique proof f16/b16/f8/i8 value element types partial/tail chunks non-contiguous layouts diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 263146ec1f..7eccc093a5 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -590,8 +590,7 @@ def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods:$indices_unique); + VMI_MaskTypeConstraint:$mask); let results = (outs); let hasVerifier = 1; let assemblyFormat = "$value `,` $destination `[` $indices `]` `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($indices) `,` type($mask)"; diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index b0da679232..39ca049a1e 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1374,10 +1374,6 @@ checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, return failure(); }; - if (!op->hasAttr("indices_unique")) - return fail("requires indices_unique proof because pto.vscatter does not " - "define logical-lane-order duplicate-index semantics"); - auto valueType = cast(op.getValue().getType()); auto indicesType = cast(op.getIndices().getType()); auto maskType = cast(op.getMask().getType()); @@ -7800,10 +7796,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::advance(); scatter.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.scatter lowers through pto.vscatter only with an " - "indices_unique proof, UB pointer destination, contiguous full " - "physical chunks, 32-bit value elements, i32 indices, and b32 " - "masks (" + << "pto.vmi.scatter lowers through pto.vscatter only with a UB " + "pointer destination, contiguous full physical chunks, 32-bit " + "value elements, i32 indices, and b32 masks (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_scatter.pto b/test/lit/vmi/vmi_layout_assignment_scatter.pto index 9560cfa981..b920cf4da4 100644 --- a/test/lit/vmi/vmi_layout_assignment_scatter.pto +++ b/test/lit/vmi/vmi_layout_assignment_scatter.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32>, %mask: !pto.vmi.mask<64xpred>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> return @@ -26,7 +26,6 @@ module { // CHECK-SAME: %arg2: !pto.vmi.vreg<64xi32, #pto.vmi.layout> // CHECK-SAME: %arg3: !pto.vmi.mask<64xb32, #pto.vmi.layout> // CHECK: pto.vmi.scatter %arg0, %arg1[%arg2], %arg3 -// CHECK-SAME: indices_unique // CHECK-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK-SAME: !pto.vmi.vreg<64xi32, #pto.vmi.layout> // CHECK-SAME: !pto.vmi.mask<64xb32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_scatter_indices_invalid.pto b/test/lit/vmi/vmi_scatter_indices_invalid.pto index bd59b81b04..e16d6905f0 100644 --- a/test/lit/vmi/vmi_scatter_indices_invalid.pto +++ b/test/lit/vmi/vmi_scatter_indices_invalid.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xf32>, %mask: !pto.vmi.mask<64xpred>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32>, !pto.ptr, !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> return diff --git a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto index c271e9f446..2e5afb7708 100644 --- a/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_gather_scatter_shape_invalid.pto @@ -57,7 +57,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<64xi32, #pto.vmi.layout>, @@ -77,7 +77,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<32xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<32xi32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_scatter.pto b/test/lit/vmi/vmi_to_vpto_scatter.pto index 12799c01fc..4f898e3571 100644 --- a/test/lit/vmi/vmi_to_vpto_scatter.pto +++ b/test/lit/vmi/vmi_to_vpto_scatter.pto @@ -14,7 +14,7 @@ module { %dst: !pto.ptr, %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask {indices_unique} + pto.vmi.scatter %value, %dst[%indices], %mask : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr, !pto.vmi.vreg<64xi32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto b/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto deleted file mode 100644 index 027162ac68..0000000000 --- a/test/lit/vmi/vmi_to_vpto_scatter_missing_unique_invalid.pto +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s - -module { - func.func @vmi_to_vpto_scatter_missing_unique_invalid( - %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %dst: !pto.ptr, - %indices: !pto.vmi.vreg<64xi32, #pto.vmi.layout>, - %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { - pto.vmi.scatter %value, %dst[%indices], %mask - : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - !pto.ptr, - !pto.vmi.vreg<64xi32, #pto.vmi.layout>, - !pto.vmi.mask<64xb32, #pto.vmi.layout> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.scatter lowers through pto.vscatter only with an indices_unique proof -// CHECK-SAME: requires indices_unique proof From 96e5c9fbf425c2fbafcc1a8feb128af70b730f16 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 22:16:41 +0800 Subject: [PATCH 32/54] Add VMI group max quant kernel case --- include/PTO/IR/VMIOps.td | 10 + include/PTO/Transforms/VMILayoutSupport.h | 75 +++-- .../PTO/Transforms/VMITargetCapabilities.h | 35 +- lib/PTO/IR/VMI.cpp | 54 ++-- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 113 ++++--- lib/PTO/Transforms/VMILayoutAssignment.cpp | 156 ++++++--- lib/PTO/Transforms/VMILayoutSupport.cpp | 193 ++++++----- lib/PTO/Transforms/VMIToVPTO.cpp | 300 ++++++++++-------- ...out_assignment_group_reduce_maxf_quant.pto | 78 +++++ .../simdvf-per-token-cast-to-fp8/compare.py | 49 +++ .../simdvf-per-token-cast-to-fp8/golden.py | 62 ++++ .../simdvf-per-token-cast-to-fp8/kernel.pto | 79 +++++ .../simdvf-per-token-cast-to-fp8/launch.cpp | 43 +++ .../simdvf-per-token-cast-to-fp8/main.cpp | 91 ++++++ .../simdvf-per-token-cast-to-fp8/ptoas.flags | 1 + 15 files changed, 945 insertions(+), 394 deletions(-) create mode 100644 test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 7eccc093a5..9acce8cd7b 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -418,6 +418,16 @@ def VMIGroupReduceAddFOp : VMI_Op<"group_reduce_addf"> { let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceMaxFOp : VMI_Op<"group_reduce_maxf"> { + let summary = "VMI masked floating-point maximum reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { let summary = "VMI masked integer add reduction within fixed logical groups"; let arguments = (ins VMI_VRegTypeConstraint:$source, diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 41b686b322..429a20bf0d 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -200,82 +200,93 @@ class VMILayoutSupport { public: FailureOr getContiguousStoreSupport(VMIVRegType valueType, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; - LogicalResult canFoldContiguousStoreMaterialization( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; + LogicalResult + canFoldContiguousStoreMaterialization(VMIVRegType sourceType, + VMIVRegType resultType, + std::string *reason = nullptr) const; FailureOr getDataLayoutMaterializationSupport(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; + VMIVRegType resultType, + std::string *reason = nullptr) const; LogicalResult canMaterializeDataLayout(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason = nullptr) const; + VMIVRegType resultType, + std::string *reason = nullptr) const; FailureOr getMaskLayoutMaterializationSupport(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; LogicalResult canMaterializeMaskLayout(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; FailureOr getMaskGranularityMaterializationSupport(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason = nullptr) const; + VMIMaskType resultType, + std::string *reason = nullptr) const; - LogicalResult canMaterializeMaskGranularity( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason = nullptr) const; + LogicalResult + canMaterializeMaskGranularity(VMIMaskType sourceType, VMIMaskType resultType, + std::string *reason = nullptr) const; FailureOr getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; FailureOr getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupSlotLoadOp op, - std::string *reason = nullptr) const; + VMIGroupSlotLoadOp op, + std::string *reason = nullptr) const; FailureOr getGroupLoadSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupLoadOp op, - std::string *reason = nullptr) const; + VMIGroupLoadOp op, std::string *reason = nullptr) const; FailureOr getGroupSlotsStoreSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupStoreOp op, - std::string *reason = nullptr) const; + VMIGroupStoreOp op, + std::string *reason = nullptr) const; FailureOr getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, int64_t numGroups, - std::string *reason = nullptr) const; + std::string *reason = nullptr) const; FailureOr getGroupReduceAddFSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddFOp op, - std::string *reason = nullptr) const; + VMIGroupReduceAddFOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupReduceMaxFSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceMaxFOp op, + std::string *reason = nullptr) const; FailureOr getGroupReduceAddISupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupReduceAddIOp op, + VMIGroupReduceAddIOp op, + std::string *reason = nullptr) const; + + FailureOr + getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastOp op, std::string *reason = nullptr) const; FailureOr getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, - VMIGroupBroadcastOp op, - std::string *reason = nullptr) const; + VMIVRegType sourceType, VMIVRegType resultType, + int64_t numGroups, + std::string *reason = nullptr) const; FailureOr getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; - FailureOr - getExtFSupport(VMIExtFOp op, std::string *reason = nullptr) const; + FailureOr getExtFSupport(VMIExtFOp op, + std::string *reason = nullptr) const; FailureOr getExtSISupport(VMIExtSIOp op, std::string *reason = nullptr) const; diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index a96a73a6d0..043da612e6 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- VMITargetCapabilities.h - VMI target capability registry -*- C++ -*-===// //===----------------------------------------------------------------------===// @@ -44,6 +46,7 @@ enum class VMIReductionKind { AddF, GroupAddI, GroupAddF, + GroupMaxF, MaxF, MinF, }; @@ -66,9 +69,7 @@ struct VMICapabilityResult { return result; } - bool isSupported() const { - return status == VMICapabilityStatus::supported; - } + bool isSupported() const { return status == VMICapabilityStatus::supported; } LogicalResult toLogicalResult(std::string *outReason = nullptr) const { if (isSupported()) @@ -188,8 +189,9 @@ class VMITargetCapabilityRegistry { "unsupported source/result layout pair"); } - VMICapabilityResult supportsMaskGranularityConversion( - StringRef sourceGranularity, StringRef resultGranularity) const { + VMICapabilityResult + supportsMaskGranularityConversion(StringRef sourceGranularity, + StringRef resultGranularity) const { if (!VMIMaskType::isConcreteGranularity(sourceGranularity) || !VMIMaskType::isConcreteGranularity(resultGranularity)) return VMICapabilityResult::missingCapability( @@ -207,8 +209,8 @@ class VMITargetCapabilityRegistry { "current VPTO pto.vlds surface has no mask operand"); } - VMICapabilityResult supportsFallbackResource( - VMIFallbackResourceKind kind) const { + VMICapabilityResult + supportsFallbackResource(VMIFallbackResourceKind kind) const { switch (kind) { case VMIFallbackResourceKind::ScratchMemory: return VMICapabilityResult::missingCapability( @@ -220,8 +222,8 @@ class VMITargetCapabilityRegistry { llvm_unreachable("unhandled VMI fallback resource kind"); } - VMICapabilityResult supportsReductionElementType( - VMIReductionKind kind, Type elementType) const { + VMICapabilityResult supportsReductionElementType(VMIReductionKind kind, + Type elementType) const { switch (kind) { case VMIReductionKind::AddI: if (pto::getPTOStorageElemBitWidth(elementType) == 32 && @@ -246,10 +248,11 @@ class VMITargetCapabilityRegistry { "cast i8/i16 storage before grouped reduction"); } case VMIReductionKind::GroupAddF: + case VMIReductionKind::GroupMaxF: if (elementType.isF16() || elementType.isF32()) return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( - "grouped floating-point add reduction supports f16/f32 accumulator " + "grouped floating-point reduction supports f16/f32 accumulator " "elements"); case VMIReductionKind::MaxF: case VMIReductionKind::MinF: diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 3589ec603e..25e08ac381 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -532,10 +532,9 @@ VMILayoutAttr::verify(function_ref emitError, return emitError() << "#pto.vmi.layout requires block_elems to be 1"; if (slots < 0) - return emitError() - << "#pto.vmi.layout requires slots to be omitted or positive"; + return emitError() << "#pto.vmi.layout requires slots to be omitted or positive"; return success(); } @@ -1121,21 +1120,23 @@ LogicalResult VMIReduceMaxFOp::verify() { return verifyReduceMinMaxFOp(*this); } LogicalResult VMIReduceMinFOp::verify() { return verifyReduceMinMaxFOp(*this); } -LogicalResult VMIGroupReduceAddFOp::verify() { - auto sourceType = cast(getSource().getType()); - auto maskType = cast(getMask().getType()); - auto resultType = cast(getResult().getType()); - if (!getOperation()->hasAttr("reassoc")) - return emitOpError( +template +static LogicalResult verifyGroupReduceFloatOp(OpTy op, bool requiresReassoc) { + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); + if (requiresReassoc && !op->hasAttr("reassoc")) + return op.emitOpError( "requires reassoc attr because grouped lowering uses pair-wise " "floating-point reductions"); if (!isVMIFloatLikeType(sourceType.getElementType())) - return emitOpError("requires floating-point-like VMI source element type"); + return op.emitOpError( + "requires floating-point-like VMI source element type"); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError( + return op.emitOpError( "requires source and result logical lane counts to match"); if (sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires source and result element types to match"); + return op.emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { bool supportedSourceLayout = sourceLayout.isContiguous() || @@ -1146,21 +1147,29 @@ LogicalResult VMIGroupReduceAddFOp::verify() { (sourceLayout.getBlockElems() == 1 || sourceLayout.getBlockElems() == 8)); if (!supportedSourceLayout) - return emitOpError( + return op.emitOpError( "requires layout-assigned source to use contiguous layout or " "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); } if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() << "requires layout-assigned result to use " - "#pto.vmi.layout"; + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return op.emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; } - if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + if (failed(verifyMaskMatchesData(op.getOperation(), maskType, sourceType))) return failure(); - return verifyNumGroups(getOperation(), sourceType, - getNumGroupsAttr().getInt()); + return verifyNumGroups(op.getOperation(), sourceType, + op.getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupReduceAddFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/true); +} + +LogicalResult VMIGroupReduceMaxFOp::verify() { + return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/false); } LogicalResult VMIGroupReduceAddIOp::verify() { @@ -1231,8 +1240,7 @@ LogicalResult VMIGroupBroadcastOp::verify() { getNumGroupsAttr().getInt()); } -template -static LogicalResult verifyVMIHistogramOp(OpTy op) { +template static LogicalResult verifyVMIHistogramOp(OpTy op) { auto accType = cast(op.getAcc().getType()); auto sourceType = cast(op.getSource().getType()); auto maskType = cast(op.getMask().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 1186a25e26..7529953fa5 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. //===- PTOValidateVMIIR.cpp - VMI boundary verifier ----------------------===// //===----------------------------------------------------------------------===// @@ -50,9 +52,9 @@ bool containsVMIOrPhysicalType(Type type) { return true; if (auto functionType = dyn_cast(type)) { - return llvm::any_of(functionType.getInputs(), [](Type input) { - return containsVMIOrPhysicalType(input); - }) || + return llvm::any_of( + functionType.getInputs(), + [](Type input) { return containsVMIOrPhysicalType(input); }) || llvm::any_of(functionType.getResults(), [](Type result) { return containsVMIOrPhysicalType(result); }); @@ -110,8 +112,8 @@ bool isVMIHelperOp(Operation *op) { StringRef name = op->getName().getStringRef(); return name == "pto.vmi.ensure_layout" || name == "pto.vmi.ensure_mask_layout" || - name == "pto.vmi.ensure_mask_granularity" || - name == "pto.vmi.pack" || name == "pto.vmi.unpack"; + name == "pto.vmi.ensure_mask_granularity" || name == "pto.vmi.pack" || + name == "pto.vmi.unpack"; } bool isVMILayoutHelperOp(Operation *op) { @@ -155,8 +157,8 @@ void mirrorDiagnostic(llvm::raw_ostream *diagOS, Twine message) { LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagPassInvariantPrefix << message; + InFlightDiagnostic diag = op->emitError() + << kVMIDiagPassInvariantPrefix << message; (void)diag; mirrorDiagnostic(diagOS, Twine(kVMIDiagPassInvariantPrefix) + message); return failure(); @@ -164,8 +166,8 @@ LogicalResult emitInvariant(Operation *op, llvm::raw_ostream *diagOS, LogicalResult emitLayoutContract(Operation *op, llvm::raw_ostream *diagOS, Twine message) { - InFlightDiagnostic diag = - op->emitError() << kVMIDiagLayoutContractPrefix << message; + InFlightDiagnostic diag = op->emitError() + << kVMIDiagLayoutContractPrefix << message; (void)diag; mirrorDiagnostic(diagOS, Twine(kVMIDiagLayoutContractPrefix) + message); return failure(); @@ -198,17 +200,15 @@ LogicalResult emitLayoutSupportContract(Operation *op, return emitLayoutContract(op, diagOS, text); } -LogicalResult emitHelperMaterializationContract(Operation *helper, - Type sourceType, - Type resultType, - StringRef helperName, - StringRef reason, - llvm::raw_ostream *diagOS) { +LogicalResult +emitHelperMaterializationContract(Operation *helper, Type sourceType, + Type resultType, StringRef helperName, + StringRef reason, llvm::raw_ostream *diagOS) { auto emitFallback = [&]() { return emitLayoutContract( helper, diagOS, - Twine(helperName) + " has no registered materialization support: " + - reason); + Twine(helperName) + + " has no registered materialization support: " + reason); }; if (helper->getNumResults() != 1 || !helper->getResult(0).hasOneUse()) @@ -223,8 +223,8 @@ LogicalResult emitHelperMaterializationContract(Operation *helper, << helperName << " has no registered materialization support: " << reason; os.flush(); - InFlightDiagnostic diag = - requester->emitError() << kVMIDiagLayoutContractPrefix << message; + InFlightDiagnostic diag = requester->emitError() + << kVMIDiagLayoutContractPrefix << message; diag.attachNote(helper->getLoc()) << "failed helper conversion " << sourceType << " -> " << resultType << " (" << reason << ")"; @@ -340,8 +340,7 @@ bool isFunctionTypeAttr(Operation *op, NamedAttribute attr) { return isa(op) && attr.getName() == "function_type"; } -LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, - NamedAttribute attr, +LogicalResult verifyNoHiddenVMIAttributeType(Operation *op, NamedAttribute attr, llvm::raw_ostream *diagOS) { if (isFunctionTypeAttr(op, attr)) return success(); @@ -424,10 +423,10 @@ LogicalResult verifyLayoutAssignedOperationTypes(Operation *op, } LogicalResult verifyLayoutHelperSupport(Operation *op, - llvm::raw_ostream *diagOS); + llvm::raw_ostream *diagOS); LogicalResult verifyLayoutSemanticSupport(Operation *op, - llvm::raw_ostream *diagOS); + llvm::raw_ostream *diagOS); LogicalResult verifyOperationBoundary(Operation *op, llvm::raw_ostream *diagOS) { @@ -461,7 +460,7 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, if (isVMIHelperOp(op)) { if (isVMILayoutHelperOp(op)) return verifyHelperSupports ? verifyLayoutHelperSupport(op, diagOS) - : success(); + : success(); return emitInvariant( op, diagOS, "VMI pack/unpack helper appears before VMI-to-VPTO physicalization"); @@ -477,15 +476,15 @@ LogicalResult verifyLayoutAssignedOperation(Operation *op, } LogicalResult verifyLayoutHelperSupport(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS) { VMILayoutSupport supports; if (auto ensure = dyn_cast(op)) { auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(supports.canMaterializeDataLayout(sourceType, resultType, - &reason))) + if (failed( + supports.canMaterializeDataLayout(sourceType, resultType, &reason))) return emitHelperMaterializationContract( op, sourceType, resultType, "pto.vmi.ensure_layout", reason, diagOS); return success(); @@ -495,11 +494,11 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, - &reason))) - return emitHelperMaterializationContract( - op, sourceType, resultType, "pto.vmi.ensure_mask_layout", reason, - diagOS); + if (failed( + supports.canMaterializeMaskLayout(sourceType, resultType, &reason))) + return emitHelperMaterializationContract(op, sourceType, resultType, + "pto.vmi.ensure_mask_layout", + reason, diagOS); return success(); } @@ -508,7 +507,7 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, auto resultType = cast(ensure.getResult().getType()); std::string reason; if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, - &reason))) + &reason))) return emitLayoutContract( op, diagOS, Twine("pto.vmi.ensure_mask_granularity has no registered " @@ -521,7 +520,7 @@ LogicalResult verifyLayoutHelperSupport(Operation *op, } LogicalResult verifyLayoutSemanticSupport(Operation *op, - llvm::raw_ostream *diagOS) { + llvm::raw_ostream *diagOS) { VMILayoutSupport supports; VMITargetCapabilityRegistry capabilities; @@ -587,7 +586,8 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); std::string reason; - if (failed(supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) + if (failed( + supports.getGroupSlotsStoreSupport(capabilities, store, &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_store has no registered group_slots layout support", @@ -602,8 +602,8 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); std::string reason; - if (failed(supports.getGroupReduceAddFSupport(capabilities, reduce, - &reason))) + if (failed( + supports.getGroupReduceAddFSupport(capabilities, reduce, &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_reduce_addf has no registered group_slots layout " @@ -612,6 +612,23 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceMaxFSupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_maxf has no registered group_slots layout " + "support", + reason); + return success(); + } + if (auto broadcast = dyn_cast(op)) { auto sourceType = cast(broadcast.getSource().getType()); VMILayoutAttr layout = sourceType.getLayoutAttr(); @@ -620,7 +637,7 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, std::string reason; if (failed(supports.getGroupBroadcastSupport(capabilities, broadcast, - &reason))) + &reason))) return emitLayoutSupportContract( op, diagOS, "pto.vmi.group_broadcast has no registered layout support", reason); @@ -697,8 +714,9 @@ struct PTOValidateVMILayoutIRPass } // namespace -LogicalResult mlir::pto::validateVMIProducerBoundaryIR( - ModuleOp module, llvm::raw_ostream *diagOS) { +LogicalResult +mlir::pto::validateVMIProducerBoundaryIR(ModuleOp module, + llvm::raw_ostream *diagOS) { WalkResult result = module.walk([&](Operation *op) { if (failed(verifyOperationBoundary(op, diagOS))) return WalkResult::interrupt(); @@ -710,8 +728,7 @@ LogicalResult mlir::pto::validateVMIProducerBoundaryIR( LogicalResult mlir::pto::validateVMILayoutAssignedIR( ModuleOp module, llvm::raw_ostream *diagOS, bool verifyHelperSupports) { WalkResult result = module.walk([&](Operation *op) { - if (failed(verifyLayoutAssignedOperation(op, diagOS, - verifyHelperSupports))) + if (failed(verifyLayoutAssignedOperation(op, diagOS, verifyHelperSupports))) return WalkResult::interrupt(); return WalkResult::advance(); }); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 99e4314cf9..f976b0d5a7 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -349,6 +349,7 @@ struct LayoutSolver { solved.getSlots() > 0) return solved; if (value.getDefiningOp() || + value.getDefiningOp() || value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); if (value.getDefiningOp()) @@ -387,7 +388,8 @@ struct LayoutSolver { if (!resultType) continue; unsigned resultBits = getElementBitWidth(resultType.getElementType()); - std::optional vlaneElems = getVLaneElems(sourceType.getElementType()); + std::optional vlaneElems = + getVLaneElems(sourceType.getElementType()); if (vlaneElems && groupSize == 2 * *vlaneElems && resultBits == 16) return true; if (vlaneElems && groupSize == 4 * *vlaneElems && resultBits == 8) @@ -408,8 +410,8 @@ struct LayoutSolver { (layout.getBlockElems() == 1 || layout.getBlockElems() == 8); } - VMILayoutAttr getTruncFCompatibleGroupReduceSourceLayout( - VMIGroupReduceLayoutFact fact) { + VMILayoutAttr + getTruncFCompatibleGroupReduceSourceLayout(VMIGroupReduceLayoutFact fact) { if (fact.kind == VMIGroupReduceLayoutKind::TwoVLane) return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); if (fact.kind == VMIGroupReduceLayoutKind::FourVLane) @@ -458,6 +460,42 @@ struct LayoutSolver { VMISelectOp, VMIBitcastOp>(op); } + bool canGroupBroadcastProduceLayout(VMIGroupBroadcastOp broadcast, + VMILayoutAttr resultLayout) { + if (!resultLayout) + return false; + auto sourceType = cast(broadcast.getSource().getType()); + auto resultType = cast(broadcast.getResult().getType()); + int64_t numGroups = broadcast.getNumGroupsAttr().getInt(); + auto assignedSourceType = VMIVRegType::get( + ctx, sourceType.getElementCount(), sourceType.getElementType(), + getPreferredGroupSlotsLayout(sourceType, numGroups)); + auto assignedResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + VMILayoutSupport supports; + return succeeded(supports.getGroupBroadcastSupport( + capabilities, assignedSourceType, assignedResultType, numGroups)); + } + + bool canEquivalenceClassAdoptConsumerLayout(Value value, + VMILayoutAttr requestedLayout) { + unsigned id = addDataValue(value); + if (id == ~0u) + return true; + unsigned root = find(id); + for (DataNode &node : dataNodes) { + if (find(dataIds.lookup(node.value)) != root) + continue; + if (auto broadcast = node.value.getDefiningOp()) { + if (node.value == broadcast.getResult() && + !canGroupBroadcastProduceLayout(broadcast, requestedLayout)) + return false; + } + } + return true; + } + bool canAdoptConsumerRequestedLayout(Value value, VMILayoutAttr requestedLayout) { Operation *definingOp = value.getDefiningOp(); @@ -469,6 +507,8 @@ struct LayoutSolver { if (!canProducerAdoptConsumerLayout(definingOp)) return false; } + if (!canEquivalenceClassAdoptConsumerLayout(value, requestedLayout)) + return false; if (value.hasOneUse()) return true; @@ -502,9 +542,7 @@ struct LayoutSolver { unsigned root = find(id); VMILayoutAttr existing = dataNodes[root].naturalLayout; if (existing && existing != request.layout) - return request.operand->getOwner()->emitError() - << kVMIDiagLayoutContractPrefix << "conflicting natural layouts " - << existing << " and " << request.layout; + continue; dataNodes[root].naturalLayout = request.layout; } return success(); @@ -560,6 +598,7 @@ struct LayoutSolver { bool sourceIsGroupSlotValue = (sourceLayout && sourceLayout.isGroupSlots()) || truncf.getSource().getDefiningOp() || + truncf.getSource().getDefiningOp() || truncf.getSource().getDefiningOp(); if (!sourceIsGroupSlotValue) return false; @@ -828,8 +867,44 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); - VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) { + sourceLayout = solvedSourceLayout; + } else if (!sourceType.getLayoutAttr() && succeeded(fact)) { + if (hasCompatibleTruncFUseForGroupReduce(reduce.getSource(), + fact->groupSize)) { + if (VMILayoutAttr truncLayout = + getTruncFCompatibleGroupReduceSourceLayout(*fact)) + sourceLayout = truncLayout; + } + } + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); if (solvedSourceLayout && succeeded(fact) && @@ -850,9 +925,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - succeeded(fact) ? fact->resultLayout - : getPreferredGroupSlotsLayout(resultType, - numGroups), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -864,8 +939,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); - VMILayoutAttr sourceLayout = getPreferredGroupReduceSourceLayout( - sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); VMILayoutAttr solvedSourceLayout = getExplicitDataLayout(reduce.getSource()); if (solvedSourceLayout && succeeded(fact) && @@ -878,9 +953,9 @@ struct LayoutSolver { return WalkResult::interrupt(); if (failed(setNaturalLayout( reduce.getResult(), - succeeded(fact) ? fact->resultLayout - : getPreferredGroupSlotsLayout(resultType, - numGroups), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -898,8 +973,8 @@ struct LayoutSolver { if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), "b8", op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -909,8 +984,8 @@ struct LayoutSolver { if (failed(requestMaskUse(hist.getMaskMutable(), getContiguousLayout(), "b8", op))) return WalkResult::interrupt(); - if (failed(setNaturalLayout(hist.getResult(), getContiguousLayout(), - op))) + if (failed( + setNaturalLayout(hist.getResult(), getContiguousLayout(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -920,9 +995,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extf.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extf.getResult(), fact->resultLayout, op))) @@ -936,9 +1010,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extsi.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extsi.getResult(), fact->resultLayout, op))) @@ -952,9 +1025,8 @@ struct LayoutSolver { VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Widen2x || - fact->kind == VMICastLayoutKind::Widen4x)) { + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x)) { requestDataUse(extui.getSourceMutable(), fact->sourceLayout); if (failed( setNaturalLayout(extui.getResult(), fact->resultLayout, op))) @@ -970,16 +1042,15 @@ struct LayoutSolver { supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(truncf.getSource()); if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout && - sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 1) { requestDataUse(truncf.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(truncf.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); VMILayoutAttr resultLayout = succeeded(fact) ? fact->resultLayout : getContiguousLayout(); @@ -995,16 +1066,15 @@ struct LayoutSolver { supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout && - sourceLayout.isGroupSlots() && sourceLayout.getSlots() == 1) { + sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 1) { requestDataUse(trunci.getSourceMutable(), sourceLayout); if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); VMILayoutAttr resultLayout = succeeded(fact) ? fact->resultLayout : getContiguousLayout(); @@ -1551,6 +1621,14 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto reduce = dyn_cast(op)) { auto sourceType = cast(reduce.getSource().getType()); if (failed(requestMaskUse( diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index acb687eed0..a3babbf7ab 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -313,10 +313,10 @@ getPhysicalLogicalBitFootprint(VMIVRegType type) { static FailureOr getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, - VMILayoutAttr resultLayout, - std::string *reason) { - auto fail = [&](const Twine &message) - -> FailureOr { + VMILayoutAttr resultLayout, + std::string *reason) { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -341,8 +341,9 @@ getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, } // namespace FailureOr -VMILayoutSupport::getPreferredGroupReduceLayoutFact( - VMIVRegType sourceType, int64_t numGroups, std::string *reason) const { +VMILayoutSupport::getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, + int64_t numGroups, + std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -407,10 +408,8 @@ VMILayoutSupport::getPreferredGroupReduceLayoutFact( "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); } -FailureOr -VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason) const { +FailureOr VMILayoutSupport::getPreferredCastLayoutFact( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -422,7 +421,8 @@ VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, unsigned resultBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); if (sourceBits == 0 || resultBits == 0) - return fail("requires source/result element types with known storage width"); + return fail( + "requires source/result element types with known storage width"); if (sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result lane count to match"); @@ -472,9 +472,9 @@ VMILayoutSupport::getPreferredCastLayoutFact(VMIVRegType sourceType, FailureOr VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -530,10 +530,9 @@ LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( FailureOr VMILayoutSupport::getDataLayoutMaterializationSupport( - VMIVRegType sourceType, VMIVRegType resultType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -550,29 +549,25 @@ VMILayoutSupport::getDataLayoutMaterializationSupport( getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); if (failed(support)) return failure(); - if (failed(checkLayoutMaterializationShape(sourceType, resultType, - sourceLayout, resultLayout, - reason))) + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); return support; } -LogicalResult -VMILayoutSupport::canMaterializeDataLayout(VMIVRegType sourceType, - VMIVRegType resultType, - std::string *reason) const { - if (failed(getDataLayoutMaterializationSupport(sourceType, resultType, - reason))) +LogicalResult VMILayoutSupport::canMaterializeDataLayout( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + if (failed( + getDataLayoutMaterializationSupport(sourceType, resultType, reason))) return failure(); return success(); } FailureOr VMILayoutSupport::getMaskLayoutMaterializationSupport( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { - auto fail = [&](const Twine &message) - -> FailureOr { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); return failure(); @@ -589,27 +584,23 @@ VMILayoutSupport::getMaskLayoutMaterializationSupport( getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); if (failed(support)) return failure(); - if (failed(checkLayoutMaterializationShape(sourceType, resultType, - sourceLayout, resultLayout, - reason))) + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) return failure(); return support; } -LogicalResult -VMILayoutSupport::canMaterializeMaskLayout(VMIMaskType sourceType, - VMIMaskType resultType, - std::string *reason) const { - if (failed(getMaskLayoutMaterializationSupport(sourceType, resultType, - reason))) +LogicalResult VMILayoutSupport::canMaterializeMaskLayout( + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { + if (failed( + getMaskLayoutMaterializationSupport(sourceType, resultType, reason))) return failure(); return success(); } FailureOr VMILayoutSupport::getMaskGranularityMaterializationSupport( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -633,16 +624,14 @@ VMILayoutSupport::getMaskGranularityMaterializationSupport( } LogicalResult VMILayoutSupport::canMaterializeMaskGranularity( - VMIMaskType sourceType, VMIMaskType resultType, - std::string *reason) const { + VMIMaskType sourceType, VMIMaskType resultType, std::string *reason) const { if (failed(getMaskGranularityMaterializationSupport(sourceType, resultType, - reason))) + reason))) return failure(); return success(); } -FailureOr -VMILayoutSupport::getGroupSlotLoadSupport( +FailureOr VMILayoutSupport::getGroupSlotLoadSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { @@ -711,9 +700,8 @@ FailureOr VMILayoutSupport::getGroupLoadSupport( !resultType.getElementType().isF32()) return fail("requires deinterleaved block8 f32 result layout"); - FailureOr groupSize = - getGroupSizeFromNumGroups(resultType, op.getNumGroupsAttr().getInt(), - reason); + FailureOr groupSize = getGroupSizeFromNumGroups( + resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); @@ -815,13 +803,11 @@ VMILayoutSupport::getGroupSlotsStoreSupport( "unit-stride slots=8"); } -FailureOr -getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, - Operation *op, VMIVRegType sourceType, - VMIMaskType maskType, VMIVRegType resultType, - int64_t numGroups, bool requiresReassoc, - VMIReductionKind reductionKind, - std::string *reason) { +FailureOr getGroupReduceAddSupportImpl( + const VMITargetCapabilityRegistry &capabilities, Operation *op, + VMIVRegType sourceType, VMIMaskType maskType, VMIVRegType resultType, + int64_t numGroups, bool requiresReassoc, VMIReductionKind reductionKind, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -845,9 +831,9 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, getGroupSizeFromNumGroups(sourceType, numGroups, reason); FailureOr lanesPerPart = getDataLanesPerPart(sourceType.getElementType()); - int64_t vlaneElems = - succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 ? *lanesPerPart / 8 - : -1; + int64_t vlaneElems = succeeded(lanesPerPart) && *lanesPerPart % 8 == 0 + ? *lanesPerPart / 8 + : -1; if (succeeded(groupSize) && resultLayout.getSlots() <= 0 && (*groupSize != vlaneElems && *groupSize != 2 * vlaneElems && *groupSize != 4 * vlaneElems)) @@ -933,10 +919,10 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, return fail("two-vlane group_reduce_add requires matching mask layout " "deinterleaved=2 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || - *sourceArity != *resultArity * 2) - return fail("two-vlane group_reduce_add requires two source/mask parts per " - "result part"); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 2) + return fail( + "two-vlane group_reduce_add requires two source/mask parts per " + "result part"); return VMIGroupReduceAddFSupport{ VMIGroupReduceAddFSupportKind::TwoVLaneDeinterleaved2VcgaddVadd}; } @@ -952,10 +938,10 @@ getGroupReduceAddSupportImpl(const VMITargetCapabilityRegistry &capabilities, return fail("four-vlane group_reduce_add requires matching mask layout " "deinterleaved=4 with the same block_elems"); int64_t expectedResultArity = ceilDivNonNegative(numGroups, 8); - if (*resultArity != expectedResultArity || - *sourceArity != *resultArity * 4) - return fail("four-vlane group_reduce_add requires four source/mask parts per " - "result part"); + if (*resultArity != expectedResultArity || *sourceArity != *resultArity * 4) + return fail( + "four-vlane group_reduce_add requires four source/mask parts per " + "result part"); return VMIGroupReduceAddFSupport{ VMIGroupReduceAddFSupportKind::FourVLaneDeinterleaved4VcgaddTree}; } @@ -969,29 +955,52 @@ VMILayoutSupport::getGroupReduceAddFSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddFOp op, std::string *reason) const { return getGroupReduceAddSupportImpl( - capabilities, op.getOperation(), cast(op.getSource().getType()), + capabilities, op.getOperation(), + cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/true, VMIReductionKind::GroupAddF, reason); } +FailureOr +VMILayoutSupport::getGroupReduceMaxFSupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceMaxFOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupMaxF, reason); +} + FailureOr VMILayoutSupport::getGroupReduceAddISupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceAddIOp op, std::string *reason) const { return getGroupReduceAddSupportImpl( - capabilities, op.getOperation(), cast(op.getSource().getType()), + capabilities, op.getOperation(), + cast(op.getSource().getType()), cast(op.getMask().getType()), cast(op.getResult().getType()), op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, VMIReductionKind::GroupAddI, reason); } -FailureOr -VMILayoutSupport::getGroupBroadcastSupport( +FailureOr VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, std::string *reason) const { + return getGroupBroadcastSupport(capabilities, + cast(op.getSource().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), reason); +} + +FailureOr VMILayoutSupport::getGroupBroadcastSupport( + const VMITargetCapabilityRegistry &capabilities, VMIVRegType sourceType, + VMIVRegType resultType, int64_t numGroups, std::string *reason) const { (void)capabilities; auto fail = [&](const Twine &message) -> FailureOr { if (reason) @@ -999,15 +1008,12 @@ VMILayoutSupport::getGroupBroadcastSupport( return failure(); }; - auto sourceType = cast(op.getSource().getType()); - auto resultType = cast(op.getResult().getType()); if (sourceType.getElementType() != resultType.getElementType() || sourceType.getElementCount() != resultType.getElementCount()) return fail("requires source/result shape and element type to match"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); - int64_t numGroups = op.getNumGroupsAttr().getInt(); if (!sourceLayout || !resultLayout) return fail("requires assigned source/result layouts"); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != numGroups) @@ -1068,8 +1074,7 @@ VMILayoutSupport::getGroupBroadcastSupport( } FailureOr -VMILayoutSupport::getTruncFSupport(VMITruncFOp op, - std::string *reason) const { +VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1126,8 +1131,7 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, } FailureOr -VMILayoutSupport::getExtFSupport(VMIExtFOp op, - std::string *reason) const { +VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1159,13 +1163,11 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, if (fact->kind == VMICastLayoutKind::Widen2x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtFSupport{ - VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; if (fact->kind == VMICastLayoutKind::Widen4x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtFSupport{ - VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; + return VMIExtFSupport{VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; return fail("unsupported extf source element width, result factor, or " "physical arity"); @@ -1173,7 +1175,7 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, template static FailureOr getExtISupportImpl(OpT op, - std::string *reason) { + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1207,33 +1209,28 @@ static FailureOr getExtISupportImpl(OpT op, if (fact->kind == VMICastLayoutKind::Widen2x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtISupport{ - VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + return VMIExtISupport{VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; if (fact->kind == VMICastLayoutKind::Widen4x && resultLayout.getFactor() == fact->factor && *resultArity == fact->factor * *sourceArity) - return VMIExtISupport{ - VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; + return VMIExtISupport{VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; return fail("unsupported integer extension source/result element width, " "result factor, or physical arity"); } FailureOr -VMILayoutSupport::getExtSISupport(VMIExtSIOp op, - std::string *reason) const { +VMILayoutSupport::getExtSISupport(VMIExtSIOp op, std::string *reason) const { return getExtISupportImpl(op, reason); } FailureOr -VMILayoutSupport::getExtUISupport(VMIExtUIOp op, - std::string *reason) const { +VMILayoutSupport::getExtUISupport(VMIExtUIOp op, std::string *reason) const { return getExtISupportImpl(op, reason); } FailureOr -VMILayoutSupport::getTruncISupport(VMITruncIOp op, - std::string *reason) const { +VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1302,7 +1299,7 @@ VMILayoutSupport::getTruncISupport(VMITruncIOp op, FailureOr VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, - std::string *reason) const { + std::string *reason) const { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -1399,14 +1396,12 @@ getHistogramSupportImpl(OpTy op, std::string *reason) { } FailureOr -VMILayoutSupport::getDhistSupport(VMIDhistOp op, - std::string *reason) const { +VMILayoutSupport::getDhistSupport(VMIDhistOp op, std::string *reason) const { return getHistogramSupportImpl(op, reason); } FailureOr -VMILayoutSupport::getChistSupport(VMIChistOp op, - std::string *reason) const { +VMILayoutSupport::getChistSupport(VMIChistOp op, std::string *reason) const { if (reason) *reason = "CHISTv2 cumulative high-range semantics are not classified"; return failure(); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 39ca049a1e..806f6c67fc 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -921,9 +921,9 @@ VMICapabilityResult requireIdentityMemRefLayout(Type memoryType, StringRef role, return VMICapabilityResult::missingCapability(reason); } -VMIMemorySafeReadProof computeSafeFullReadProof( - Type sourceType, std::optional constantOffset, - VMIVRegType resultType) { +VMIMemorySafeReadProof +computeSafeFullReadProof(Type sourceType, std::optional constantOffset, + VMIVRegType resultType) { VMIMemorySafeReadProof proof; proof.constantOffset = constantOffset; @@ -964,10 +964,11 @@ VMIMemorySafeReadProof computeSafeFullReadProof( return proof; } -VMIMemoryAccessPlan buildReadAccessPlan( - const VMITargetCapabilityRegistry &capabilities, Value source, - Type sourceType, VMIVRegType resultType, - std::optional constantOffset, VMIMemoryValidMaskKind validMask) { +VMIMemoryAccessPlan +buildReadAccessPlan(const VMITargetCapabilityRegistry &capabilities, + Value source, Type sourceType, VMIVRegType resultType, + std::optional constantOffset, + VMIMemoryValidMaskKind validMask) { VMIMemoryAccessPlan plan; plan.baseType = sourceType; plan.valueType = resultType; @@ -1032,9 +1033,10 @@ void requireUnavailableReadFallback(VMIMemoryAccessPlan &plan) { maskedLoadReason + scratchReason + guardedReason); } -FailureOr verifyFullOrSafeReadVRegChunks( - Operation *op, VMIVRegType type, Type sourceType, Value offset, - PatternRewriter &rewriter) { +FailureOr verifyFullOrSafeReadVRegChunks(Operation *op, + VMIVRegType type, + Type sourceType, Value offset, + PatternRewriter &rewriter) { std::string fullChunkReason; FailureOr lanesPerPart = checkFullDataPhysicalChunks(type, &fullChunkReason); @@ -1055,19 +1057,20 @@ FailureOr verifyFullOrSafeReadVRegChunks( return failure(); } -LogicalResult checkSupportedLoadShape( - const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, - Value source, Type sourceType, std::optional constantOffset, - std::string *reason) { +LogicalResult +checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIVRegType type, Value source, Type sourceType, + std::optional constantOffset, + std::string *reason) { auto fail = [&](const Twine &message) -> LogicalResult { if (reason) *reason = message.str(); return failure(); }; - VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( - capabilities, source, sourceType, type, constantOffset, - VMIMemoryValidMaskKind::AllTrue); + VMIMemoryAccessPlan accessPlan = + buildReadAccessPlan(capabilities, source, sourceType, type, + constantOffset, VMIMemoryValidMaskKind::AllTrue); if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); @@ -2112,8 +2115,7 @@ FailureOr> materializeDynamicContiguousGroupMask( shiftScalar, *allMask) .getResult(); col = rewriter - .create(loc, indexVectorType, lane, groupBase, - *allMask) + .create(loc, indexVectorType, lane, groupBase, *allMask) .getResult(); } @@ -3057,10 +3059,11 @@ struct OneToNVMIEnsureLayoutOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeDataLayout(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( - op, Twine("ensure_layout has no registered materialization support: ") + - supportReason); + op, + Twine("ensure_layout has no registered materialization support: ") + + supportReason); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); if (!sourceLayout || !resultLayout) @@ -3091,11 +3094,11 @@ struct OneToNVMIEnsureMaskLayoutOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeMaskLayout(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( - op, - Twine("ensure_mask_layout has no registered materialization support: ") + - supportReason); + op, Twine("ensure_mask_layout has no registered materialization " + "support: ") + + supportReason); if (sourceType.getGranularity() != resultType.getGranularity()) return rewriter.notifyMatchFailure( op, "mask layout helper cannot also change granularity"); @@ -3130,7 +3133,7 @@ struct OneToNVMIEnsureMaskGranularityOpPattern VMILayoutSupport supports; std::string supportReason; if (failed(supports.canMaterializeMaskGranularity(sourceType, resultType, - &supportReason))) + &supportReason))) return rewriter.notifyMatchFailure( op, Twine("ensure_mask_granularity has no registered materialization " "support: ") + @@ -3623,8 +3626,8 @@ struct OneToNVMICreateGroupMaskOpPattern contiguousMaterializations = computeGroupMaskMaterializationForType( op, contiguousType, &contiguousReason); if (failed(contiguousMaterializations)) - return rewriter.notifyMatchFailure( - op, Twine("create_group_mask ") + contiguousReason); + return rewriter.notifyMatchFailure(op, Twine("create_group_mask ") + + contiguousReason); contiguousParts.reserve(contiguousMaterializations->size()); for (const ConstantMaskChunkMaterialization &materialization : @@ -3807,23 +3810,20 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { Value firstOffset = createChunkOffset( op.getLoc(), *offset, group * 4 * *lanesPerPart, rewriter); Value secondOffset = createChunkOffset( - op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, - rewriter); - auto first = - rewriter.create(op.getLoc(), part0Type, part1Type, - *source, firstOffset, - rewriter.getStringAttr(*dist)); - auto second = - rewriter.create(op.getLoc(), part2Type, part3Type, - *source, secondOffset, - rewriter.getStringAttr(*dist)); - - auto even = rewriter.create( - op.getLoc(), part0Type, part2Type, first.getLow(), - second.getLow()); - auto odd = rewriter.create( - op.getLoc(), part1Type, part3Type, first.getHigh(), - second.getHigh()); + op.getLoc(), *offset, (group * 4 + 2) * *lanesPerPart, rewriter); + auto first = rewriter.create( + op.getLoc(), part0Type, part1Type, *source, firstOffset, + rewriter.getStringAttr(*dist)); + auto second = rewriter.create( + op.getLoc(), part2Type, part3Type, *source, secondOffset, + rewriter.getStringAttr(*dist)); + + auto even = + rewriter.create(op.getLoc(), part0Type, part2Type, + first.getLow(), second.getLow()); + auto odd = + rewriter.create(op.getLoc(), part1Type, part3Type, + first.getHigh(), second.getHigh()); part0.push_back(even.getLow()); part1.push_back(odd.getLow()); part2.push_back(even.getHigh()); @@ -5449,16 +5449,17 @@ struct OneToNVMIReduceAddFOpPattern } }; -template -struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { - OneToNVMIGroupReduceAddOpPattern( - TypeConverter &typeConverter, MLIRContext *context, - const VMITargetCapabilityRegistry &capabilities) +template +struct OneToNVMIGroupReduceOpPattern : OneToNOpConversionPattern { + OneToNVMIGroupReduceOpPattern(TypeConverter &typeConverter, + MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) : OneToNOpConversionPattern(typeConverter, context), capabilities(capabilities) {} LogicalResult - matchAndRewrite(OpTy op, typename OneToNOpConversionPattern::OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto sourceVMIType = cast(op.getSource().getType()); auto resultVMIType = cast(op.getResult().getType()); @@ -5472,15 +5473,14 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { getSupport(supports, op, &supportReason); if (failed(support)) return rewriter.notifyMatchFailure( - op, Twine("group_reduce_add has no layout support: ") + - supportReason); + op, Twine(op->getName().getStringRef()) + + " has no layout support: " + supportReason); FailureOr groupSize = getGroupSizeFromNumGroups( sourceVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( - op, - "group_reduce_addf requires num_groups to evenly divide lane count"); + op, "group reduce requires num_groups to evenly divide lane count"); if (support->kind == VMIGroupReduceAddFSupportKind::OneVLaneVcgadd) { if (sourceParts.size() != maskParts.size() || @@ -5506,9 +5506,9 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { results.reserve(resultTypes.size()); for (auto [sourceIndex, sourcePart] : llvm::enumerate(sourceParts)) { results.push_back(rewriter - .create(op.getLoc(), resultType, - sourcePart, - maskParts[sourceIndex]) + .create(op.getLoc(), resultType, + sourcePart, + maskParts[sourceIndex]) .getResult()); } @@ -5554,16 +5554,18 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "s16 block8 group_reduce_addf requires uniform physical " "types"); - Value lo = - rewriter.create(op.getLoc(), resultType, loSource, loMask) - .getResult(); - Value hi = - rewriter.create(op.getLoc(), resultType, hiSource, hiMask) - .getResult(); - results.push_back( - rewriter - .create(op.getLoc(), resultType, lo, hi, *combineMask) - .getResult()); + Value lo = rewriter + .create(op.getLoc(), resultType, + loSource, loMask) + .getResult(); + Value hi = rewriter + .create(op.getLoc(), resultType, + hiSource, hiMask) + .getResult(); + results.push_back(rewriter + .create(op.getLoc(), resultType, lo, + hi, *combineMask) + .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -5608,21 +5610,24 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "s32 block8 group_reduce_addf requires uniform physical " "types"); - partials.push_back( - rewriter.create(op.getLoc(), resultType, source, mask) - .getResult()); + partials.push_back(rewriter + .create( + op.getLoc(), resultType, source, mask) + .getResult()); } - Value sum01 = rewriter - .create(op.getLoc(), resultType, partials[0], - partials[1], *combineMask) - .getResult(); - Value sum23 = rewriter - .create(op.getLoc(), resultType, partials[2], - partials[3], *combineMask) - .getResult(); + Value sum01 = + rewriter + .create(op.getLoc(), resultType, partials[0], + partials[1], *combineMask) + .getResult(); + Value sum23 = + rewriter + .create(op.getLoc(), resultType, partials[2], + partials[3], *combineMask) + .getResult(); results.push_back(rewriter - .create(op.getLoc(), resultType, sum01, - sum23, *combineMask) + .create(op.getLoc(), resultType, + sum01, sum23, *combineMask) .getResult()); } @@ -5642,10 +5647,9 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { &chunksPerGroup, rewriter))) return failure(); VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); - bool rowLocalSlots1Result = - resultLayout && resultLayout.isGroupSlots() && - resultLayout.getNumGroups() == groupCount && - resultLayout.getSlots() == 1; + bool rowLocalSlots1Result = resultLayout && resultLayout.isGroupSlots() && + resultLayout.getNumGroups() == groupCount && + resultLayout.getSlots() == 1; int64_t expectedResultParts = rowLocalSlots1Result ? groupCount : groupCount * chunksPerGroup; if (sourceParts.size() != maskParts.size() || @@ -5682,11 +5686,7 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { op, "failed to create group_reduce_addf masks"); for (int64_t group = 0; group < groupCount; ++group) { - FailureOr accumulator = - createZeroVector(op.getLoc(), resultType, rewriter); - if (failed(accumulator)) - return rewriter.notifyMatchFailure( - op, "failed to create group_reduce_addf accumulator"); + Value accumulator; for (int64_t chunk = 0; chunk < chunksPerGroup; ++chunk) { int64_t index = group * chunksPerGroup + chunk; @@ -5696,19 +5696,23 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { op, "group_reduce_addf requires uniform physical chunk types"); Value reduced = rewriter - .create(op.getLoc(), resultType, sourceParts[index], - maskParts[index]) + .create(op.getLoc(), resultType, + sourceParts[index], maskParts[index]) .getResult(); - *accumulator = rewriter - .create(op.getLoc(), resultType, reduced, - *accumulator, *firstLaneMask) - .getResult(); + if (!accumulator) { + accumulator = reduced; + continue; + } + accumulator = rewriter + .create(op.getLoc(), resultType, reduced, + accumulator, *firstLaneMask) + .getResult(); } int64_t destChunk = rowLocalSlots1Result ? group : group * chunksPerGroup; results[destChunk] = rewriter - .create(op.getLoc(), resultType, *accumulator, + .create(op.getLoc(), resultType, accumulator, results[destChunk], *firstLaneMask) .getResult(); } @@ -5718,18 +5722,24 @@ struct OneToNVMIGroupReduceAddOpPattern : OneToNOpConversionPattern { } private: - FailureOr - getSupport(VMILayoutSupport &supports, VMIGroupReduceAddFOp op, - std::string *reason) const { + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddFOp op, + std::string *reason) const { return supports.getGroupReduceAddFSupport(capabilities, op, reason); } - FailureOr - getSupport(VMILayoutSupport &supports, VMIGroupReduceAddIOp op, - std::string *reason) const { + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceAddIOp op, + std::string *reason) const { return supports.getGroupReduceAddISupport(capabilities, op, reason); } + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceMaxFOp op, + std::string *reason) const { + return supports.getGroupReduceMaxFSupport(capabilities, op, reason); + } + const VMITargetCapabilityRegistry &capabilities; }; @@ -5994,8 +6004,7 @@ struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { FailureOr lanesPerPart = getDataLanesPerPart(sourceType.getElementType()); if (failed(lanesPerPart)) - return rewriter.notifyMatchFailure(op, - "failed to compute source lanes"); + return rewriter.notifyMatchFailure(op, "failed to compute source lanes"); Location loc = op.getLoc(); Value bin0 = createI32Constant(loc, 0, rewriter); @@ -6012,21 +6021,19 @@ struct OneToNVMIDhistOpPattern : OneToNOpConversionPattern { Value chunkMask = userMask; int64_t firstLane = static_cast(index) * *lanesPerPart; - int64_t activeLanes = - std::min(*lanesPerPart, - sourceType.getElementCount() - firstLane); + int64_t activeLanes = std::min( + *lanesPerPart, sourceType.getElementCount() - firstLane); if (activeLanes < *lanesPerPart) { - FailureOr validMask = - createPrefixMaskForActiveLanes(loc, maskType, activeLanes, - rewriter); + FailureOr validMask = createPrefixMaskForActiveLanes( + loc, maskType, activeLanes, rewriter); FailureOr allMask = createAllTrueMask(loc, maskType, rewriter); if (failed(validMask) || failed(allMask)) return rewriter.notifyMatchFailure( op, "failed to materialize tail-valid b8 mask"); - chunkMask = rewriter - .create(loc, maskType, chunkMask, *validMask, - *allMask) - .getResult(); + chunkMask = + rewriter + .create(loc, maskType, chunkMask, *validMask, *allMask) + .getResult(); } lo = rewriter.create(loc, loType, lo, source, chunkMask, bin0) @@ -6301,9 +6308,9 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite( - OpT op, typename OneToNOpConversionPattern::OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(OpT op, + typename OneToNOpConversionPattern::OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.empty()) @@ -6330,8 +6337,7 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { !isa(resultVRegType.getElementType()) || (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( resultVRegType.getElementType()) != 32 - : resultVRegType != - resultVRegTypes.front())) + : resultVRegType != resultVRegTypes.front())) return rewriter.notifyMatchFailure( op, "unsupported physical integer extension result type"); resultVRegTypes.push_back(resultVRegType); @@ -6996,15 +7002,15 @@ void populateVMIOneToNConversionPatterns( OneToNVMIReduceMinMaxFOpPattern, OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, - OneToNVMIExtIOpPattern, - OneToNVMIExtIOpPattern, OneToNVMITruncIOpPattern, - OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, - OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>( - typeConverter, patterns.getContext()); - patterns - .add, - OneToNVMIGroupReduceAddOpPattern>( - typeConverter, patterns.getContext(), capabilities); + OneToNVMIExtIOpPattern, OneToNVMIExtIOpPattern, + OneToNVMITruncIOpPattern, OneToNVMIBitcastOpPattern, + OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, + OneToNVMIShuffleOpPattern>(typeConverter, patterns.getContext()); + patterns.add< + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern>( + typeConverter, patterns.getContext(), capabilities); patterns.add( typeConverter, patterns.getContext(), capabilities); } @@ -7384,13 +7390,16 @@ checkSupportedReduceShape(const VMITargetCapabilityRegistry &capabilities, } template -LogicalResult checkSupportedGroupReduceAddShape( - const VMITargetCapabilityRegistry &capabilities, OpTy op, - std::string *reason = nullptr) { +LogicalResult +checkSupportedGroupReduceShape(const VMITargetCapabilityRegistry &capabilities, + OpTy op, std::string *reason = nullptr) { VMILayoutSupport supports; if constexpr (std::is_same_v) { if (succeeded(supports.getGroupReduceAddFSupport(capabilities, op, reason))) return success(); + } else if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceMaxFSupport(capabilities, op, reason))) + return success(); } else { if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) return success(); @@ -7642,7 +7651,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, broadcast.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_broadcast requires full source chunks with " - "#pto.vmi.layout, a dense full result layout, " + "#pto.vmi.layout, a dense full result " + "layout, " "and num_groups deriving a group size that divides or is a " "multiple of physical chunk lanes (" << reason << ")"; @@ -8062,13 +8072,14 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded( - checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.group_reduce_addf lowers through pto.vcgadd for 32B " "VLane groups or through pto.vcadd with reassoc, contiguous full " - "source/mask chunks, #pto.vmi.layout result " + "source/mask chunks, #pto.vmi.layout " + "result " "chunks, and num_groups deriving a group size aligned to " "physical chunks (" << reason << ")"; @@ -8078,7 +8089,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded( - checkSupportedGroupReduceAddShape(capabilities, reduce, &reason))) + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) return WalkResult::advance(); reduce.emitError() << kVMIDiagUnsupportedPrefix @@ -8090,6 +8101,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_maxf lowers through pto.vcgmax/vmax only " + "for f16/f32 values, matching source/mask chunks, " + "#pto.vmi.layout result chunks, and " + "num_groups deriving a group size aligned to physical chunks (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedReduceShape( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto new file mode 100644 index 0000000000..1ae3f90a15 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_reduce_maxf_quant( + %src: !pto.ptr, + %scale_out: !pto.ptr, + %out8: !pto.ptr, + %off: index) { + %c8 = arith.constant 8 : index + %c256 = arith.constant 256 : index + %eps = arith.constant 1.000000e-04 : f32 + %fp8_max = arith.constant 4.480000e+02 : f32 + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %eps2 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.maxf %amax_raw, %eps2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %fp8_max2 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %scale_out[%off], %c8 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %out8[%off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// ASSIGN: %[[X:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[ABS:.*]] = pto.vmi.absf %[[X]] +// ASSIGN: %[[AMAX_RAW:.*]] = pto.vmi.group_reduce_maxf %[[ABS]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[SCALE:.*]] = pto.vmi.divf +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[SCALE]] +// ASSIGN: %[[SCALE_VEC:.*]] = pto.vmi.group_broadcast %[[SCALE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q:.*]] = pto.vmi.divf %[[X]], %[[SCALE_VEC]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q_SPLIT:.*]] = pto.vmi.ensure_layout %[[Q]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( +// LOWER: pto.vcgmax +// LOWER: pto.vmax +// LOWER: pto.vsel +// LOWER: pto.vdiv +// LOWER: pto.vdintlv +// LOWER: pto.vcvt +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py new file mode 100644 index 0000000000..c6e34633b5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_f32(name: str, atol: float, rtol: float) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) + output = np.fromfile(f"{name}.bin", dtype=np.float32) + close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) + if close: + return True + diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + return False + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + return False + + +def main() -> None: + if not check_f32("v2", 1e-5, 1e-5) or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py new file mode 100644 index 0000000000..39f0af76f7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ELEMS = 256 +GROUPS = 2 +GROUP_SIZE = ELEMS // GROUPS +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8) + + +def generate(output_dir: Path) -> None: + repeats = (GROUP_SIZE + len(Q_VALUES) - 1) // len(Q_VALUES) + q_group = np.tile(Q_VALUES, repeats)[:GROUP_SIZE].astype(np.float32) + q = np.concatenate([q_group, q_group]).astype(np.float32) + src = np.empty(ELEMS, dtype=np.float32) + golden_scale = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + for group in range(GROUPS): + begin = group * GROUP_SIZE + end = begin + GROUP_SIZE + src[begin:end] = (q_group * SCALES[group]).astype(np.float32) + amax = np.max(np.abs(src[begin:end])).astype(np.float32) + scale = np.maximum(amax, np.float32(1.0e-4)) / FP8_MAX + golden_scale[group * 8] = scale + golden_out8_group = np.tile(F8E4M3FN_BYTES, repeats)[:GROUP_SIZE].astype(np.uint8) + golden_out8 = np.concatenate([golden_out8_group, golden_out8_group]).astype(np.uint8) + + scale_out = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + out8 = np.full(ELEMS, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale_out.tofile(output_dir / "v2.bin") + out8.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out8.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto new file mode 100644 index 0000000000..f2dcc0cd16 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_simdvf_per_token_cast_to_fp8_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %eps = arith.constant 1.000000e-04 : f32 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out8_gm, %ub_out8_u8, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %eps1 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.maxf %amax_raw, %eps1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %fp8_max1 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c8 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale + {num_groups = 2} : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out8_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp new file mode 100644 index 0000000000..630c7d55af --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_simdvf_per_token_cast_to_fp8_kernel(__gm__ float *src, + __gm__ float *scale, + __gm__ uint8_t *out8); + +void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, + uint8_t *out8, + void *stream) { + vmi_simdvf_per_token_cast_to_fp8_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp new file mode 100644 index 0000000000..cbb7149b86 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, + uint8_t *out8, + void *stream); + +int main() { + constexpr size_t kElems = 256; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kElems * sizeof(float); + size_t out8Bytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *out8Host = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *out8Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(srcDevice, scaleDevice, + out8Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", out8Host, out8Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(out8Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(out8Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 8c4a80114b9322cb5486a5247c0104182e0963a5 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 24 Jun 2026 23:08:37 +0800 Subject: [PATCH 33/54] feat: add cce-aligned vmi kernel cases --- docs/designs/vmi-dialect-design.md | 2080 ----------------- docs/designs/vmi-implementation-manual.md | 204 +- docs/designs/vmi-introduction.md | 87 +- .../vmi-layout-assignment-implementation.md | 37 +- .../vmi-layout-assignment-lowering-design.md | 22 +- docs/designs/vmi-layout-lowering-cases.md | 26 +- .../vmi-mxfp8-32x32-expected-lowering.md | 236 ++ include/PTO/IR/VMIAttrs.td | 12 +- include/PTO/IR/VMIOps.td | 63 +- include/PTO/Transforms/VMILayoutSupport.h | 13 +- .../PTO/Transforms/VMITargetCapabilities.h | 10 +- lib/PTO/IR/VMI.cpp | 262 ++- lib/PTO/Transforms/PTOValidateVMIIR.cpp | 50 +- lib/PTO/Transforms/VMILayoutAssignment.cpp | 241 +- lib/PTO/Transforms/VMILayoutFoldConsumers.cpp | 3 - lib/PTO/Transforms/VMILayoutSupport.cpp | 99 +- lib/PTO/Transforms/VMIToVPTO.cpp | 1463 +++++++++--- test/lit/vmi/vmi_gather_indices_invalid.pto | 2 +- .../vmi/vmi_group_reduce_maxi_i8_invalid.pto | 24 + ...t_assignment_group_load_block8_truncf.pto} | 17 +- ...ment_group_reduce_s64_broadcast_reduce.pto | 4 +- ...assignment_group_reduce_s64_tail_store.pto | 2 +- ...out_assignment_group_reduce_s64_truncf.pto | 2 +- .../vmi_layout_assignment_group_slot_load.pto | 18 + ...assignment_group_slot_load_dual_layout.pto | 1 + ...gnment_group_store_slots1_unit_stride.pto} | 18 +- .../vmi/vmi_layout_assignment_load_truncf.pto | 59 - .../vmi_layout_assignment_trunci_sparse.pto | 40 + .../vmi/vmi_layout_fold_consumers_store.pto | 28 - ...> vmi_layout_gate_bitcast_group_slots.pto} | 14 +- .../vmi_layout_gate_store_support_invalid.pto | 15 - .../vmi/vmi_memory_element_type_invalid.pto | 25 - test/lit/vmi/vmi_op_verifier_basic.pto | 5 - test/lit/vmi/vmi_shuffle_indices_invalid.pto | 33 + ...to => vmi_to_vpto_bitcast_group_slots.pto} | 21 +- .../vmi/vmi_to_vpto_ensure_layout_deint4.pto | 24 + .../vmi/vmi_to_vpto_gather_f16_invalid.pto | 2 +- test/lit/vmi/vmi_to_vpto_gather_u16.pto | 37 + ...group_broadcast_s32_deint2_small_group.pto | 29 + test/lit/vmi/vmi_to_vpto_group_ops.pto | 5 +- ...mi_to_vpto_group_reduce_s256_broadcast.pto | 2 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 4 +- .../vmi_to_vpto_group_reduce_s64_support.pto | 2 +- test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 81 +- .../vmi_to_vpto_group_slot_load_support.pto | 1 + .../vmi_to_vpto_group_store_slots1_1pt.pto | 28 + ...oup_store_slots1_unit_stride_alignment.pto | 47 + ...to_vpto_group_store_slots8_packed_byte.pto | 43 + test/lit/vmi/vmi_to_vpto_integer_casts.pto | 91 + .../vmi/vmi_to_vpto_load_safe_tail_memref.pto | 19 - .../vmi/vmi_to_vpto_memory_space_invalid.pto | 33 - .../vmi/vmi_to_vpto_memref_layout_invalid.pto | 33 - .../vmi/vmi_to_vpto_store_width_invalid.pto | 15 - test/lit/vmi/vmi_to_vpto_stride_load.pto | 35 + test/lit/vmi/vmi_to_vpto_stride_store.pto | 33 + test/lit/vmi/vmi_to_vpto_tile_read_write.pto | 64 - .../vmi/vmi_to_vpto_tile_write_deint_tail.pto | 34 - test/lit/vmi/vmi_to_vpto_tile_write_tail.pto | 33 - test/lit/vmi/vmi_to_vpto_truncf.pto | 24 + ... => vmi_truncf_rounding_token_invalid.pto} | 14 +- ...mi_truncf_rounding_unsupported_invalid.pto | 20 + test/lit/vpto/vgather2_u16_vpto_llvm.pto | 30 + ...i_fp4_e1_packed_surface_verify_invalid.pto | 19 + .../vmi_fp4_packed_surface_verify_invalid.pto | 19 + test/lit/vpto/vmi_sitofp.pto | 42 + test/lit/vpto/vmi_truncf_hif8.pto | 96 + .../group-reduce-i32-maxi-store/compare.py | 37 + .../vmi/group-reduce-i32-maxi-store/golden.py | 45 + .../group-reduce-i32-maxi-store/kernel.pto | 51 + .../group-reduce-i32-maxi-store/launch.cpp | 35 + .../vmi/group-reduce-i32-maxi-store/main.cpp | 86 + .../ptoas.flags | 0 test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md | 226 ++ test/vpto/cases/vmi/kernels/README.md | 49 + .../anti-mx-f8-bf16-scaled-16x512/compare.py | 34 + .../anti-mx-f8-bf16-scaled-16x512/golden.py | 75 + .../anti-mx-f8-bf16-scaled-16x512/kernel.pto | 96 + .../anti-mx-f8-bf16-scaled-16x512/launch.cpp | 45 + .../anti-mx-f8-bf16-scaled-16x512/main.cpp | 93 + .../anti-mx-f8-bf16-scaled-16x512/ptoas.flags | 1 + .../anti-mx-f8-bf16-scaled-4x128/compare.py | 34 + .../anti-mx-f8-bf16-scaled-4x128/golden.py | 67 + .../anti-mx-f8-bf16-scaled-4x128/kernel.pto | 85 + .../launch.cpp | 16 +- .../anti-mx-f8-bf16-scaled-4x128/main.cpp | 93 + .../anti-mx-f8-bf16-scaled-4x128/ptoas.flags | 1 + .../anti-mx-f8-bf16-scaled-64x2048/compare.py | 34 + .../anti-mx-f8-bf16-scaled-64x2048/golden.py | 75 + .../anti-mx-f8-bf16-scaled-64x2048/kernel.pto | 109 + .../anti-mx-f8-bf16-scaled-64x2048/launch.cpp | 45 + .../anti-mx-f8-bf16-scaled-64x2048/main.cpp | 93 + .../ptoas.flags | 1 + .../anti-mx-f8-f16-scaled-4x128/compare.py | 34 + .../anti-mx-f8-f16-scaled-4x128/golden.py | 64 + .../anti-mx-f8-f16-scaled-4x128/kernel.pto | 85 + .../anti-mx-f8-f16-scaled-4x128/launch.cpp | 44 + .../anti-mx-f8-f16-scaled-4x128/main.cpp | 93 + .../anti-mx-f8-f16-scaled-4x128/ptoas.flags | 1 + .../anti-mx-f8-f32-scaled-4x128/compare.py | 34 + .../anti-mx-f8-f32-scaled-4x128/golden.py | 60 + .../anti-mx-f8-f32-scaled-4x128/kernel.pto | 83 + .../anti-mx-f8-f32-scaled-4x128/launch.cpp | 44 + .../anti-mx-f8-f32-scaled-4x128/main.cpp | 93 + .../anti-mx-f8-f32-scaled-4x128/ptoas.flags | 1 + .../compare.py | 34 + .../golden.py | 78 + .../kernel.pto | 96 + .../launch.cpp | 45 + .../main.cpp | 93 + .../ptoas.flags | 1 + .../compare.py | 34 + .../golden.py | 73 + .../kernel.pto | 85 + .../launch.cpp | 45 + .../anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp | 93 + .../ptoas.flags | 1 + .../compare.py | 23 +- .../block-mx-quant-bf16-e4m3-4x128/golden.py | 69 + .../block-mx-quant-bf16-e4m3-4x128/kernel.pto | 151 ++ .../block-mx-quant-bf16-e4m3-4x128/launch.cpp | 45 + .../block-mx-quant-bf16-e4m3-4x128/main.cpp | 106 + .../ptoas.flags | 1 + .../block-mx-quant-bf16-e5m2-4x128/compare.py | 36 + .../block-mx-quant-bf16-e5m2-4x128/golden.py | 71 + .../block-mx-quant-bf16-e5m2-4x128/kernel.pto | 151 ++ .../block-mx-quant-bf16-e5m2-4x128/launch.cpp | 45 + .../block-mx-quant-bf16-e5m2-4x128/main.cpp | 106 + .../ptoas.flags | 1 + .../block-mx-quant-f16-e4m3-64x256/compare.py | 36 + .../block-mx-quant-f16-e4m3-64x256/golden.py | 61 + .../block-mx-quant-f16-e4m3-64x256/kernel.pto | 138 ++ .../block-mx-quant-f16-e4m3-64x256/launch.cpp | 45 + .../block-mx-quant-f16-e4m3-64x256/main.cpp | 106 + .../ptoas.flags | 1 + .../block-mx-quant-f16-e5m2-8x256/compare.py | 36 + .../block-mx-quant-f16-e5m2-8x256/golden.py | 64 + .../block-mx-quant-f16-e5m2-8x256/kernel.pto | 154 ++ .../block-mx-quant-f16-e5m2-8x256/launch.cpp | 45 + .../block-mx-quant-f16-e5m2-8x256/main.cpp | 106 + .../block-mx-quant-f16-e5m2-8x256/ptoas.flags | 1 + .../block-quant-bf16-fp8-2x128/compare.py | 47 + .../block-quant-bf16-fp8-2x128/golden.py | 75 + .../kernel.pto | 67 +- .../block-quant-bf16-fp8-2x128/launch.cpp | 42 + .../block-quant-bf16-fp8-2x128/main.cpp | 93 + .../block-quant-bf16-fp8-2x128/ptoas.flags | 1 + .../block-quant-bf16-fp8-32x128/compare.py | 47 + .../block-quant-bf16-fp8-32x128/golden.py | 75 + .../block-quant-bf16-fp8-32x128/kernel.pto | 83 + .../block-quant-bf16-fp8-32x128/launch.cpp | 42 + .../block-quant-bf16-fp8-32x128/main.cpp | 93 + .../block-quant-bf16-fp8-32x128/ptoas.flags | 1 + .../compare.py | 47 + .../golden.py | 122 + .../kernel.pto | 123 + .../launch.cpp | 42 + .../main.cpp | 93 + .../ptoas.flags | 1 + .../block-quant-bf16-fp8-4x128/compare.py | 47 + .../block-quant-bf16-fp8-4x128/golden.py | 75 + .../block-quant-bf16-fp8-4x128/kernel.pto | 113 + .../block-quant-bf16-fp8-4x128/launch.cpp | 42 + .../block-quant-bf16-fp8-4x128/main.cpp | 93 + .../block-quant-bf16-fp8-4x128/ptoas.flags | 1 + .../block-quant-f16-fp8-16x256/compare.py | 47 + .../block-quant-f16-fp8-16x256/golden.py | 69 + .../block-quant-f16-fp8-16x256/kernel.pto | 89 + .../block-quant-f16-fp8-16x256/launch.cpp | 41 + .../block-quant-f16-fp8-16x256/main.cpp | 93 + .../block-quant-f16-fp8-16x256/ptoas.flags | 1 + .../block-quant-f16-fp8-4x256/compare.py | 47 + .../block-quant-f16-fp8-4x256/golden.py | 69 + .../block-quant-f16-fp8-4x256/kernel.pto | 89 + .../block-quant-f16-fp8-4x256/launch.cpp | 41 + .../main.cpp | 44 +- .../block-quant-f16-fp8-4x256/ptoas.flags | 1 + .../block-quant-f16-fp8-8x128/compare.py | 47 + .../block-quant-f16-fp8-8x128/golden.py | 64 + .../block-quant-f16-fp8-8x128/kernel.pto | 89 + .../block-quant-f16-fp8-8x128/launch.cpp | 41 + .../block-quant-f16-fp8-8x128/main.cpp | 93 + .../block-quant-f16-fp8-8x128/ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 86 + .../kernel.pto | 188 ++ .../launch.cpp | 44 + .../main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 75 + .../kernel.pto | 188 ++ .../launch.cpp | 44 + .../main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 75 + .../kernel.pto | 117 + .../launch.cpp | 44 + .../main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 81 + .../kernel.pto | 112 + .../launch.cpp | 44 + .../dynamic-quant-pertoken-bf16-4x32/main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 82 + .../kernel.pto | 118 + .../launch.cpp | 44 + .../main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../dynamic-quant-pertoken-f16-4x32/golden.py | 70 + .../kernel.pto | 112 + .../launch.cpp | 44 + .../dynamic-quant-pertoken-f16-4x32/main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 101 + .../kernel.pto | 156 ++ .../launch.cpp | 45 + .../main.cpp | 105 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 90 + .../kernel.pto | 160 ++ .../launch.cpp | 45 + .../main.cpp | 105 + .../ptoas.flags | 1 + .../compare.py | 50 + .../golden.py | 90 + .../kernel.pto | 156 ++ .../launch.cpp | 45 + .../main.cpp | 105 + .../ptoas.flags | 1 + .../simdvf-per-token-cast-to-fp8/golden.py | 62 - .../swiglu-mx-quant-bf16-e4m3-4x8/compare.py | 36 + .../swiglu-mx-quant-bf16-e4m3-4x8/golden.py | 59 + .../swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto | 157 ++ .../swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp | 43 + .../swiglu-mx-quant-bf16-e4m3-4x8/main.cpp | 95 + .../swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags | 1 + .../swiglu-mx-quant-bf16-e5m2-4x8/compare.py | 36 + .../swiglu-mx-quant-bf16-e5m2-4x8/golden.py | 59 + .../swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto | 157 ++ .../swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp | 43 + .../swiglu-mx-quant-bf16-e5m2-4x8/main.cpp | 95 + .../swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags | 1 + .../compare.py | 36 + .../swiglu-mx-quant-f16-e4m3-64x512/golden.py | 60 + .../kernel.pto | 153 ++ .../launch.cpp | 42 + .../swiglu-mx-quant-f16-e4m3-64x512/main.cpp | 95 + .../ptoas.flags | 1 + .../compare.py | 36 + .../golden.py | 62 + .../kernel.pto | 163 ++ .../launch.cpp | 42 + .../swiglu-mx-quant-f16-e5m2-128x256/main.cpp | 95 + .../ptoas.flags | 1 + .../tquant-int8-asym-64x128/compare.py | 32 + .../kernels/tquant-int8-asym-64x128/golden.py | 48 + .../tquant-int8-asym-64x128/kernel.pto | 106 + .../tquant-int8-asym-64x128/launch.cpp | 44 + .../kernels/tquant-int8-asym-64x128/main.cpp | 99 + .../tquant-int8-asym-64x128/ptoas.flags | 1 + .../kernels/tquant-int8-sym-64x128/compare.py | 32 + .../kernels/tquant-int8-sym-64x128/golden.py | 44 + .../kernels/tquant-int8-sym-64x128/kernel.pto | 121 + .../kernels/tquant-int8-sym-64x128/launch.cpp | 41 + .../kernels/tquant-int8-sym-64x128/main.cpp | 90 + .../tquant-int8-sym-64x128/ptoas.flags | 1 + .../kernels/tquant-mxfp8-32x32-nd/compare.py | 36 + .../kernels/tquant-mxfp8-32x32-nd/golden.py | 57 + .../kernels/tquant-mxfp8-32x32-nd/kernel.pto | 114 + .../kernels/tquant-mxfp8-32x32-nd/launch.cpp | 42 + .../kernels/tquant-mxfp8-32x32-nd/main.cpp | 93 + .../kernels/tquant-mxfp8-32x32-nd/ptoas.flags | 1 + .../kernels/tquant-mxfp8-32x64-nz/compare.py | 36 + .../kernels/tquant-mxfp8-32x64-nz/golden.py | 75 + .../kernels/tquant-mxfp8-32x64-nz/kernel.pto | 174 ++ .../kernels/tquant-mxfp8-32x64-nz/launch.cpp | 44 + .../kernels/tquant-mxfp8-32x64-nz/main.cpp | 116 + .../kernels/tquant-mxfp8-32x64-nz/ptoas.flags | 1 + 285 files changed, 16684 insertions(+), 3258 deletions(-) delete mode 100644 docs/designs/vmi-dialect-design.md create mode 100644 docs/designs/vmi-mxfp8-32x32-expected-lowering.md create mode 100644 test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto rename test/lit/vmi/{vmi_layout_assignment_group_load_block8_truncf_invalid.pto => vmi_layout_assignment_group_load_block8_truncf.pto} (70%) rename test/lit/vmi/{vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto => vmi_layout_assignment_group_store_slots1_unit_stride.pto} (66%) create mode 100644 test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto rename test/lit/vmi/{vmi_to_vpto_bitcast_group_slots_invalid.pto => vmi_layout_gate_bitcast_group_slots.pto} (70%) create mode 100644 test/lit/vmi/vmi_shuffle_indices_invalid.pto rename test/lit/vmi/{vmi_layout_gate_bitcast_group_slots_invalid.pto => vmi_to_vpto_bitcast_group_slots.pto} (59%) create mode 100644 test/lit/vmi/vmi_to_vpto_gather_u16.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto create mode 100644 test/lit/vmi/vmi_to_vpto_stride_load.pto create mode 100644 test/lit/vmi/vmi_to_vpto_stride_store.pto delete mode 100644 test/lit/vmi/vmi_to_vpto_tile_read_write.pto delete mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto delete mode 100644 test/lit/vmi/vmi_to_vpto_tile_write_tail.pto rename test/lit/vmi/{vmi_to_vpto_tile_write_tail_deint_invalid.pto => vmi_truncf_rounding_token_invalid.pto} (55%) create mode 100644 test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto create mode 100644 test/lit/vpto/vgather2_u16_vpto_llvm.pto create mode 100644 test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto create mode 100644 test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto create mode 100644 test/lit/vpto/vmi_sitofp.pto create mode 100644 test/lit/vpto/vmi_truncf_hif8.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py create mode 100644 test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto create mode 100644 test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp create mode 100644 test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp rename test/vpto/cases/vmi/{kernels/simdvf-per-token-cast-to-fp8 => group-reduce-i32-maxi-store}/ptoas.flags (100%) create mode 100644 test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md create mode 100644 test/vpto/cases/vmi/kernels/README.md create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto rename test/vpto/cases/vmi/kernels/{simdvf-per-token-cast-to-fp8 => anti-mx-f8-bf16-scaled-4x128}/launch.cpp (72%) create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags rename test/vpto/cases/vmi/kernels/{simdvf-per-token-cast-to-fp8 => block-mx-quant-bf16-e4m3-4x128}/compare.py (62%) create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py rename test/vpto/cases/vmi/kernels/{simdvf-per-token-cast-to-fp8 => block-quant-bf16-fp8-2x128}/kernel.pto (50%) create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp rename test/vpto/cases/vmi/kernels/{simdvf-per-token-cast-to-fp8 => block-quant-f16-fp8-4x256}/main.cpp (72%) create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags delete mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags diff --git a/docs/designs/vmi-dialect-design.md b/docs/designs/vmi-dialect-design.md deleted file mode 100644 index 897b1661cc..0000000000 --- a/docs/designs/vmi-dialect-design.md +++ /dev/null @@ -1,2080 +0,0 @@ -# VMI dialect 设计 - -## 背景 - -VPTO 的 `!pto.vreg` 是 256 bytes 物理向量寄存器抽象。很多 VPTO op 暴露的是 -physical placement:`vcvt` part、pack/unpack、interleave/deinterleave、load/store dist、 -predicate granularity 等。TileLang `T.parallel` 或其它前端想表达的是逻辑向量语义,不应该 -手写这些 physical placement。 - -VMI dialect 的目标是提供一层 PTO-friendly 的 semantic vector IR。它不是任何外部向量 dialect -的语法克隆,也不是 VPTO physical dialect。VMI 的设计来源是 PTO virtual vector ISA 需要承接的 -逻辑向量语义、layout、mask granularity、memory safety 和控制流 layout join;后续 lowering 只从 -VMI 决定 physical layout 和 VPTO op。 - -本设计采用 `vmi.vreg` 作为 layout carrier,不再引入单独的 `vbundle` type: - -```text -semantic VMI - -> layout-assigned VMI - -> physical VPTO -``` - -VMI 的 producer 在核心设计之外。TileLang/PTO lowering、手写 VMI 测试或其它 import 工具都可以 -产生 VMI,但它们不能定义 VMI 的 semantic surface。核心设计只要求 producer 在进入 VMI boundary -时生成合法 VMI IR。 - -## 和旧 VMI layout 设计的关系 - -旧文档中的核心形式是: - -```mlir -!pto.vmi.vreg -!pto.vmi.mask -``` - -这个方向是对的:`vmi.vreg` 本身是 virtual aggregate type,可以承载完整 logical vector, -layout 放在它上面比放在 physical `!pto.vreg` 上更合理。 - -旧设计需要补强的地方主要是 layout descriptor 和 lowering contract,而不是推翻 -`vmi.vreg`: - -1. 旧 layout descriptor 把 `logical_shape`、`phys_dtype`、`phys_lanes` 放进 attr,和 - `vreg` / target registry 存在重复信息。重复字段会产生 verifier 漂移。 -2. `axes=[#axis<...>]` 太开放,缺少每个 layout 的精确定义、part ordering 和 lane map。 -3. 旧设计要求 `N * bitwidth(T)` 是 256B 整数倍,无法覆盖 tail / 非整 tile。 -4. mask 只写成 `mask`,但没有定义 data layout、mask layout、mask granularity - conversion 在宽度转换中的同步规则。 -5. 控制流 join 没有定义:`scf.if` 两边 layout 不同、`scf.for` loop-carried layout 如何稳定。 -6. memory access map 和 register layout 没有切开,容易把 strided memory view 误当成 vreg - layout。 -7. hard vector semantics 缺失,例如 padding read、active prefix index、dynamic permute、 - compress/expand、scan/reduction/contract 的 VMI 表达和 lowering contract。 - -因此本设计保留 `vmi.vreg` 这个 carrier,但不沿用旧 layout descriptor 的 -开放式语义。旧文档没有定义 “logical behavior -> hardware mismatch -> physical -decomposition -> lane map -> propagation/sink” 这条 source contract;这是本文新增的核心约束。 - -换句话说,本文不是复述旧 `vmi.layout`,而是把旧的开放式 axis descriptor 收紧成一个很小的 -public layout 集合。本设计只接受 `contiguous`、`deinterleaved = 2`、`deinterleaved = 4`。 -source contract 是新增 layout kind 的准入规则,不是要求实现 generic axes 或任意 lane-map -descriptor。 - -## 目标 - -1. VMI surface 表达逻辑向量语义,不暴露 VPTO part/dist/interleave 细节。 -2. `vmi.vreg` 是 virtual aggregate type,可以表示大于 256B 的 logical vector。 -3. layout 放在 layout-assigned VMI type 上,不再另设 `vbundle`。 -4. VMI mask 是一等类型;surface mask 表达 logical predicate,layout-assigned mask 才携带 - concrete predicate granularity `b8/b16/b32`。 -5. VMI 支持 tail / 非整 tile;padding physical lane 不可观察。 -6. VMI lowering 支持控制流中的 layout join。 -7. VMI producer boundary 后的 IR 必须只依赖 VMI semantic op/type 表达逻辑向量语义。 - -## 非目标 - -1. 不改变 physical `!pto.vreg` 的含义。它仍然是 256 bytes physical register。 -2. 不把 VMI 做成任何外部向量 dialect 的逐 op 复制品;VMI 只表达 PTO lowering 需要的 logical - vector semantics。 -3. 不把 scalar lane extract 当作 VMI vector op。scalar lane extract 是 vector-to-scalar - boundary,必须在进入 VMI 前被 producer 消除,或以明确 diagnostic 退出 PTO 路线。 -4. 不把 VPTO load/store dist 暴露成 VMI surface op。dist 是 lowering 选择。 - -## VMI Producer Boundary Contract - -VMI 是 PTO 路线上的 virtual vector ISA。任何 producer 在进入 VMI boundary 后,必须满足下面之一: - -1. 逻辑向量语义已经表达为 native VMI semantic op。 -2. 逻辑向量语义已经表达为一组 VMI semantic op 的组合,并保持 producer 的 observable semantics。 -3. 该行为不是 VMI 负责的向量计算,而是 vector-to-scalar / tensor / debug / transform boundary, - 已经在进入 VMI 前由 producer 消除,或以明确 diagnostic 退出 PTO 路线。 - -不能把“当前阶段不支持”作为 VMI 设计结果。一个 PTO virtual vector semantic 如果属于 VMI 负责的 -逻辑向量语义,文档必须给出 VMI op、组合 lowering、layout contract、memory fallback 或 target -capability diagnostic。diagnostic 只允许表示语义边界或目标能力缺失,不能表示“VMI 没有设计这个能力”。 - -`pto.vmi -> pto` 的完成条件是: - -```text -at VMI producer boundary: - logical vector semantics are represented by VMI op/type - no physical VPTO op is introduced by the producer - no hidden layout/mask/type side table is required to interpret a VMI value - -after vmi-layout-assignment: - every vmi.vreg/vmi.mask has an explicit #pto.vmi.layout - every mask granularity matches its consumer - every control-flow yield/iter_arg/result has one stable layout - -after vmi-to-vpto: - no pto.vmi op/type remains - every logical VMI value has been lowered to ordered physical VPTO values -``` - -### Capability And Fallback Policy - -所有 direct lowering 和 fallback 选择必须来自显式配置,不能依赖 pass 内隐藏全局状态: - -```text -TargetCapabilityRegistry: - element-type storage/compute/convert support - layout source/sink/conversion support - memory access capability: OOB, masked, gather/scatter, block-strided - predicate capability: granularity conversion, prefix-popcount, rearrangement - reduction/scan/contract capability - scratch memory spaces, alignment, and lifetime rules - -VMIToPTOOptions: - enableScratchFallback - enableGuardedScalarFallback - enableIndexBufferFallback - allowDebugStrip - targetVScaleSpecialization - diagnosticVerbosity -``` - -fallback 被 option 禁用时,diagnostic 必须报告 `disabled_by_option`。target registry 缺能力时, -diagnostic 必须报告 `missing_capability`。debug-only op 只能由 debug pipeline 消费,或在 -`allowDebugStrip` 明确开启时剥离;否则报 `VMI-DEBUG-BOUNDARY`。 - -fallback resource 也必须显式建模: - -```text -scratch fallback: - memory space, alignment, element type, shape, lifetime, and deallocation point - must be explicit in the lowering plan - scratch initialization, such as padding fill, must dominate later scratch load - -guarded scalar/vector fallback: - guard must dominate every memory effect it protects - invalid lane must not compute a memory effect through an OOB memref address - -index-buffer fallback: - index element width, signedness, and address unit must match the consumer - buffer lifetime must dominate gather/scatter or compaction use -``` - -如果无法分配 scratch、无法放置 guard、或 index buffer 宽度不满足目标要求,diagnostic 使用 -`VMI-FALLBACK-RESOURCE`,并说明是 resource 缺失而不是语义不可表达。 - -## 类型模型 - -### Surface Type - -VMI surface type 不显式写 layout: - -```mlir -!pto.vmi.vreg<128xf32> -!pto.vmi.vreg<256xf8> -!pto.vmi.vreg<1xf32> - -!pto.vmi.mask<128xpred> -!pto.vmi.mask<256xpred> -``` - -`N` 是 logical lane count,`T` 是 logical element type。surface `mask` 表示 N 个 -logical predicate lane,不预先绑定 VPTO predicate granularity。layout assignment 根据 consumer -选择 concrete granularity: - -```text -f32/i32 consumer -> b32 -f16/bf16/i16 consumer -> b16 -f8/i8 consumer -> b8 -``` - -如果一个 logical mask 被不同 width consumer 使用,VMI lowering 必须按 use 插入 -`vmi.ensure_mask_granularity` 或重物化 mask producer,不能假设某个 concrete granularity 可直接 -给所有 consumer 使用。 - -VMI type 以 1-D logical vector 为核心。来自 multi-rank producer value 的语义在进入 VMI boundary 前按 row-major flatten 成: - -```mlir -!pto.vmi.vreg<64xf32> -!pto.vmi.mask<64xpred> -``` - -VMI value 本身只承载 flattened lane sequence,不携带隐式 rank side table。需要 rank 信息的 op -必须在自身 attr 中保存 logical shape / indexing map,例如 `logical_shape = [8, 8]`。这样保持 -与既有 `vmi.vreg` 设计一致,同时不丢失 transfer、transpose、reshape 等 op 的语义。 - -shape-sensitive op 的规则是: - -```text -elementwise / select: - operate on flattened lanes and preserve any surrounding op-provided shape context - -tile_read / tile_write: - carry logical_shape and permutation_map attrs - -shape_cast / reshape / transpose / contract: - carry source/result shapes, maps, and iterator metadata as op attrs - -block argument / function argument: - carries only flat vreg type; any later shaped use must provide its own shape attrs -``` - -因此 logical shape 信息不能保存在 C++ side table,也不能要求 consumer 从 defining op 反查。 - -Rank-0 logical vector 仍然是 VMI vector value,不是 scalar SSA value: - -```mlir -rank-0 logical vector -> !pto.vmi.vreg<1xT> -rank-0 logical predicate -> !pto.vmi.mask<1xpred> -``` - -只有产生 scalar result 的 extract 才是 vector-to-scalar boundary。rank-0 logical vector load、 -bitcast、mask 和 arithmetic 仍然走 VMI,不能因为只有一个 lane 就绕开 VMI verifier。 - -Scalable logical vector 不能直接进入 VMI type,因为 `vmi.vreg` 的 `N` 是 concrete logical lane -count。producer 必须先根据 target profile 和 tiling decision 把 scalable semantics specialize 成固定 -`N`;否则在 VMI boundary 报 `VMI-SCALABLE-VECTOR`。这不是 VMI 的临时缺口,而是 -固定 256B physical vreg lowering 的前置约束。 - -### Layout-Assigned Type - -`vmi-layout-assignment` 后,所有 VMI data/mask value 都必须带 layout: - -```mlir -!pto.vmi.vreg<128xf32, #pto.vmi.layout> -!pto.vmi.vreg<128xf32, #pto.vmi.layout> -!pto.vmi.vreg<256xf32, #pto.vmi.layout> - -!pto.vmi.mask<128xb32, #pto.vmi.layout> -!pto.vmi.mask<128xb32, #pto.vmi.layout> -``` - -这里的 `#pto.vmi.layout` 是唯一的 VMI register layout carrier。它不是 `#pto.vlayout` -的直接复用,也不是 `vbundle` 的 type 参数;但它必须采用同一套精确 lane-map 语义,保证后续 -lower 到 physical VPTO 时可验证。 - -### 非整 Tile - -VMI type 不要求 `N * bitwidth(T)` 是 256B 整数倍: - -```mlir -!pto.vmi.vreg<100xf32> -!pto.vmi.mask<100xpred> -``` - -physical lowering 时按 256B part 向上取整。超出 `N` 的 physical lane 是 padding lane: - -```text -padding lane: - may be poison/undef internally - must not be stored - must not affect compare/reduction/scan - must not become visible through layout conversion -``` - -任何 store、reduction、compress、mask-producing op 都必须用 logical lane count 或 explicit -mask 保护 padding lane。 - -## Layout 设计来源 - -VMI layout 的价值必须从逻辑 vector 行为推导,而不是从 layout 名字推导。判断流程是: - -```text -1. 前端想表达一个完整的 logical vector 行为。 -2. VPTO 底层指令不能把这个 logical vector 天然放进一个 contiguous physical sequence。 -3. 但 VPTO 可以把这个 logical vector 拆成一组有固定 lane-map 的 physical parts。 -4. 后续常见 op 可以在这些 parts 上逐 part 保持 logical semantics。 -5. 边界 consumer 能直接消费这种 parts,或存在可验证的 materialize path。 -6. 因此值得把这个 parts relation 提升为 VMI layout。 -``` - -layout 不是“某条指令的名字”,而是一个 representation relation: - -```text -Layout L defines: - logical vector value V[NxT] - <-> ordered physical parts P0, P1, ... - with exact map logical lane i -> (part, lane) -``` - -只有当这个 relation 能让 VMI 保持“用户看到的是一个连续 logical vector”,同时避免前端手写 -parts,layout 才有设计价值。 - -### Register Layout 集合 - -VMI register layout 不采用复杂通用 descriptor,而是定义为封闭集合: - -```text -#pto.vmi.layout -#pto.vmi.layout -#pto.vmi.layout -``` - -`deinterleaved = K` 表示一个 logical vector 被拆成 K 个 physical part,第 `p` 个 part 保存 -logical lane `p, p + K, p + 2K, ...`。这个名字直接描述元素摆放,不绑定到某条 VPTO op,也不 -引入旧 `axes` 的通用维度系统。 - -不加入 `channel`、`packed_bits`、`blocked`、`stride`、`permutation` 等 layout kind。 -这些能力先由 VMI semantic op、memory access plan 或 explicit layout conversion 表达。只有当 -一个新 representation 同时满足下面的 source contract,才允许扩展 layout 目录。 - -### Layout Source Contract - -每个 VMI layout kind 必须来自一条明确的 source contract: - -```text -logical behavior: - VMI 想表达的用户级 vector 行为 - -hardware mismatch: - 为什么 VPTO 不能用一个 contiguous physical sequence 天然承载 - -physical decomposition: - VPTO 实际能产生或消费的 physical parts - -lane map: - logical lane -> physical part/lane 的精确定义 - -propagation rule: - 哪些 VMI op 可以逐 part 保持语义 - -boundary rule: - 哪些 load/store/pack/convert consumer 可以直接消费,哪些必须 materialize - -mask rule: - 对应 mask 如何生成、转换和消费 -``` - -没有这份 source contract 的 lane movement 不能进入 `#pto.vmi.layout`。 - -### Source 1: Widen Cast To Larger Logical Vector - -逻辑行为: - -```mlir -%w = pto.vmi.extf %a - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> -``` - -用户语义是“128 个 f16 lane 加宽成 128 个连续 f32 lane”。但 128 个 f32 是 512B,超过单个 -256B physical vreg。VPTO 的可行 lowering 不是一个 contiguous 512B register,而是两条 part -conversion: - -```text -even part: - physical even[i] = extf(logical[2*i]) - -odd part: - physical odd[i] = extf(logical[2*i+1]) -``` - -因此需要一个 layout 表达“这个 VMI value 仍然是 logical `128xf32`,但 physical representation -是 even/odd 两个 parts”: - -```mlir -#pto.vmi.layout -``` - -lane map: - -```text -part = i % 2 -lane = floor(i / 2) -physical[part][lane] = logical[i] -``` - -这个 layout 的价值在于后续 elementwise op 不需要 materialize contiguous representation: - -```mlir -%s = pto.vmi.addf %w, %b - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -``` - -lowering 可以变成两路 add: - -```text -add even parts -add odd parts -``` - -最后如果 store consumer 能把 even/odd parts 交织写回 contiguous memory,就不需要中途 -`ensure_layout contiguous`。 - -同理: - -```mlir -%w = pto.vmi.extf %a - : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> -``` - -需要: - -```mlir -#pto.vmi.layout -``` - -这里不再使用抽象 stride 命名。`deinterleaved = 4` 的来源是 `f8 -> f32` 的 VPTO part -conversion contract,不是任意 stride 语义。 - -### Source 2: Narrow / Pack Consumer - -逻辑行为: - -```mlir -%n = pto.vmi.truncf %x - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> -``` - -如果 `%x` 已经是 `#pto.vmi.layout`,VPTO 可以用 pack/narrow 类 -consumer 把 even/odd f32 parts 合成 contiguous f16 result。这里 layout 的来源不是 producer,而是 -consumer 能直接接受这种 decomposition: - -```text -source layout: - logical f32 value represented as even/odd f32 parts - -consumer: - narrowing pack consumes those parts - -result: - contiguous f16 logical vector -``` - -因此 `deinterleaved` 必须同时登记 producer contract 和 inverse/sink contract。否则 layout 只能 -产生,不能被合法消耗。 - -### Source 3: Same-Width Layout Materialization - -逻辑行为: - -```mlir -%x = pto.vmi.ensure_layout %v - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -``` - -这里不新增 surface view op。目标不是产生两个独立 semantic vectors,而是让同一个 logical -vector 继续作为一个 VMI value 存活,只是 physical representation 变成 even/odd parts。IR 中由 -`vmi-layout-assignment` 插入 -`pto.vmi.ensure_layout`,并由 target registry 证明存在 preserving materialization path。VPTO 的 -`vdintlv/vintlv` 类 register rearrangement 可以产生或消费这种 representation。 - -这和 `vcvt` 产生的 even/odd representation 使用同一个 layout: - -```mlir -#pto.vmi.layout -``` - -区别只在 source contract: - -```text -logical behavior: - 同宽 logical vector 保持一个 VMI value,但 physical parts 分别保存 even/odd lanes - -hardware mismatch: - VPTO interleave/deinterleave 指令以两个 physical vreg parts 表达 - -layout: - deinterleaved=2 -``` - -如果 VMI op 的语义本来就是“返回两个独立 vectors”,例如 AoS -> SoA 后用户分别使用 `%x` -和 `%y`,那不需要 layout,直接产生两个 `vmi.vreg`。只有当“一个 logical vector value” -需要以 even/odd parts 长期存活时,才使用 `deinterleaved=2`。 - -### Channel Split / Merge 不是 Register Layout - -channel split/merge 的用户代码通常有两种形态。 - -第一种是把 interleaved data 当作普通 flat vector: - -```text -logical = [r0, g0, b0, a0, r1, g1, b1, a1, ...] -对每个 lane 做同一种逐元素操作 -``` - -这种情况下 `contiguous` representation 就能表达用户语义,不需要 channel layout。 - -第二种是用户按 channel 编程: - -```mlir -%r, %g, %b, %a = pto.vmi.channel_split %rgba - : !pto.vmi.vreg<128xi8> - -> !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, - !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> - -%r2 = pto.vmi.addi %r, %bias_r : !pto.vmi.vreg<32xi8> -%g2 = pto.vmi.addi %g, %bias_g : !pto.vmi.vreg<32xi8> -%b2 = pto.vmi.addi %b, %bias_b : !pto.vmi.vreg<32xi8> -%a2 = pto.vmi.addi %a, %bias_a : !pto.vmi.vreg<32xi8> -%out = pto.vmi.channel_merge %r2, %g2, %b2, %a2 - : !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8>, - !pto.vmi.vreg<32xi8>, !pto.vmi.vreg<32xi8> - -> !pto.vmi.vreg<128xi8> -``` - -这里自然的 IR 是多个 semantic VMI values,而不是“一个 VMI value 带 channel layout”。 -目标专用 split/merge 能力是 `channel_split/channel_merge` 的 lowering contract;load/store -memory boundary 的 dist/sink contract 也可以作为等价 lowering path。 - -`channel_split` / `channel_merge` 的语义必须能完全退化成 static shuffle,不能引入额外 -layout 规则。`C` 不需要单独 attr:`channel_split` 的 `C` 来自 result 个数, -`channel_merge` 的 `C` 来自 operand 个数。设 input 有 `N = C * M` 个 logical lanes: - -```text -channel_split(input, C): - out[c][i] = input[i * C + c] - for 0 <= c < C - for 0 <= i < M - -channel_merge(out[0], ..., out[C-1]): - result[i * C + c] = out[c][i] - for 0 <= i < M - for 0 <= c < C -``` - -如果 `N` 不能被 `C` 整除,或者 merge operands 的 logical lane count 不一致,op verifier -必须拒绝。需要 tail 的场景通过外层 mask / valid lane 语义表达,不能让 channel op 自己发明 -padding lane。 - -因此这两个 op 的价值只是 canonical interface:producer 可以直接表达 channel 语义, -外部 import 工具也可以把识别出的 static shuffle pattern canonicalize 成它们;如果没有 -识别或目标没有专用 lowering,保持或退回 `pto.vmi.shuffle` 仍然是等价路径。 -当前 direct VPTO lowering 只接受能形成完整 physical channel groups 的形状:flat contiguous -source/result 与 virtual deinterleaved=C channel layout 必须有相同 physical arity,或已经是 matching -deinterleaved=C layout 的 identity forwarding。arity-changing partial group 需要额外 packing/drop -padding plan,不能直接 lowering。 - -所以 VMI register layout 目录不为 channel-specific representation 引入 layout kind,也不预留 -半成品 layout 语义。本文覆盖的用户形态要么是 flat contiguous vector,要么是多个 channel -semantic value;都不需要“一个 VMI value 带 channel layout”。 - -### Pack / Unpack 不作为长期 Layout - -pack/unpack 的逻辑行为通常是 width conversion 或 memory encoding: - -```text -wide logical vector -> narrow logical vector -narrow memory payload -> wide logical vector -``` - -它们的结果可以是 `contiguous` logical vector;pack/unpack 是 producer/sink/conversion -contract,不是必须长期传播的 register layout。只有当目标 ISA 提供 packed-format arithmetic, -并且 VMI 真的要让 packed representation 跨 compute 存活时,才需要另立 -`packed_bits` layout。本设计没有 packed-format arithmetic source contract,因此 pack/unpack 不进入 -长期 register layout。 - -### 不应成为 Register Layout 的东西 - -以下能力虽然来自 VPTO/VISA,但不是 VMI register layout: - -| 能力 | 原因 | -|---|---| -| `vsldb/vsstb` block stride | 描述 memory address map;result register 可仍是 contiguous representation | -| gather/scatter index | runtime address map,不是 static logical lane 到 physical part 的关系 | -| dynamic `vselr` | runtime permutation,应是 `pto.vmi.permute` op | -| `vsqz/vusqz` compaction | runtime mask 决定 lane destination,应是 `compress/active_prefix_index` op | -| one-shot `vintlv/vdintlv` | 如果只是 boundary conversion,不应提升成长期 layout;若表示一个 VMI value 的 even/odd parts,则归入 `deinterleaved=2` | - -VMI layout 只解决“一个 logical vector value 在寄存器中长期以什么 parts representation 存活” -的问题。memory address、runtime permutation、dynamic compaction 都是其它语义。 - -### Lane Map - -设: - -```text -N = logical lane count -lanesPerDataPart(T) = 256B / sizeof(T) -lanesPerMaskPart(b8) = 256 -lanesPerMaskPart(b16) = 128 -lanesPerMaskPart(b32) = 64 -``` - -`contiguous`: - -```text -chunk = floor(i / lanesPerPart) -lane = i % lanesPerPart -physical[chunk][lane] = logical[i] -``` - -`deinterleaved = K`,其中 `K` 只能是 2 或 4: - -```text -p = i % K -q = floor(i / K) -chunk = floor(q / lanesPerPart) -lane = q % lanesPerPart -physical[p][chunk][lane] = logical[i] -``` - -`deinterleaved=2` 和 `deinterleaved=4` 的 physical value ordering 固定为 part-major: - -```text -p0_chunk0, p0_chunk1, ..., p1_chunk0, p1_chunk1, ..., p(K-1)_chunk0, ... -``` - -所有 verifier、type converter、physical lowering 和 control-flow conversion 必须使用同一套 -ordering。 - -### Physical Arity - -`vmi-to-vpto` 不能按示例猜 physical value 个数,必须由 type + layout 统一推导。 - -对 data vreg: - -```text -lanesPerPart = 256B / sizeof(T) - -contiguous: - chunks = ceil(N / lanesPerPart) - physical values = chunks - -deinterleaved = K: - lanesPerLogicalPart = ceil(N / K) - chunksPerPart = ceil(lanesPerLogicalPart / lanesPerPart) - physical values = K * chunksPerPart -``` - -对 mask: - -```text -lanesPerPart = lanesPerMaskPart(G) -same formula as data, replacing T with mask granularity G -``` - -每个 physical value 的有效 lane 由 lane map 反推: - -```text -contiguous valid: - logical = chunk * lanesPerPart + lane - valid = logical < N - -deinterleaved valid: - logical = K * (chunk * lanesPerPart + lane) + p - valid = logical < N -``` - -padding lane 可以是 poison/undef,但 store、mask-producing op、reduction、scan、compress 和 -layout conversion 都必须显式带着 `valid` 信息,不能只依赖 physical register 宽度。 - -### Broadcast 不作为 Register Layout - -VMI surface 使用 `broadcast` 表达前端语义: - -```mlir -%v = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> -``` - -也就是: - -```text -for i in 0 .. N: - v[i] = x -``` - -这不是 logical lane 到 physical part/lane 的 placement relation,而是一个 value producer -可以延迟 materialize 的事实。`vmi.broadcast` 应保持为 semantic op 或 layout-polymorphic -producer: - -```text -consumer wants contiguous: - materialize scalar into contiguous physical parts - -consumer wants deinterleaved=2: - materialize same scalar into even/odd parts - -consumer wants deinterleaved=4: - materialize same scalar into p0/p1/p2/p3 parts -``` - -因此 broadcast 不进入 `#pto.vmi.layout` 目录。它由 `vmi-layout-assignment` 按 consumer -layout 重物化或下沉到 consumer lowering,而不是作为 `vreg` 的 layout kind。 - -#### Broadcast Materialization - -MLIR SSA value 不能对不同 use 拥有不同 result type。因此 scalar broadcast 的多 layout -适配不是“一个 VMI value 同时带多个 layout”,而是在 layout assignment 中按 use 重物化。 - -semantic VMI: - -```mlir -%b = pto.vmi.broadcast %x : f32 -> !pto.vmi.vreg<128xf32> -%u = pto.vmi.addf %a_contiguous, %b - : !pto.vmi.vreg<128xf32> -%v = pto.vmi.addf %a_split, %b - : !pto.vmi.vreg<128xf32> -``` - -如果 `%u` 需要 `contiguous`,`%v` 需要 `deinterleaved=2`,layout assignment 重写为: - -```mlir -%b0 = pto.vmi.broadcast %x - : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -%u = pto.vmi.addf %a_contiguous, %b0 - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -%b1 = pto.vmi.broadcast %x - : f32 -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -%v = pto.vmi.addf %a_split, %b1 - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -``` - -physical materialization: - -```text -contiguous: - each physical chunk is filled with scalar x - -deinterleaved=2: - even part is filled with scalar x - odd part is filled with scalar x - -deinterleaved=4: - p0/p1/p2/p3 parts are all filled with scalar x -``` - -这要求 `pto.vmi.broadcast` 标记为 rematerializable,并满足 dominance:clone 位置必须被 scalar -operand `%x` dominate。跨控制流时,如果 scalar operand 可在各 predecessor/body 内使用, -优先在 consumer 所在 block 重物化;否则必须在控制流 join 处选择一个具体 layout 并 materialize。 - -这个规则只对 scalar-to-vector broadcast 是零语义风险的。低 rank vector 到高 rank vector 的 -broadcast 可能需要真实 lane replication/shuffle,不能默认按任意 consumer layout 免费重物化; -这类 broadcast 必须携带 broadcast map,并按普通 VMI op 做 layout assignment。 - -VMI register layout 目录因此是: - -```text -contiguous -deinterleaved=2 -deinterleaved=4 -``` - -channel split/merge、pack/unpack、memory stride、dynamic permutation、dynamic compaction -不在目录内。它们分别由 VMI semantic op、conversion、memory access plan、`vmi.permute`、 -`vmi.compress/active_prefix_index` 承接。 - -## Pipeline - -### 1. VMI Producer Boundary - -VMI core pipeline 从合法 VMI semantic IR 开始。Producer 可以是 TileLang/PTO lowering、手写 VMI -测试或其它外部 import 工具,但 producer 不属于 VMI core pipeline。 - -进入 VMI boundary 时必须满足: - -```text -all logical vector semantics are represented by pto.vmi semantic ops -all VMI data/mask values use surface VMI type without layout -no physical VPTO op is introduced -no hidden layout/mask/type side table is required -scalar/tensor/debug/transform boundary has already been handled by producer -``` - -该 boundary 需要 verifier gate。它验证 VMI IR 自身完整,不验证某个外部 source dialect 的 -coverage。 - -### 2. `vmi-layout-assignment` - -该阶段把无 layout VMI type 转换成 layout-assigned VMI type,推荐实现为独立 pass: - -```mlir -!pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -!pto.vmi.mask<128xpred> - -> !pto.vmi.mask<128xb32, #pto.vmi.layout> -``` - -layout assignment 做三件事: - -1. 为每个 producer 选择 natural layout。 -2. 为每个 consumer 协调 operand/result layout。 -3. 在必要处插入: - -```mlir -pto.vmi.ensure_layout -pto.vmi.ensure_mask_layout -pto.vmi.ensure_mask_granularity -``` - -layout assignment 不是局部 pattern 贪心插 conversion,而是约束求解: - -```text -nodes: - every VMI SSA value - block arguments and region/function results - rematerializable producers such as scalar broadcast/iota/constant - -allowed layouts: - contiguous - deinterleaved=2 - deinterleaved=4 - filtered by element type, mask granularity, op capability, and target registry - -hard constraints: - op verifier constraints, such as same-layout elementwise operands - data/mask layout alignment for predicated ops - control-flow block argument/yield/call signature equality - external ABI layout boundary - source/sink contracts for width conversion, load/store, pack/narrow - -soft costs: - natural producer layout preference - ensure_layout materialization cost from target registry - store/load sink cost - rematerialization cost for broadcast/iota/constant - scratch/guarded fallback resource cost -``` - -求解顺序: - -```text -1. Build constraints for the whole region/SCC, including control-flow and call edges. -2. Propagate impossible layouts and required mask granularities. -3. Choose a minimum-cost layout for each node. -4. Use deterministic tie-break: prefer existing natural layout, then contiguous. -5. Insert ensure_layout/ensure_mask_layout or rematerialize producers at chosen use sites. -6. Re-run verifier gates; no hidden side table may be needed to interpret the result. -``` - -如果 hard constraints 冲突,或所有 legal paths 都缺 target capability/resource,报 -`VMI-LAYOUT-CONTRACT` 或更具体 diagnostic。diagnostic payload 必须列出 conflict value、producer -natural layout、consumer required layouts、available conversion paths 和被禁用的 fallback。 - -#### Consumer Layout Demand - -“consumer 需要某个 layout”不是前端语义要求,而是 layout assignment 为了让 operands/results -的 lane-map 对齐并减少 layout conversion 选择的共同 representation。 - -典型例子: - -```mlir -%w = pto.vmi.extf %a - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> - -%b = pto.vmi.broadcast %scalar - : f32 -> !pto.vmi.vreg<128xf32> - -%s = pto.vmi.addf %w, %b - : !pto.vmi.vreg<128xf32> -``` - -`%w` 的 logical 语义是 `128xf32`,但 VPTO `f16 -> f32` 的自然 lowering 产生 even/odd -两路 parts: - -```text -w_even[i] = extf(a[2*i]) -w_odd[i] = extf(a[2*i+1]) -``` - -因此 `%w` 的 natural layout 是: - -```mlir -#pto.vmi.layout -``` - -`addf` 是 layout-polymorphic elementwise op。它有两个合法选择: - -```text -choice A: - materialize %w to contiguous - materialize broadcast to contiguous - do one contiguous add sequence - -choice B: - materialize broadcast directly as deinterleaved=2 - do add on even parts and odd parts separately - keep result as deinterleaved=2 -``` - -choice B 通常更便宜,因为不需要把 `%w_even/%w_odd` 先 interleave 成 contiguous。broadcast -能直接适配 `deinterleaved=2`,是因为它的 logical lanes 全部等于同一个 scalar: - -```text -b_even = [scalar, scalar, ...] -b_odd = [scalar, scalar, ...] -``` - -所以这里说 `addf` consumer “需要” `deinterleaved=2`,准确含义是: - -```text -layout assignment 选择 deinterleaved=2 作为 addf 的共同 operand/result representation, -因为其中一个 operand 的 natural layout 已经是 deinterleaved=2,并且 broadcast 可零语义风险地重物化到该 layout。 -``` - -### 3. `vmi-to-vpto` - -该阶段把 layout-assigned VMI type 做 1:N physical type conversion,推荐实现为独立 pass: - -```text -!pto.vmi.vreg<128xf32, contiguous> - -> !pto.vreg<64xf32>, !pto.vreg<64xf32> - -!pto.vmi.vreg<128xf32, deinterleaved=2> - -> !pto.vreg<64xf32>, !pto.vreg<64xf32> - -!pto.vmi.vreg<256xf32, deinterleaved=4> - -> !pto.vreg<64xf32>, !pto.vreg<64xf32>, - !pto.vreg<64xf32>, !pto.vreg<64xf32> - -!pto.vmi.mask<128xb32, deinterleaved=2> - -> !pto.mask, !pto.mask -``` - -需要 internal projection/materialization op: - -```mlir -%p0, %p1 = pto.vmi.unpack %v - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vreg<64xf32>, !pto.vreg<64xf32> - -%v = pto.vmi.pack %p0, %p1 - : !pto.vreg<64xf32>, !pto.vreg<64xf32> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -``` - -`pack/unpack` 不是新的 layout carrier,只是 layouted `vmi.vreg` 到 physical VPTO parts 的 -projection/materialization。 - -`unpack` 必须能作用在任意 SSA value 上,不能依赖 defining op。VMI value 可以来自 block -argument、`scf.if` result、loop iter_arg、function argument 或 call result;这些 value 没有 -可 look-through 的 layout materialization defining op。 - -`pack/unpack` 的 operand/result 个数必须使用 Physical Arity 公式推导。非整 tile 时,最后一个 -chunk 的 padding lane 仍属于 physical value,但不属于 logical value。 - -### Layout Conversion Materialization - -`pto.vmi.ensure_layout` / `pto.vmi.ensure_mask_layout` 是 logical-value-preserving conversion: - -```text -for every logical lane i: - dst.logical[i] = src.logical[i] -for every padding lane: - dst padding remains unobservable -``` - -source/result layout 完全相同时,`ensure_layout` / `ensure_mask_layout` 是 identity forwarding; -即使存在 partial/tail physical chunk,也不需要 target materialization path。source/result layout -不同时才需要 registry 证明 preserving conversion 及其 full-chunk/tail 处理策略。当前 direct path -允许 equal-arity partial/tail conversion:source/result 的 physical arity 必须相同,且两边都能组成完整 -contiguous/deinterleaved=2/4 `intlv` materialization group;arity-changing partial conversion 和 uneven -deinterleaved groups 继续报 unsupported。 - -合法 materialization path 必须来自 target registry: - -```text -same layout: - no-op - -contiguous <-> deinterleaved=2: - direct interleave/deinterleave register op, load/store dist sink/source, - or scratch/ordered fallback - -contiguous <-> deinterleaved=4: - direct 4-way layout sink/source, proven staged 2-way sequence, - or scratch/ordered fallback - -deinterleaved=2 <-> deinterleaved=4: - convert through contiguous only if both legs have preserving paths, - otherwise use scratch/ordered fallback or report VMI-LAYOUT-CONTRACT -``` - -`deinterleaved=4` 不能默认假设“两次二路 interleave”就是正确 materialization。只有当 staged -sequence 的 lane map 被 registry 证明等价于: - -```text -logical = 4 * lane + p -``` - -才允许使用。否则必须选择 store sink、scratch buffer 或 diagnostic。 - -### Verifier Gates - -每个 pipeline 边界都必须有 hard verifier,不能把残缺 IR 留给后续 pass 猜测: - -```text -at VMI producer boundary: - every logical vector value is represented by !pto.vmi.vreg / !pto.vmi.mask - every logical vector operation is represented by pto.vmi semantic op - no physical VPTO op has been introduced - no hidden layout/mask/type side table is required to interpret a value - -after vmi-layout-assignment: - every !pto.vmi.vreg / !pto.vmi.mask has #pto.vmi.layout - layout kind is one of contiguous/deinterleaved=2/deinterleaved=4 - mask granularity matches each consumer - branch operands, block arguments, function arguments/results, and yields agree on layout - no hidden layout/mask/type side table is required to interpret a value - -before vmi-to-vpto: - every pto.vmi.ensure_layout / ensure_mask_layout has a registered preserving materialization path - every fallback path has resource decision and dominance/lifetime proof - -after vmi-to-vpto: - no pto.vmi op or type remains - no UnrealizedConversionCastOp remains - no pto.vmi.pack/unpack/ensure_* helper remains - every physical value arity matches the Physical Arity helper -``` - -layout、mask、valid-lane 和 physical arity 信息必须存在于 IR type/attr/op operand 中,或可由它们 -纯函数推导;不能依赖 C++ side table。违反这些 gate 时使用 `VMI-PASS-INVARIANT` 或更具体的 -diagnostic,例如 `VMI-LAYOUT-CONTRACT`、`VMI-MEMORY-ACCESS`、`VMI-RESIDUAL-OP`。 - -## Layout Assignment 规则 - -### Elementwise - -same-layout operands: - -```text -vmi.addf/vmi.mulf/vmi.cmpi/vmi.select - fan out per physical part - result keeps operand layout -``` - -different-layout operands: - -```text -choose consumer-demanded layout -insert ensure_layout for other operands -vmi.broadcast can rematerialize in consumer-demanded layout -``` - -### Width Conversion - -典型 natural layout: - -```text -vmi.extf 128xf16 -> 128xf32: - source contiguous f16 - result deinterleaved=2 f32 - -vmi.extf 256xf8 -> 256xf32: - source contiguous f8 - result deinterleaved=4 f32 - -vmi.truncf 128xf32 -> 128xf16: - source may be deinterleaved=2 f32 - result contiguous f16 if pack/store sink requires contiguous - -vmi.truncf 256xf32 -> 256xf8: - source may be deinterleaved=4 f32 - result contiguous f8 if pack/store sink requires contiguous -``` - -Direct `vcvt` lowering 可以覆盖同一 contract 下的 partial/tail case:`extf` 的 logical lanes -必须仍然装进一个 contiguous narrow source physical chunk,并自然产生 deinterleaved=2/4 result; -`truncf` 的 deinterleaved=2/4 source parts 必须能 pack 成一个 contiguous narrow result chunk。 -这些路径允许 VPTO 对 padding lanes 执行 conversion,但 padding 只能流向 result padding lanes, -不能变成 logical result。 - -Mask granularity assignment 把 surface `mask` 转成 concrete -`mask`。consumer 决定所需 granularity: - -```text -f16 op consumes mask -f32 op consumes mask -f8 op consumes mask -``` - -如果 data 从 f16 扩到 f32,后续 f32 consumer 需要: - -```mlir -!pto.vmi.mask -``` - -不能继续复用 `mask`。 - -mask-producing op 的 granularity 不是 producer 固有属性: - -```text -vmi.create_mask / constant_mask: - logical predicate producer; granularity chosen by users - create_mask 的 logical prefix 语义不受目标 PAT_VL token 集合限制; - unsupported PAT_VL count 可以用 pto.plt_b* materialize - constant_mask 的 non-prefix chunk 用 prefix 差分和 predicate boolean ops materialize - -vmi.cmpf/cmpi: - result logical lane count follows compared data - concrete granularity chosen by mask consumers, not by compare element type alone - -multi-use mask: - choose one concrete granularity for the original SSA value - insert ensure_mask_granularity or rematerialize cheap mask producers per use -``` - -`ensure_mask_granularity` 必须 preserve logical predicate lane `mask[i]`。当前 direct lowering 对 -concrete `b8/b16/b32` granularity 使用 `pto.punpack` 做 widening,使用 `pto.ppack` 加 `pto.por` -做 narrowing,并按需要串联相邻级别完成 `b8 <-> b32`。如果目标缺少 predicate rearrangement 或 -granularity conversion,报 `VMI-LAYOUT-CONTRACT`,不能把 b16/b32 mask 当成同一 physical bit -pattern 直接复用。 - -### Predication - -Region-style mask 不作为长期 region op 保留到 VPTO lowering。producer 必须把 mask thread 到 -具体 VMI op: - -```text -masked load/store: - use pto.vmi.masked_load / pto.vmi.masked_store - -masked arithmetic with passthru: - compute candidate result - merge with passthru by pto.vmi.select(mask, candidate, passthru) - -masked reduction/scan: - inactive and padding lanes are excluded from the logical iteration -``` - -如果一个 masked op 的 inactive lane 语义要求“不读内存”或“不执行有副作用操作”,不能用 -full op + select 伪装;必须使用对应 masked VMI op、ordered fallback,或报 target capability -diagnostic。 - -### Memory Ops - -VMI memory op 表达 memory semantics,不表达 register layout。lowering 先构造 access plan: - -```text -base -logical lane count -logical_shape attr, if any -lane-to-address map -contiguity -block-strided row classification -read/write validity mask -padding plan -footprint safety proof -target OOB capability -``` - -memory access map 不是 register layout。比如 `tile_read` 的 memref stride 可以识别 -block-strided rows,并选择 `vsldb`,但 result `vmi.vreg` 的 register layout 仍由 -layout assignment 决定。 - -Producer-specific packed element view 不进入 VMI type。它们必须在 VMI memory op 之前规范化为 -element memref + access map: - -```text -memref> - -> base element type T - -> logical address = original index * K + vector_lane -``` - -normalization 必须保留 offset、stride、alignment、memory space 和 alias 信息。无法证明等价 -element view 时,报 `VMI-MEMORY-ACCESS`,不能把 packed element memref 伪装成 contiguous VMI -load/store。 - -direct path examples: - -```text -contiguous full-safe: - vlds/vsts - !pto.ptr source/destination must be UB-backed; memref source/destination - must either have unknown memory space at this stage or explicitly use - #pto.address_space - -32B block-strided rows with block-uniform mask: - vsldb/vsstb - -interleave/deinterleave boundary: - vldsx2/vstsx2 dist or explicit rearrangement - -indexed memory: - gather/scatter; ordinary scatter requires pairwise-distinct active indices -``` - -GM-backed VMI memory is semantic input, not a direct vector load/store target. -Current `vmi-to-vpto` direct memory lowering emits `pto.vlds`, `pto.vldsx2`, -`pto.vsts`, or `pto.vstsx2`; those VPTO ops operate on UB-backed vector memory. -If a `pto.vmi.load/store/tile_read/tile_write` still names GM at this stage, -the missing step is an explicit memory movement/materialization plan, scratch -plan, or UB view normalization. Otherwise the pass must report `VMI-UNSUPPORTED` -instead of silently producing illegal VPTO. - -### Control Flow - -VMI layouted type 可以跨 internal control flow,但 public ABI 不允许 layout leak。 - -MLIR conversion framework 可以做 region/block/signature 的 structural type conversion,但它不会 -自动决定 layout。`vmi-layout-assignment` 必须先把每个 block argument、region yield、branch -operand 和 call boundary 的 layout 固定下来,再交给 `vmi-to-vpto` 做 1:N type conversion。 - -`scf.if` join: - -```text -if all incoming layouts equal: - keep that layout -else: - choose consumer-demanded layout, otherwise contiguous - insert ensure_layout / ensure_mask_layout before yield -``` - -`scf.for` loop-carried value: - -```text -init layout == iter_arg layout == yield layout == loop result layout -``` - -如果 loop body repeatedly consumes deinterleaved=2/deinterleaved=4,优先保持该 natural layout;如果只有 loop -exit 需要 contiguous,则在 exit 后转换,不在 backedge 每轮转换。 - -`cf.br` / `cf.cond_br` block arguments: - -```text -target block argument has one chosen layout -each predecessor operand is converted to that layout before branch -``` - -function boundary: - -```text -internal VMI functions: - function argument/result layout is part of layout assignment - all callsites and returns must agree with the specialized signature layout - -external/public ABI: - must not expose #pto.vmi.layout - materialize to memory, scalar ABI, or final physical PTO ABI before crossing boundary -``` - -recursive or mutually recursive VMI functions require SCC fixed-point layout assignment. If a stable signature -layout cannot be found without inserting conversion on every cycle edge, choose `contiguous` at the function -boundary and keep deinterleaved layouts inside the function body. - -## VMI Op Families - -本节列出 VMI 必须拥有的 semantic op。assembly form 可在 ODS 中微调,但语义边界应保持。 -表中用 `/` 写在一起的名字表示多个独立 op,不表示一个 variadic opcode。去重后,正式 -semantic op 数量是 75 个。 -`ensure_layout`、`ensure_mask_layout`、`ensure_mask_granularity`、`pack`、`unpack` 是内部 -layout/materialization helper,不计入 semantic op;如果把 helper 也算作 VMI op,总数是 80 个。 - -该总表描述目标 semantic surface,不等价于当前第一批实现清单。当前 implementation slice -以 `docs/designs/vmi-implementation-manual.md` 的 Slice 1 为准;例如 `pto.vmi.from_elements` -虽然属于目标 construction family,但没有 scalar lane insert、vreg immediate 或 scratch -materialization plan 前不能宣称 direct lowering 已支持。 - -```text -construction: 6 -memory: 10 -arithmetic/conversion: 36 -permutation/mask/reduction/channel: 23 -semantic total: 75 -internal helpers: 5 -total including helpers: 80 -``` - -### Construction - -| Op | 语义 | -|---|---| -| `pto.vmi.constant` | logical constant vector,layout assignment 决定 materialization | -| `pto.vmi.broadcast` | scalar 或低 rank value broadcast 到 `vreg` | -| `pto.vmi.iota` | 从 scalar base 生成 logical lane index/value vector | -| `pto.vmi.from_elements` | 按 logical lane order 构造 | -| `pto.vmi.create_mask` | prefix 或 logical-shape mask | -| `pto.vmi.constant_mask` | static logical predicate mask, including non-prefix masks | -| `pto.vmi.mask_and/or/xor/not` | logical predicate elementwise operation | - -### Memory - -```mlir -%v = pto.vmi.load %base[%idx] - : memref -> !pto.vmi.vreg<128xf32> - -pto.vmi.store %v, %base[%idx] - : !pto.vmi.vreg<128xf32>, memref - -%v = pto.vmi.masked_load %base[%idx], %mask, %passthru - : memref, !pto.vmi.mask<128xpred>, - !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> - -pto.vmi.masked_store %v, %base[%idx], %mask - : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> - -%g = pto.vmi.gather %base[%indices], %mask, %passthru - : memref, !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred>, - !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> - -pto.vmi.scatter %v, %base[%indices], %mask - : !pto.vmi.vreg<128xf32>, memref, - !pto.vmi.vreg<128xindex>, !pto.vmi.mask<128xpred> - -%e = pto.vmi.expand_load %base[%idx], %mask, %passthru - : memref, !pto.vmi.mask<128xpred>, - !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> - -pto.vmi.compress_store %v, %base[%idx], %mask - : !pto.vmi.vreg<128xf32>, memref, !pto.vmi.mask<128xpred> -``` - -`masked_load` 的 inactive lane 不能产生 memory read。full load + select 只有在 inactive -lane 地址 safe-readable 时才合法。 -当前直接 lowering 只覆盖 contiguous result/passthru/mask:full physical chunks 直接 `vlds + vsel`; -partial/tail chunks 必须先证明完整 physical read footprint safe-readable,否则报 `VMI-UNSUPPORTED`。 -在第一阶段的矩阵 quant/dequant lowering 中,默认假设 UB 中的行数据按元素连续,tail load 可以安全读满 -当前物理 vreg;tail 的对外写入效果仍由 `pto.vmi.create_mask` + `pto.vmi.masked_store` -约束。严格 no-read tail 不是这个默认路径的语义,后续通过 stable gather 模式承接:该模式应把 -contiguous tail masked load 转为 `VGATHER2 + Pg` 风格的 per-lane non-faulting load。当前 -`vmi-to-vpto` 只预留 `enable-stable-gather-masked-load` 开关;开关打开且遇到 -`pto.vmi.masked_load` 时必须给 TODO diagnostic,不能退化成普通 `vlds + vsel`。 - -普通 `vmi.store` 和 `vmi.masked_store` 的 contiguous tail 可以用 true predicate store 承接: -full physical chunk 使用 all-true mask 或用户 mask,最后一个 partial chunk 使用 prefix valid-lane -mask;因此普通 `vmi.store` direct lowering 要求 value element width 能对应 -`pto.mask`。`masked_store` 先把用户 mask 与 valid-lane mask 做 logical AND。 -deinterleaved=2/4 tail store/masked_store 只有在每个 deinterleaved part 的 physical chunk 数相同、可先组成完整 -`vintlv/pintlv` group 并 materialize 成 contiguous chunks 时才直接支持;materialized 后 active -lane 为 0 的 padding-only chunk 不发 store。load padding 仍需要独立的 access plan,不能通过未受保护的 -full-footprint memory op 偷跑。 - -`gather/scatter` 使用 logical lane order 解释 `%indices`,index 单位和 memref element type -一致。`gather` inactive lane 返回 `%passthru[i]` 且不能读内存。`scatter` inactive lane 不能写 -内存;如果 active lanes 可能写同一地址,direct VPTO lowering 必须证明目标语义与 logical -lane order 等价,否则使用 ordered fallback 或报 `VMI-MEMORY-ACCESS`。 - -当前 `gather` direct lowering 覆盖一个保守子集: - -```text -source: - !pto.ptr - -layout: - result / indices / mask / passthru all contiguous - all physical chunks are full, so padding lanes cannot trigger memory reads - -type: - T is 32-bit element type - indices are signless or unsigned i32 - mask granularity is b32 - -lowering: - gathered = pto.vgather2_bc source, indices, mask - result = pto.vsel gathered, passthru, mask -``` - -`VGATHER2_BC` false predicate lanes do not read memory but produce zero result lanes. VMI `gather` requires false -lanes to preserve passthru, so the `vsel` is semantically required, not an optimization artifact. `f16/b16/f8/i8` -gather, tail gather, non-contiguous layout, memref/gm source, and fallback through guarded scalar load or scratch are -future target-capability paths. - -`scatter` 的基础语义要求所有 active logical lanes 的 `%indices` 两两不同。inactive lane 不写内存, -因此不参与这个唯一性约束。如果两个 active lane 的 index 相同,程序违反 `pto.vmi.scatter` 的 -语义前置条件;VMI 不为这种输入定义 logical lane order 或 winner。 - -```mlir -pto.vmi.scatter %v, %base[%indices], %mask - : !pto.vmi.vreg<64xf32>, !pto.ptr, - !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> -``` - -当前 direct path 的其它限制与 gather 对齐:UB pointer destination、contiguous full physical chunks、 -32-bit value element、i32 indices 和 b32 mask。允许冲突的 scatter 不能复用普通 `pto.vmi.scatter`, -因为 `VSCATTER` 对重复 index 的 grant procedure 是目标相关/未定义的,不等价于确定的 VMI logical -lane order。后续如果需要定义 duplicate-index scatter,需要新增显式语义,例如 ordered fallback、 -atomic scatter、reduce-scatter 或 target-specific unordered scatter。 - -`expand_load/compress_store` 表达 masked contiguous stream,不是 arbitrary indexed access: - -```text -expand_load: - k = 0 - for i in 0 .. N: - if mask[i]: - result[i] = base[idx + k] - k += 1 - else: - result[i] = passthru[i] - -compress_store: - k = 0 - for i in 0 .. N: - if mask[i]: - base[idx + k] = value[i] - k += 1 -``` - -Current direct `expand_load` lowering supports two paths. The first is the -degenerate all-active case: - -```text -mask == all_true => expand_load(base[idx], mask, passthru) == load(base[idx]) -``` - -The accepted mask must be statically proven all active through -`pto.vmi.create_mask` with constant `active_lanes >= N`, or a dense all-true -`pto.vmi.constant_mask`. The result, passthru, and mask layouts must be -contiguous. Partial/tail chunks still need the same safe full-read proof as -ordinary `vmi.load`; otherwise the direct path reports `VMI-UNSUPPORTED`. - -The second direct path covers one full 32-bit UB physical chunk with a runtime -mask: - -```text -base' = pto.addptr base, idx -indices = pto.vusqz(zero_i32_carrier, mask) -gathered = pto.vgather2_bc base', indices, mask -result = pto.vsel gathered, passthru, mask -``` - -It requires contiguous result/passthru/mask layout, 32-bit element type, b32 -mask granularity and one full physical chunk. Multi-chunk runtime masks need a -cross-chunk prefix-count carry; f16/b16/f8/i8 need a gather packing contract. -Unsupported cases still require guarded load, scratch fallback, or diagnostic, -and must not be lowered as a plain full load. - -Current direct `compress_store` lowering is intentionally narrower than the -surface semantics. It requires contiguous value/mask layout, exactly one full -physical chunk, and a UB `!pto.ptr` destination. The direct sequence is: - -```text -store_base = pto.addptr base, idx -sqz = pto.vsqz value, mask -align0 = pto.init_align -align1 = pto.vstur align0, sqz, store_base, "POST_UPDATE" -pto.vstar align1, store_base -``` - -The paired `vstur` consumer is what makes the later VPTO LLVM emitter select -`VSQZ #st=1`; emitting `vsqz` without that store consumer is only register -compress. Full physical chunk is required in this first path because padding -mask lanes must not be squeezed into memory. Multi-chunk `compress_store` -needs cross-chunk compaction and SQZN/store-state planning; deinterleaved -layouts need logical lane order reconstruction before the store chain. - -### Index And Address Contract - -`!pto.vmi.vreg` 是 logical index vector,不是 physical address vector。进入 VPTO 前, -index 必须按 target registry legalize 成目标支持的整数宽度: - -```text -index legalization: - choose target index bitwidth - prove every lane value fits, or insert preserving extend/trunc/check sequence - preserve signedness required by the consuming op -``` - -memory op 的 index 单位是 memref element,不是 byte。byte address 由 memref layout、element -size、base offset 和 lane index 共同计算: - -```text -logical element offset -> memref affine/strided map -> byte address -``` - -`gather/scatter` 的 `%indices`、`expand_load/compress_store` 的 active-prefix offset、`iota` 生成 -的 lane index 都必须在同一套 address unit 下解释。不能把 element index 直接当 byte offset,也 -不能在没有 range proof 时把 `index` 静默截断成较窄整数。 - -`active_prefix_index(mask)` 返回当前 lane 之前的 active lane 数: - -```text -idx[i] = popcount(mask[0 .. i)) -``` - -因此 `expand_load/compress_store` active lane 使用 `base + idx[i]`。如果目标缺少 prefix-popcount -或 index-vector lowering,必须选择 index-buffer/guarded fallback,或报 `VMI-FALLBACK-RESOURCE` -/ `VMI-LAYOUT-CONTRACT`。 - -`tile_read/tile_write` 承接 transfer-style padding 和 multi-dimensional access semantics: - -```mlir -%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask - {logical_shape = [8, 8], - permutation_map = affine_map<(d0, d1) -> (d0, d1)>} - : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, - !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> - -pto.vmi.tile_write %tile, %view[%c0, %c0], %mask - {logical_shape = [8, 8], - permutation_map = affine_map<(d0, d1) -> (d0, d1)>} - : !pto.vmi.vreg<64xf32>, memref<8x8xf32, strided<[?, 1], offset: ?>>, - !pto.vmi.mask<64xpred> -``` - -`tile_read/tile_write` 只承接 memref memory semantics。producer 的 transfer-style read/write 如果作用在 -tensor source/destination 上,必须在进入 VMI 前 bufferize 成 memref access plan,或退出 PTO -路线。tensor write-back style 语义是产生新 tensor,不是对 memref 的 memory effect;不能把它 -伪装成 `pto.vmi.tile_write`。未处理的 tensor transfer 报 `VMI-TENSOR-BOUNDARY`。 - -`tile_read` invalid lane 的 result 必须等于 padding,不是后继 op 的 inactive lane。 - -`tile_read` lowering 必须先构造三个对象: - -```text -validMask(result lane): - logical lane is inside result shape - and explicit transfer mask maps to true - and source address is in bounds - -paddingValue(result lane): - scalar padding: same value for every invalid lane - vector-element padding: select element by suffix coordinate - broadcast/permuted padding: apply the same result-lane map as data - -safeReadProof: - proves the actual physical load footprint is safe-readable - independent from validMask -``` - -`validMask=false` 只说明 result lane 应等于 padding,不说明该 lane 的 source address 可以被读。 -因此 `tile_read` 的 preserving lowering 决策是: - -```text -safeReadProof == full and validMask all-true: - direct load - -safeReadProof == full and validMask not all-true: - loaded = full load - pad = materialize paddingValue in result layout - result = select(validMask, loaded, pad) - -target has true masked/non-faulting load: - loaded = masked load with inactive lanes not read - pad = materialize paddingValue in result layout - result = select(validMask, loaded, pad) unless inactive result is already padding - -safeReadProof != full: - split full-safe and partial paths, or - fill scratch with paddingValue, guarded-copy only valid lanes, then load scratch, or - use guarded scalar/vector fallback -``` - -First implementation stage note: - -```text -The padding-preserving branches above are semantic requirements for the full -design, but they are not part of the first-stage VMI implementation. The first -stage may lower only all-valid direct reads, or physical-tail reads whose extra -lanes are outside the logical VMI value and remain unobservable. If invalid -logical lanes require transfer_read paddingValue materialization, true -masked/non-faulting load, scratch, or guarded fallback, lowering must stop with -the implementation diagnostic code VMI-UNSUPPORTED instead of emitting an -approximate full load. -``` - -如果所有 preserving paths 都因 target capability 或 option 被禁用,报 `VMI-MEMORY-ACCESS`, -payload 必须指出缺的是 unsafe partial `tile_read` padding-preserving path。 - -`tile_write` 没有 padding value,但有 write-valid mask: - -```text -writeMask(source lane): - logical lane is inside source shape - and explicit transfer mask maps to true - and destination address is in bounds -``` - -`writeMask=false` 的 lane 不能产生 memory effect。只有 full physical footprint safe-writable 且 -writeMask all-true 时,才能使用 predicate-ignored store。partial write 必须使用 true masked -store、split/guarded fallback、scatter-like fallback,或报 `VMI-MEMORY-ACCESS`。 -当前 direct `vmi.tile_write` 只覆盖 flat contiguous tail:最后一个 partial chunk 使用 prefix -valid-lane predicate 发 `vsts`,同样要求 value element width 能对应 `pto.mask`。 -deinterleaved=2/4 tail 只有在能先完整 materialize 到 contiguous -chunks 时直接支持,padding-only materialized chunk 不发 store;带 transfer mask coordinate remap 的 -tile write 仍必须走独立 access plan。 - -explicit transfer mask 的坐标属于 transfer access space,不一定等于 flattened result/source lane -坐标。non-minor-identity transfer 必须先做 predicate coordinate remap;缺少 remap capability 时, -diagnostic 必须点名 transfer mask coordinate remap,而不是泛化成普通 memory failure。 - -### Arithmetic And Conversion - -VMI 不复用外部 elementwise arithmetic op。需要定义对应 VMI op: - -| Semantic | VMI op | -|---|---| -| float binary | `pto.vmi.addf/subf/mulf/divf/minf/maxf` | -| float unary | `pto.vmi.negf/sqrt/exp/ln/relu` | -| integer binary | `pto.vmi.addi/subi/muli` | -| bitwise/shift | `pto.vmi.andi/ori/xori/not/shli/shrui` | -| fused multiply-add | `pto.vmi.fma` | -| float casts | `pto.vmi.extf/truncf` | -| bitcast | `pto.vmi.bitcast` | -| compare/select | `pto.vmi.cmpf/cmpi/select` | - -Integer div/rem, arithmetic right shift, integer casts, int-float casts, and -index casts are intentionally not in the current VMI surface. They need -explicit signedness, rounding, saturation, overflow/remainder, and VPTO target -contracts before ODS ops are introduced. - -producer constant 转成 `pto.vmi.constant`,包括 dense、splat 和 rank-0 logical vector。 -constant 的 element type、shape、splatness 和 poison/undef 属性如果存在,必须保留到 VMI -constant attr;padding physical lane 仍按 VMI padding rule 处理,不能把 padding lane 当成用户 -constant lane。 - -当前 VPTO direct lowering 只把 scalar broadcast 和 splat constant materialize 成 -`pto.vdup`。这条路径与逐元素 op 一样要求 physical element width 能对应 -`pto.mask`;其它 element type 或非 splat constant 必须先有明确的 materialization -contract,否则报 `VMI-UNSUPPORTED`。 - -VMI arithmetic op 必须保留原 `arith` op 的 numeric contract: - -```text -floating point: - fastmath flags - rounding mode, if present - NaN / signed-zero / inf behavior implied by flags - -integer: - signedness of div/rem/compare/extend - overflow flags such as nsw/nuw when present - truncation and extension width rules - -compare/select: - cmpf/cmpi predicate - select condition mask granularity and layout -``` - -lowering 不能因为 VPTO 有更快指令就加强或放松这些属性。比如没有 fastmath 允许时,`fma` -不能拆成 `mulf + addf`,也不能把 `mulf + addf` 合成 `fma`;带 `nsw/nuw` 的 integer op -可以利用 flag 做优化,不带 flag 的 op 必须保持 wraparound/defined overflow 语义。 - -`pto.vmi.fma` 不能默认拆成 `mulf + addf`。`bitcast` 只有在当前 contiguous/deinterleaved -layout 下 bit grouping physically adjacent、且每个对应 physical chunk 的 logical bit -footprint 相同时才能 direct;padding bits 只能流向 result padding bits。group_slots bitcast -暂不复用这个规则,必须等 slot-wise bitcast contract 定义清楚后再支持。否则需要 layout -conversion、scratch materialization 或 target capability diagnostic。 - -当前 VPTO direct lowering 对逐元素算术、逻辑、比较和 select 还有一条共同硬约束:物理 element -width 必须能对应到 `pto.mask`。因此 VMI 语义层可以承载 `index` 或 `f64` -这类类型,但在没有独立 lowering contract 前,`vmi-to-vpto` 必须报 `VMI-UNSUPPORTED`, -不能让 OneToN conversion 或 residual gate 隐式失败。 - -这条共同约束不是唯一约束。某些目标 VPTO/VISA op 还有自己的 element type contract, -必须在 `vmi-to-vpto` preflight 中单独检查。当前 direct lowering 明确承诺: - -```text -addf/subf/mulf: f16/bf16/f32 -divf: f16/f32 -minf/maxf: f16/bf16/f32 -negf/absf: f16/f32 -sqrt/exp/ln: f16/f32 -relu: f16/f32 -absi: signless/signed i8/i16/i32 -cmpf: f16/bf16/f32 -cmpi: signless/signed/unsigned i8/i16/i32 -``` - -因此 bf16/f8 虽然可能是合法 VMI float-like type 且能 materialize b16/b8 predicate mask, -但只要目标 direct op 不承诺该 element type,`vmi-to-vpto` 就必须先报 -`VMI-UNSUPPORTED`,直到定义对应 materialization 或 VPTO 目标能力。 - -当前 direct lowering 将 `pto.vmi.fma %lhs, %rhs, %acc` 映射为每个 physical part 上的 -`pto.vmula %acc_part, %lhs_part, %rhs_part, %all_true_mask`。该路径只承诺 f16/bf16/f32 -floating-point fused multiply-add;整数 multiply-accumulate、带 rounding/fastmath 变体或需要 -不同 accumulator 精度的形式必须单独建模,不能复用这个 op 偷换语义。 - -### Permutation, Mask, Reduction, Channel - -| Semantic | VMI op | -|---|---| -| static lane map | `pto.vmi.shuffle` | -| dynamic indexed lane map | `pto.vmi.permute` | -| logical interleave/deinterleave | `pto.vmi.interleave/deinterleave` | -| shape metadata change | `pto.vmi.shape_cast/reshape/transpose` | -| subvector update | `pto.vmi.slice/insert_slice/insert_element` | -| predicate logic | `pto.vmi.mask_and/or/xor/not` | -| prefix active index | `pto.vmi.active_prefix_index` | -| register compaction/expansion | `pto.vmi.compress/expand` | -| reduction/scan | `pto.vmi.reduction/scan` | -| contraction | `pto.vmi.contract/outerproduct` | -| channel split/merge | `pto.vmi.channel_split/channel_merge` | - -`pto.vmi.shuffle` 表达完整 static lane map。当前 VPTO direct lowering 先识别 physical chunk -forwarding:每个 result physical chunk 的所有非 padding lanes 必须来自同一个 source chunk, -且 source lane number 等于 result lane number;result padding lanes 不参与证明,forward 过来的 -物理 padding lanes 仍然不可观察。否则在每个 result physical chunk 都来自同一个 source chunk、 -result chunk 没有 padding lane、且 source lane index 是 ASC/DESC 连续序列时,用 `pto.vci` -生成 index vector 并发 `pto.vselr`。任意非 affine permutation、以及需要 tail lane 重排但无法安全 -materialize tail index vector 的场景,仍然需要通用 index-vector materialization、scratch fallback -或 target capability diagnostic。 - -`channel_split/channel_merge` 是 PTO-specific semantic op。它们表达用户按 channel 编程时的 -多个 logical VMI values,不能降格成 -`#pto.vmi.layout` kind。它们必须拥有 static shuffle 等价定义,canonicalization 可以双向进行: -识别出的 shuffle pattern 可以变成 channel op,channel op 也可以合法展开回 shuffle。 -Direct lowering 还必须证明 physical group 完整;否则即使 logical shuffle 语义成立,也要报 -target capability/materialization diagnostic,而不是让 OneToN pattern 在中途失败。 - -### Internal Layout Helpers - -这些 op 只允许存在于 VMI lowering 的中间阶段,不能作为 VMI semantic surface,也不能残留到 -physical VPTO 之后: - -| Op | 语义 | -|---|---| -| `pto.vmi.ensure_layout` | data vreg layout-preserving conversion | -| `pto.vmi.ensure_mask_layout` | mask layout-preserving conversion | -| `pto.vmi.ensure_mask_granularity` | logical predicate-preserving granularity conversion | -| `pto.vmi.unpack` | layouted VMI value projection to physical VPTO parts | -| `pto.vmi.pack` | physical VPTO parts materialized as one layouted VMI value | - -`active_prefix_index` 语义是: - -```text -idx[i] = popcount(mask[0 .. i)) -``` - -VMI surface 不暴露 VPTO `vusqz` 的无意义 source operand;需要 type/ABI carrier 时在 -`vmi-to-vpto` late materialize。 - -当前直接 lowering 只覆盖 contiguous 单物理 chunk。这个 case 可以用 `pto.vusqz` 精确承接: -`vmi-to-vpto` 先 materialize 一个 zero vreg 作为 VPTO `vusqz` 的 source carrier,再把 VMI mask -作为 governing predicate 传入。多物理 chunk 需要把前一 chunk 的 active count carry 到后一 chunk; -deinterleaved layout 还需要按逻辑 lane 顺序重建 prefix,因此不能逐物理 part 独立发 `vusqz`。 - -`vmi.compress(source, mask)` 语义是按 logical lane order 保留 active source lane 并压缩到结果前缀。 -当前直接 lowering 只覆盖 contiguous 单个 full physical chunk,可以用 `pto.vsqz(source, mask)` 承接。 -partial/tail chunk 不能直接走 `vsqz`,因为 padding mask lane 如果为 true,padding source lane 可能被 -压缩到可观察的 result 前缀。多物理 chunk 需要跨 chunk compaction;`compress_store` 还涉及 -`VSQZ #st=1` 与 `VSTUR`/`SQZN` 的配对约束,不能由 register `compress` 自动推出。 - -`vmi.compress_store(value, base[idx], mask)` 语义是按 logical lane order 把 active lane 写成连续 -memory stream。当前直接 lowering 只覆盖 contiguous、单个 full physical chunk 和 UB pointer -destination,并发出 `pto.vsqz -> pto.vstur POST_UPDATE -> pto.vstar` 的完整 store-state chain。非 full -chunk 暂不直接 lowering,因为 padding mask lane 可能被硬件 squeeze 成额外写出;multi-chunk 需要 -跨 chunk active count 和 SQZN FIFO/VSTUR 配对计划。 - -`shape_cast/reshape/transpose` 必须区分 metadata change 和 lane movement: - -```text -shape_cast / reshape: - preserve row-major flattened lane order - produce explicit result logical_shape attr - -transpose / flat_transpose: - changes logical lane order according to permutation - must lower through shuffle/permute/layout conversion/direct transpose capability -``` - -这些 op 的 source/result shape、permutation 和 broadcast map 都是 op attrs。VMI lowering 不能从 -producer defining op 或 side table 推断缺失 shape。 - -低 rank vector 到高 rank vector 的 broadcast 也不能当成 scalar broadcast 免费重物化。它必须 -保存 broadcast map: - -```text -result[indices] = source[broadcast_map(indices)] -``` - -只有 scalar-to-vector broadcast 可以按 consumer layout 任意重物化。 - -`iota` 是 lane index generation 的 VMI 表达: - -```text -iota(base, ASC): - result[i] = base + i - -iota(base, DESC): - result[i] = base - i -``` - -第一版 `iota` 的 `T` 跟随 VPTO `vci` 能承接的元素类型:integer 8/16/32 和 f16/f32。 -可变 step 不是 surface op 语义的一部分;如果 producer 需要 `base + i * step`,应表达为 -`iota(base=0) -> muli/vmi arithmetic -> addi/addf` 组合,或后续单独引入带 step 的 op。 -tail physical chunk 的 padding lane 可以承接 iota 的自然延续值,但这些 lane 不是 logical lane; -后续 memory/mask/reduction 等有外部效果的 consumer 必须继续按 valid logical lane 保护。 -deinterleaved layout 下的 physical part 需要 strided index materialization: - -```text -part p contains logical lanes p, p + factor, p + 2 * factor, ... -ASC value = base + p + factor * local_lane -DESC value = base - p - factor * local_lane -``` - -因此 direct `vci` 只覆盖 contiguous full-chunk path;deinterleaved path 必须额外物化 -`vci(0) * factor + base +/- p`,不能误降成每个 part 内连续的 `vci(base + p)`。当前 lowering -按 physical part 生成 `vci(0) + vmuls(factor) + vadds/vdup/vsub` 序列;padding/tail chunk -仍然需要独立的 padding-safe materialization plan。 - -`slice/insert_slice` 都按 logical lane order 定义,不读取或写入 padding lane: - -```text -slice(offset, size, stride): - result[j] = source[offset + j * stride] - -insert_slice(offset, stride): - result = dest - result[offset + j * stride] = update[j] - -insert_element(pos): - result = dest - result[pos] = scalar -``` - -`reduction/scan` 的 logical iteration 只覆盖 active logical lanes,padding lanes 不参与: - -```text -reduction(op, init, value, mask): - acc = init - for i in 0 .. N: - if mask is absent or mask[i]: - acc = op(acc, value[i]) - result = acc - -scan(op, init, value, mask): - acc = init - for i in 0 .. N: - if mask is absent or mask[i]: - acc = op(acc, value[i]) - result[i] = acc - else: - result[i] = passthru_or_identity -``` - -Current direct reduction support starts with integer add: - -```mlir -%r = pto.vmi.reduce_addi %value, %init, %mask - : !pto.vmi.vreg<64xi32>, !pto.vmi.vreg<1xi32>, - !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xi32> - -%rf = pto.vmi.reduce_addf %value, %init, %mask {reassoc} - : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, - !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> - -%rmax = pto.vmi.reduce_maxf %value, %init, %mask - : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<1xf32>, - !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<1xf32> - -%rmin = pto.vmi.reduce_minf %value, %init, %mask - : !pto.vmi.vreg<128xf16>, !pto.vmi.vreg<1xf16>, - !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf16> -``` - -`reduce_addi` preserves integer wraparound addition semantics. The direct -lowering requires contiguous layout, full 32-bit source physical chunks, -matching mask chunks, and one rank-0 init/result chunk. It emits `pto.vcadd` -for each masked source chunk, then serially accumulates each chunk result into -the rank-0 accumulator with `pto.vadd` under a `PAT_VL1` predicate. Padding -source lanes are rejected instead of being allowed to participate. - -`reduce_addf` is legal only with an explicit `{reassoc}` contract because the -ISA documents pair-wise FP reduction order. The direct lowering supports only -f32, contiguous layout, full source physical chunks, matching b32 mask chunks, -and one rank-0 init/result chunk. It uses the same per-chunk `vcadd` plus -serial `PAT_VL1 vadd` accumulation shape. Without `{reassoc}`, the verifier -rejects the op instead of silently changing ordered floating-point semantics. - -`reduce_maxf` and `reduce_minf` preserve VPTO-compatible floating-point min/max -reduction semantics. Direct lowering supports f16/f32, contiguous layout, full -source physical chunks, matching mask chunks, and one rank-0 init/result chunk. -For each physical source chunk, lowering emits `pto.vcmax` or `pto.vcmin`. -The chunk result's lowest lane is then accumulated into the rank-0 accumulator -with `pto.vmax` or `pto.vmin` under a `PAT_VL1` predicate. The index value that -`vcmax/vcmin` writes to the second lane is intentionally not part of the VMI op -result and is discarded by only observing lane 0. Inactive lane identities, -signed zero handling, and NaN behavior follow the underlying `vcmax/vcmin` and -`vmax/vmin` VPTO instructions. Padding source lanes are rejected, because the -logical reduction must not allow padding to become an inactive-lane identity or -a NaN-producing participant. - -lowering 可以选择 VPTO reduction/scan 指令、tree decomposition、scratch memory 或 scalarized -ordered fallback,但必须保持 numeric contract。没有目标能力时使用 `VMI-ELEMENT-TYPE` 或 -`VMI-LAYOUT-CONTRACT`,不能让未 lower 的逻辑向量 op 残留到 VPTO。 - -`contract/outerproduct` 在 VMI 中保留 indexing maps、iterator types、accumulator、mask 和 -element type,并且不允许绕过 VMI 直接回到其它向量 IR。如果目标有直接 matrix/vector contract -能力,lower 到直接 VPTO sequence;否则按 iterator space 分解成 VMI arithmetic + -reduction/scan,再走普通 VMI lowering。只有当 element type、accumulator 精度或 iterator -semantics 无法由目标表达时,才报 target capability diagnostic。 - -如果 producer 的 extract-like operation 结果仍是 logical vector,应表达成 `pto.vmi.slice`、 -`pto.vmi.shuffle` 或 `pto.vmi.shape_cast`。如果结果是 scalar,则属于 vector-to-scalar boundary, -不进入 VMI vector path,也不产生 `pto.vmi.extract`: - -```text -VMI-SCALAR-EXTRACT-BOUNDARY -``` - -## End-To-End Examples - -### f16 Widen Add Store - -Semantic VMI: - -```mlir -%a = pto.vmi.load %A[%i] - : memref -> !pto.vmi.vreg<128xf16> -%w = pto.vmi.extf %a - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> -%s = pto.vmi.addf %w, %bias - : !pto.vmi.vreg<128xf32> -pto.vmi.store %s, %C[%i] - : !pto.vmi.vreg<128xf32>, memref -``` - -Layout-assigned VMI: - -```mlir -%a = pto.vmi.load %A[%i] - : memref -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -%w = pto.vmi.extf %a - : !pto.vmi.vreg<128xf16, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -%s = pto.vmi.addf %w, %bias - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -pto.vmi.store %s, %C[%i] - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref -``` - -Physical lowering 可以生成 EVEN/ODD `vcvt`、两路 `vadd`,并在 store sink 使用 interleave -store 或显式 layout conversion。 - -### f8 To f32 - -```mlir -%a = pto.vmi.load %A[%i] - : memref -> !pto.vmi.vreg<256xf8> -%w = pto.vmi.extf %a - : !pto.vmi.vreg<256xf8> -> !pto.vmi.vreg<256xf32> -%s = pto.vmi.addf %w, %b - : !pto.vmi.vreg<256xf32> -pto.vmi.store %s, %C[%i] - : !pto.vmi.vreg<256xf32>, memref -``` - -layout assignment 可把 `%w/%s` 设为 `#pto.vmi.layout`。contiguous store 必须使用 -已验证的 layout sink 或先 materialize contiguous representation,不能把 p0/p1/p2/p3 part 当成连续内存写出。 - -### Block-Strided Tile Read - -```mlir -%tile = pto.vmi.tile_read %view[%c0, %c0], %pad, %mask - {logical_shape = [8, 8], - permutation_map = affine_map<(d0, d1) -> (d0, d1)>} - : memref<8x8xf32, strided<[?, 1], offset: ?>>, f32, - !pto.vmi.mask<64xpred> -> !pto.vmi.vreg<64xf32> -``` - -如果 access plan 证明每 row 是 32B contiguous block,row 间 stride 可落到 ISA stride 字段, -且 mask block-uniform,lowering 可以选择 `vsldb`。如果 padding 非零,仍需在 load 后用 -valid mask 修正 invalid lane。 - -## Risk Closure Matrix - -| 风险 | 设计闭环 | 测试出口 | -|---|---|---| -| producer 直接绕过 VMI 生成 physical VPTO | VMI Producer Boundary Contract + Verifier Gates | `vmi_producer_boundary.mlir`, `vmi_pipeline_hard_gates.mlir` | -| arith numeric contract 被 VPTO 快速路径改写 | fastmath/rounding/overflow/cmp predicate preservation | `vmi_arith_numeric_contract.mlir` | -| layout 设计泛化失控 | closed `contiguous/deinterleaved=2/4` layout set + source contract | `vmi_f16_ext_add_store_deinterleaved2.mlir`, `vmi_f8_ext_add_store_deinterleaved4.mlir` | -| layout assignment 局部贪心导致控制流/多 use 错误 | region/SCC constraint solver + deterministic tie-break | `vmi_layout_assignment_constraint_solver.mlir`, `vmi_cf_and_call_layout_boundary.mlir` | -| 1:N physicalization arity 漂移 | Physical Arity helper + hard gate | `vmi_physical_arity_non_full_deinterleaved.mlir` | -| `deinterleaved=4` materialization 错 lane | registered preserving materialization path | `vmi_ensure_layout_materialization_contract.mlir` | -| mask granularity 过早固化 | surface `mask` + consumer-driven granularity assignment | `vmi_mask_granularity_width_change.mlir` | -| non-scalar broadcast / transpose 被当成 metadata | explicit broadcast map and lane-movement semantics | `vmi_shape_broadcast_semantics.mlir` | -| transfer padding / OOB read 写成 full load/store | `validMask` / `paddingValue` / `safeReadProof` / `writeMask` decision tree | `vmi_tile_read_padding_decision_tree.mlir`, `vmi_tile_write_oob_no_effect.mlir` | -| index/address 单位或宽度被误用 | index/address legalization contract | `vmi_index_address_legalization.mlir` | -| reduction/scan/contract 回退成 residual logical-vector op | VMI semantic op + direct/decompose/scratch lowering contract | `vmi_reduction_scan_contract_coverage.mlir` | -| shape 信息依赖 hidden side table | flat VMI value + shape-sensitive op attrs | `vmi_shape_broadcast_semantics.mlir`, `vmi_pipeline_hard_gates.mlir` | -| fallback 缺资源时退化成残缺 lowering | explicit fallback resource contract + `VMI-FALLBACK-RESOURCE` | `vmi_fallback_resource_diagnostics.mlir` | -| tensor/debug/scalar boundary 混入 VMI | explicit boundary diagnostics | `vmi_tensor_transfer_boundary.mlir`, `vmi_debug_boundary.mlir`, `vmi_extract_boundary.mlir` | - -## Diagnostics - -| Code | 场景 | -|---|---| -| `VMI-SCALAR-EXTRACT-BOUNDARY` | scalar lane extract 不是 VMI vector op,必须在进入 VMI 前消除或退出 PTO 路线 | -| `VMI-SCALABLE-VECTOR` | scalable vector 未在进入 VMI 前 specialize 成固定 logical lane count | -| `VMI-ELEMENT-TYPE` | target registry 缺 storage/compute/convert capability | -| `VMI-LAYOUT-CONTRACT` | VMI layout、mask granularity 或控制流/调用边界约束冲突 | -| `VMI-MEMORY-ACCESS` | access plan 无 direct/fallback path | -| `VMI-LAYOUT-CONTRACT` | layout conversion 或 sink 未被 target registry 支持 | -| `VMI-FALLBACK-RESOURCE` | scratch、guard、index buffer 或 fallback index width 资源不可用 | -| `VMI-TENSOR-BOUNDARY` | tensor transfer 必须在进入 VMI 前 bufferize 或退出 PTO 路线 | -| `VMI-DEBUG-BOUNDARY` | debug op 必须在进入 VMI 前消费、剥离或退出 PTO 路线 | -| `VMI-PASS-INVARIANT` | pipeline hard gate 被破坏,例如 hidden side table、残留 conversion cast 或 layout 缺失 | -| `VMI-RESIDUAL-OP` | physicalization 后仍有非法 VMI op/type 或 helper | - -diagnostic payload 至少包含 source op、semantic reason、failed contract、available paths、 -missing capability 或 disabled fallback option。 - -## Implementation Plan - -具体文件布局、Slice 切分、ODS/type/op/pass/test 落地步骤见 -`docs/designs/vmi-implementation-manual.md`。本节只保留高层任务顺序。 - -1. 定义 `!pto.vmi.vreg`、`!pto.vmi.vreg`、 - `!pto.vmi.mask`、`!pto.vmi.mask`。 -2. 定义 layout 目录:`#pto.vmi.layout`、 - `#pto.vmi.layout`、 - `#pto.vmi.layout`, - 并实现统一 lane-map / physical-arity helper。 -3. 定义 VMI semantic op families:construction、memory、arith、conversion、mask、 - permutation、active-prefix、compress/expand、channel split/merge、reduction/scan/contract。 -4. 实现 VMI producer boundary verifier,禁止 producer 直接生成 physical VPTO 或依赖 hidden state。 -5. 实现 `vmi-layout-assignment`,包含 op transfer function、cost model、mask granularity - conversion、control-flow join。 -6. 实现 VMI memory lowering:access plan、safe-read/write proof、tile padding materialization、 - transfer mask coordinate remap、masked/guarded/scratch fallback。 -7. 实现 `vmi-to-vpto` 1:N type conversion,包含 `pack/unpack` materialization 和 structural - conversion。 -8. 加 target element-type / layout-sink / ISA contract / fallback resource registry。 -9. 加 VMI hard gate verifier:覆盖 VMI producer boundary、`vmi-layout-assignment`、 - `vmi-to-vpto` 后的残留 op/type、layout、mask granularity、conversion cast 和 hidden-state - invariant。 -10. 加 VMI diagnostic code registry 和 lit tests。 - -## Test Checklist - -1. `vmi_f16_ext_add_store_deinterleaved2.mlir` - - `extf` 后 result 是 `vreg<128xf32, deinterleaved=2>`,store 保持 contiguous logical order。 -2. `vmi_f8_ext_add_store_deinterleaved4.mlir` - - `deinterleaved=4` p0/p1/p2/p3 不被误写成 contiguous memory。 -3. `vmi_non_full_tile_padding_lanes.mlir` - - `vreg<100xf32>` padding lane 不可观察。 -4. `vmi_mask_granularity_width_change.mlir` - - surface `mask` 被不同 width consumer 使用时,正确生成 `mask` / - `mask` 并保持 data layout。 -5. `vmi_control_flow_layout_join.mlir` - - `scf.if/scf.for` layouted VMI type join 稳定。 -6. `vmi_tile_read_padding_safe_footprint.mlir` - - full physical load unsafe 时不偷读 invalid lane。 -7. `vmi_block_strided_rows_vsldb.mlir` - - `tile_read/tile_write` 识别 32B block rows,并拒绝 per-lane mask direct path。 -8. `vmi_active_prefix_index_compress.mlir` - - arbitrary mask compaction 使用 logical prefix order。 -9. `vmi_extract_boundary.mlir` - - scalar extract 输出 `VMI-SCALAR-EXTRACT-BOUNDARY`。 -10. `vmi_channel_split_merge_semantic_op.mlir` - - interleaved channel data 按用户语义拆成多个 VMI values,再通过 merge 写回。 -11. `vmi_producer_boundary.mlir` - - producer boundary 后只有 VMI semantic op/type,不出现 physical VPTO 或 hidden-state 依赖。 -12. `vmi_mask_threading.mlir` - - region-style mask 被 thread 到 masked VMI op 或 `vmi.select` merge,不残留 region mask。 -13. `vmi_gather_scatter_memory_semantics.mlir` - - inactive gather/scatter lane 不读写内存,scatter duplicate-index case 不走非法 direct path。 -14. `vmi_reduction_scan_contract_coverage.mlir` - - reduction/scan/contract 不回退成 residual logical-vector op,按 VMI lowering contract 处理。 -15. `vmi_cf_and_call_layout_boundary.mlir` - - `cf.br/cond_br` block arguments 和 internal call signatures 选择稳定 layout,external ABI 不泄露 layout。 -16. `vmi_iota_bitcast_insert_extract_coverage.mlir` - - lane index、bitcast、vector-result extract-like 和 insert-like 语义都有 VMI 承接。 -17. `vmi_memory_view_normalization.mlir` - - producer-specific vector element view 先规范化为 element view 和 access plan。 -18. `vmi_debug_boundary.mlir` - - debug-only op 不进入 VMI;未被 producer 消费时输出 `VMI-DEBUG-BOUNDARY`。 -19. `vmi_arith_numeric_contract.mlir` - - VMI arithmetic constant、fastmath、cmp predicate、integer signedness/overflow flags 保真。 -20. `vmi_shape_broadcast_semantics.mlir` - - `shape_cast/reshape` 只改 explicit op shape attrs,`transpose/flat_transpose` 和非 scalar broadcast 保持 lane map 语义且不依赖 shape side table。 -21. `vmi_physical_arity_non_full_deinterleaved.mlir` - - 非整 tile 下 `contiguous/deinterleaved=2/4` 的 physical value 个数和 valid lane map 一致。 -22. `vmi_ensure_layout_materialization_contract.mlir` - - `ensure_layout` 保持 logical lane 值,`deinterleaved=4` 只使用 registry 证明过的 materialization path。 -23. `vmi_tile_read_padding_decision_tree.mlir` - - safe full-read + non-all-true valid mask 生成 padding materialization + select;unsafe path 不读 invalid address。 -24. `vmi_tile_write_oob_no_effect.mlir` - - `tile_write` 的 writeMask=false lane 没有 memory effect,不被 lower 成 predicate-ignored store。 -25. `vmi_transfer_mask_coordinate_remap.mlir` - - non-minor-identity `tile_read/tile_write` 的 explicit mask 先映射到 result/source logical lane。 -26. `vmi_tile_read_vector_element_padding.mlir` - - vector-element padding 按 suffix coordinate 展开,invalid lane 使用对应 padding element。 -27. `vmi_index_address_legalization.mlir` - - `vreg`、gather/scatter indices、active-prefix offset 使用 element units 且宽度合法。 -28. `vmi_fallback_resource_diagnostics.mlir` - - scratch、guarded fallback、index-buffer fallback 缺资源时输出 `VMI-FALLBACK-RESOURCE`。 -29. `vmi_tensor_transfer_boundary.mlir` - - tensor transfer-style producer op 不伪装成 VMI memory op,未 bufferize 时输出 `VMI-TENSOR-BOUNDARY`。 -30. `vmi_pipeline_hard_gates.mlir` - - 各 pass 边界拒绝残留 VMI helper/unrealized cast/hidden state,且 final lowering 不残留 VMI op/type。 -31. `vmi_layout_assignment_constraint_solver.mlir` - - 多 use、rematerializable producer、control-flow join、layout conversion cost 冲突时选择稳定 layout 或输出精确 diagnostic。 diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index f86a812f05..e9790f56ea 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -1,6 +1,6 @@ # VMI 实现手册 -本文是 `docs/designs/vmi-dialect-design.md` 的落地手册。设计文档回答“为什么这样设计”,本文回答 +本文配套 `docs/designs/vmi-introduction.md` 和当前 VMI lowering 设计,回答 “按什么顺序改哪些文件、每一步做到什么程度才算完成”。 本文不替代最终 ODS / C++ verifier / lit 测试。实现时如果发现本文和 ODS 或 verifier 冲突,以 @@ -1313,7 +1313,7 @@ truncf f32 -> f16: truncf f32 -> fp8-like: result natural layout = contiguous -store/tile_write: +store: consumer requests contiguous externally visible order ``` @@ -1485,7 +1485,7 @@ Data natural layout: pto.vmi.channel_merge with C inputs: result natural = deinterleaved=C Data use request: - pto.vmi.store/tile_write: value requested as contiguous + pto.vmi.store: value requested as contiguous pto.vmi.channel_split with C results: source requested as deinterleaved=C op requiring a common operand/result layout: request producer class layout @@ -1587,7 +1587,7 @@ allowed: pto.vmi.create_mask not allowed in the first implementation: - load/tile_read + load arithmetic result conversion result shuffle/channel_split/channel_merge result @@ -1904,7 +1904,7 @@ layout-producing conversion: extf, truncf, bitcast externally ordered memory: - load, store, tile_read, tile_write + load, store value-indexed accumulation: dhist, chist @@ -1934,7 +1934,7 @@ vmi.store of deinterleaved=2: or materialize source to contiguous before physical store ``` -Therefore `store/tile_write` lowering must either: +Therefore `store` lowering must either: ```text 1. consume contiguous layout directly, or @@ -1960,9 +1960,6 @@ vmi.masked_load: vmi.store: materialize assigned source layout -> contiguous emit physical vsts chunks in memory order - -vmi.tile_read / vmi.tile_write: - follow the same externally ordered rule ``` Current direct memory lowering may only emit VPTO vector memory ops for @@ -1998,11 +1995,11 @@ current direct path supports this limited proof: ```text source is a statically shaped memref -offset is a constant non-negative index, or tile_read implicit offset 0 +offset is a constant non-negative index offset + physical_arity(result) * lanes_per_physical_part <= static memref element count ``` -When this proof holds, `vmi.load` / `vmi.tile_read` may still issue full `pto.vlds` chunks. The extra padding lanes are +When this proof holds, `vmi.load` may still issue full `pto.vlds` chunks. The extra padding lanes are not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, dynamic offsets, dynamic memrefs, and insufficient static footprints remain unsupported: @@ -2014,7 +2011,7 @@ pto.vlds/pto.vsts and requires UB-backed memory) ``` Store-style ops are different because inactive lanes can be made write-free with true predicates. `vmi.store`, -`vmi.masked_store`, and `vmi.tile_write` therefore support the explicit contiguous/deinterleaved tail-store +`vmi.masked_store` therefore support the explicit contiguous/deinterleaved tail-store materialization paths described below. ## 2. Slice 0: Type / Attr Bootstrap @@ -2218,8 +2215,6 @@ pto.vmi.store pto.vmi.masked_store pto.vmi.scatter pto.vmi.compress_store -pto.vmi.tile_read -pto.vmi.tile_write ``` Value-indexed accumulation: @@ -2304,13 +2299,15 @@ mask granularity must match selected element width after layout assignment source/result lane count equal source/result element types are float bitwidth changes in the expected direction +truncf rounding attr, when present, must be A/H and currently only applies to + f32 -> !pto.hif8 ``` Memory op verifier: ```text -load/tile_read memory element type must match result VMI data element type when the source is PtrType or MemRefType -store/tile_write memory element type must match stored VMI data element type when the destination is PtrType or MemRefType +load memory element type must match result VMI data element type when the source is PtrType or MemRefType +store memory element type must match stored VMI data element type when the destination is PtrType or MemRefType ``` Histogram op verifier: @@ -2524,6 +2521,9 @@ truncf f32 -> fp8-like: result. This mirrors the hardware packed-4 contract: each source part owns one quarter of the destination byte lanes, so the final externally visible vector remains logical lane order 0..N-1 after the merge. + default round mode is result-type specific: f8E4M3/f8E5M2 use rnd=R, hif8 + uses rnd=A. hif8 may explicitly request hybrid lowering with + pto.vmi.truncf {rounding = "H"}, which forwards rnd=H to every packed part. bitcast: source and result layouts must match @@ -2537,12 +2537,12 @@ bitcast: result logical bits. group_slots bitcast is unsupported until a slot-wise bitcast contract is defined. -load/tile_read: +load: baseline result layout is deterministic from explicit layout attrs or the producer natural layout; consumer-specific alternatives are represented by ensure_layout and optimized later -store/tile_write: +store: baseline requests contiguous source layout current implementation records a contiguous use-site request for vmi.store and inserts pto.vmi.ensure_layout when the stored value class solved to a @@ -2599,7 +2599,7 @@ implemented: extf source -> contiguous use-site request for supported f16/fp8-like to f32 paths truncf f32->f16 source -> deinterleaved=2 use-site request truncf f32->fp8-like source -> deinterleaved=4 use-site request - single-use pto.vmi.load / tile_read results can adopt a consumer-requested + single-use pto.vmi.load results can adopt a consumer-requested layout before type rewrite; this covers direct memory producers such as load -> truncf without inserting a redundant ensure_layout vmi.store data operand -> contiguous use-site request @@ -3246,10 +3246,12 @@ pto.vmi.truncf, direct path: result part; converted padding lanes remain result padding support f32 deinterleaved=4 source parts -> 8-bit contiguous result part materialize pto.pset_b32 "PAT_ALL" for the source conversion - emit pto.vcvt(p0_f32_part, mask, rnd=R, sat=SAT, part=P0) - emit pto.vcvt(p1_f32_part, mask, rnd=R, sat=SAT, part=P1) - emit pto.vcvt(p2_f32_part, mask, rnd=R, sat=SAT, part=P2) - emit pto.vcvt(p3_f32_part, mask, rnd=R, sat=SAT, part=P3) + emit pto.vcvt(p0_f32_part, mask, rnd=, sat=SAT, part=P0) + emit pto.vcvt(p1_f32_part, mask, rnd=, sat=SAT, part=P1) + emit pto.vcvt(p2_f32_part, mask, rnd=, sat=SAT, part=P2) + emit pto.vcvt(p3_f32_part, mask, rnd=, sat=SAT, part=P3) + result round is R for f8E4M3/f8E5M2, A for default hif8, or H for + hif8 truncf with {rounding = "H"} materialize pto.pset_b8 "PAT_ALL" merge mutually exclusive part results with pto.vor partial/tail is valid when the four source parts pack into one physical @@ -3552,13 +3554,13 @@ Unsupported diagnostics: non-splat pto.vmi.constant: VMI-UNSUPPORTED: non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan - partial/tail pto.vmi.load/tile_read: + partial/tail pto.vmi.load: VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read footprint (...; safe-read proof failed: ...) - GM-backed direct pto.vmi.load/masked_load/expand_load/tile_read: + GM-backed direct pto.vmi.load/masked_load/expand_load: VMI-UNSUPPORTED: pto.vmi. ... (source is GM-backed, but current direct VMI-to-VPTO memory lowering emits pto.vlds/pto.vsts and requires UB-backed memory) - unsupported partial/tail pto.vmi.store/masked_store/tile_write: + unsupported partial/tail pto.vmi.store/masked_store: VMI-UNSUPPORTED: pto.vmi. requires an 8/16/32-bit predicate-maskable element type and either full physical chunks or contiguous/deinterleaved tail-store materialization, with UB-backed destination; unsupported cases include values such as f64/index that have no b64 predicate representation, GM-backed destinations that @@ -3628,6 +3630,16 @@ f32 -> f16: pto.vor merges mutually exclusive f16 part results into one contiguous vreg source/result physical arity must be 2 -> 1 current default conversion attrs are rnd=R, sat=SAT + +f32 -> 8-bit fp-like: + supported direct path when source is deinterleaved=4 and result is contiguous: + pto.vcvt part=P0/P1/P2/P3 consumes the four source partitions + pto.vor merges mutually exclusive byte-lane part results into one + contiguous vreg + source/result physical arity must be 4 -> 1 + current default conversion attrs are rnd=R for f8E4M3/f8E5M2 and rnd=A for + hif8. pto.vmi.truncf {rounding = "H"} is accepted only for f32 -> hif8 + and forwards rnd=H to the emitted pto.vcvt operations. ``` Memory lowering: @@ -3698,6 +3710,44 @@ vmi.masked_load: unsafe partial/tail read footprints target true masked/non-faulting load and guarded/scratch fallback +vmi.stride_load: + semantics: + result lane order is contiguous VMI logical order + source addresses are described by the VPTO block/repeat stride operands + mask false lanes are inactive for the underlying block-strided load + layout assignment: + result natural layout is contiguous + mask use is requested as contiguous with granularity derived from result element width + current direct path: + source must be !pto.ptr + result and mask must be one contiguous physical chunk + base = pto.addptr source, offset + result = pto.vsldb base, block_stride, repeat_stride, mask + unsupported cases: + multi-chunk result or mask + non-contiguous layouts + memref/gm source + +vmi.stride_store: + semantics: + value lane order is contiguous VMI logical order + destination addresses are described by the VPTO block/repeat stride operands + mask false lanes do not write memory + layout assignment: + value use is requested as contiguous + mask use is requested as contiguous with granularity derived from value element width + current direct path: + destination must be !pto.ptr + value and mask must be one contiguous physical chunk + base = pto.addptr destination, offset + updated_base = pto.vsstb value, base, block_stride, repeat_stride, mask + The updated base result is intentionally unused by VMI lowering, but the + post-update VPTO form matches CCE block-strided staging behavior. + unsupported cases: + multi-chunk value or mask + non-contiguous layouts + memref/gm destination + vmi.gather: semantics: if mask[lane] is true, result[lane] = memory[base + indices[lane]] @@ -3709,15 +3759,27 @@ vmi.gather: mask use is requested as contiguous with granularity derived from result element width current direct path: source must be !pto.ptr - T must be a 32-bit element type - indices must be signless or unsigned i32 - result / indices / passthru / mask must be contiguous full physical chunks - mask granularity must be b32 - for each physical chunk i: - gathered_i = pto.vgather2_bc source, indices_i, mask_i - result_i = pto.vsel gathered_i, passthru_i, mask_i + supported 32-bit mode: + T must be a 32-bit element type + indices must be signless or unsigned i32 + result / indices / passthru / mask must be contiguous full physical chunks + mask granularity must be b32 + for each physical chunk i: + gathered_i = pto.vgather2_bc source, indices_i, mask_i + result_i = pto.vsel gathered_i, passthru_i, mask_i + supported ui16 mode: + T must be ui16 + indices must be unsigned i16 + result / indices / passthru / mask must be one contiguous physical chunk + mask granularity must be b16 + gathered = pto.vgather2 source, indices, mask + result = pto.vsel gathered, passthru, mask + VPTO LLVM emitter bitcasts the physical index register from <128xi16> + to the installed Bisheng intrinsic ABI <64xi32>; this is the same + 256B register payload viewed as the wrapper-level vector_u16 index + container. reason for vsel: - VGATHER2_BC false predicate lanes do not read memory but produce zero; VMI false lanes preserve passthru. + VPTO gather false predicate lanes do not read memory but produce zero; VMI false lanes preserve passthru. unsupported cases: f16/b16/f8/i8 result element types partial/tail chunks @@ -3808,49 +3870,18 @@ vmi.masked_store: emitting stores. This preserves logical memory order and keeps inactive lanes write-free. non-full chunks: - vmi.store, vmi.masked_store, and vmi.tile_write support contiguous tail chunks by predicating the final pto.vsts with + vmi.store and vmi.masked_store support contiguous tail chunks by predicating the final pto.vsts with a prefix valid mask. masked_store additionally ANDs the user mask with the tail-valid mask. - deinterleaved=2/4 tail store/masked_store/tile_write is supported only through explicit layout materialization to + deinterleaved=2/4 tail store/masked_store is supported only through explicit layout materialization to contiguous chunks first. This requires every deinterleaved part to have the same physical chunk count, so the materializer can build complete vintlv/pintlv groups. After materialization, each contiguous chunk is predicated by the logical tail-valid mask; chunks whose active logical lane count is zero are not emitted as stores. Uneven deinterleaved groups, such as 129xf32 with deinterleaved=2, remain unsupported until a padding/scratch plan can assemble only the observable contiguous chunks. - vmi.load and tile_read support partial/tail chunks only when the direct full physical read is statically safe: - statically shaped memref source, constant non-negative offset (or tile_read offset 0), and enough elements for the + vmi.load support partial/tail chunks only when the direct full physical read is statically safe: + statically shaped memref source, constant non-negative offset, and enough elements for the whole physical read footprint. Padding lanes must never become observable. Other partial/tail load cases still need scratch/guarded/true-masked load planning. - -vmi.tile_read / vmi.tile_write, current direct full-footprint path: - This is not transfer_read padding lowering. It is only the tile/memref equivalent of the full-chunk direct memory - path above. - - tile_read: - source must lower to one VPTO buffer-like value. - logical lane count must be an exact multiple of the physical lanes per part. - use offset 0 as the tile base offset. - contiguous result layout reads physical chunks with pto.vlds. - deinterleaved=2 result layout prefers pto.vldsx2 "DINTLV_B8/B16/B32" with offset 0. - other supported layouts materialize the requested result layout after contiguous reads. - - tile_write: - destination must lower to one VPTO buffer-like value. - use offset 0 as the tile base offset. - value element width must be 8, 16, or 32 bits so pto.vsts/pto.vstsx2 can receive a materialized predicate. - contiguous source layout stores every physical chunk with pto.vsts and an all-true mask. - if the final contiguous chunk is partial, store it with a prefix valid-lane mask. - deinterleaved=2 source layout prefers pto.vstsx2 "INTLV_B8/B16/B32" with offset 0. - other supported layouts materialize the source value to contiguous layout first. - deinterleaved=2/4 tail source layouts are supported through this materialization path only when every - deinterleaved part has the same physical chunk count; zero-active materialized chunks are skipped. - - Unsupported: - padding value semantics - partial/tail tile footprints - transfer_read-style out-of-bounds reads - write masks - non-identity tile indexing/permutation - any path that would expose padding lanes or reorder externally visible memory ``` Histogram lowering: @@ -3958,9 +3989,9 @@ Slice 4 完成条件: vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto, vmi_to_vpto_expand_load_all_active.pto, vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto, and multi-chunk load/store layout tests. -4. Full-footprint tile_read/tile_write direct path lowers through pto.vlds/pto.vsts or deinterleaved=2 x2 dist +4. Full-footprint load/store direct path lowers through pto.vlds/pto.vsts or deinterleaved=2 x2 dist instructions with offset 0. - Covered by vmi_to_vpto_tile_read_write.pto. + Covered by the load/store direct-path and layout-folding tests. 5. Internal func.call boundaries expand callee signatures, call operands/results, and returned VMI values together. Covered by vmi_layout_assignment_call_boundary.pto, vmi_layout_assignment_indirect_call_invalid.pto, and vmi_to_vpto_call_boundary.pto. @@ -4004,11 +4035,12 @@ Slice 4 完成条件: vmi_to_vpto_chist_semantics_invalid.pto. ``` -## 7. Slice 5: Tile Memory And Padding +## 7. Slice 5: Memory Padding -The Slice 4 direct path may lower full-footprint `tile_read/tile_write` with offset 0. For partial `tile_read`, it may -also lower to plain `pto.vlds` only when the static safe-read proof above succeeds. Do not lower any other partial or -padded `tile_read` as a plain load until a richer access plan proves it is safe. +The Slice 4 direct path may lower full-footprint `load/store` when the +physical memory footprint is statically safe. Do not lower any partial, +padded, or out-of-bounds read-like operation as a plain `pto.vlds` until a +richer access plan proves it is safe. Implement an internal `VMIMemoryAccessPlan`: @@ -4050,8 +4082,8 @@ currently routed through the plan: direct pto.vmi.load partial/tail safe full-read proof pto.vmi.masked_load partial/tail safe full-read proof pto.vmi.expand_load static all-active safe full-read proof - VMI-to-VPTO rewrite match guard for load/tile_read full-or-safe reads - pto.vmi.store/tile_write direct write target decision with all-true writeMask kind + VMI-to-VPTO rewrite match guard for load full-or-safe reads + pto.vmi.store direct write target decision with all-true writeMask kind pto.vmi.masked_store direct write target decision with explicit writeMask kind unsafe partial/tail read fallback decision as RequiredUnavailable diagnostic covered by vmi_to_vpto_load_nonfull_invalid.pto, @@ -4087,7 +4119,7 @@ unless it has already canonicalized to an all-valid load/masked_load subset whose invalid lanes are proven absent. ``` -`tile_read` decision tree: +Read-like memory decision tree: ```text safeReadProof full && validMask all true: @@ -4106,7 +4138,7 @@ otherwise: future: split safe regions, scratch fill/copy/load, guarded fallback, or diagnostic ``` -`tile_write` decision tree: +Write-like memory decision tree: ```text writeMask all true && full footprint safe-writable: @@ -4239,12 +4271,12 @@ vmi_layout_assignment_mask_remat.mlir vmi_to_vpto_deinterleaved2.mlir vmi_to_vpto_deinterleaved4.mlir vmi_to_vpto_compaction_deint_invalid.mlir -vmi_to_vpto_non_full_tile.mlir +vmi_to_vpto_load_safe_tail_memref.mlir +vmi_to_vpto_masked_load_safe_tail_memref.mlir +vmi_to_vpto_store_tail.mlir vmi_to_vpto_dhist.mlir vmi_to_vpto_dhist_tail_mask.mlir vmi_to_vpto_chist_semantics_invalid.mlir -vmi_tile_read_padding.mlir -vmi_tile_write_mask.mlir vmi_pipeline_hard_gates.mlir ``` @@ -4276,7 +4308,7 @@ Recommended merge order: 6. vmi-to-vpto type conversion + pack/unpack/unpackable block args. 7. deinterleaved=2 f16 widen end-to-end. 8. deinterleaved=4 f8 widen end-to-end. -9. tile_read/tile_write padding-safe lowering. +9. load/store padding-safe lowering. 10. remaining semantic op families. ``` @@ -4408,10 +4440,10 @@ layout-changing producer: memory consumer/producer: examples: - load/store/tile_read/tile_write + load/store/load/store layout rule: - load/tile_read result natural layout is chosen by memory dist capability - store/tile_write value operand requests the layout that memory dist can consume + load result natural layout is chosen by memory dist capability + store value operand requests the layout that memory dist can consume lowering rule: direct path only when every physical chunk has no padding lane and footprint is safe diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index e7161dc4a0..8e884bf677 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -378,6 +378,10 @@ group_load / group_slot_load: result 根据 group size、row stride 和目标能力选择 contiguous、deinterleaved 或 group_slots。 +stride_load: + result 是 contiguous。block/repeat stride 只描述 memory address map, + 不改变 register 内 logical lane order。 + active_prefix_index: result 使用 contiguous。 ``` @@ -388,7 +392,7 @@ active_prefix_index: “这个 use site 希望 operand 是什么 layout”: ```text -store / tile_write / masked_store value: +store / masked_store value: wants contiguous ordinary reduce source/init: @@ -400,6 +404,10 @@ group_reduce source: group_store value: wants preferred group result layout +stride_store value: + wants contiguous。block/repeat stride 只描述 memory write address map, + 不表示 source vreg 是 sparse 或 NZ layout。 + truncf/trunci/extf/extsi/extui source: wants cast support 给出的 source layout @@ -415,7 +423,7 @@ use request:如果 operand 的 producer 可以直接用 consumer 需要的 lay 可采纳 producer 是受限集合: ```text -load / tile_read +load broadcast / constant / iota layout-transparent elementwise select @@ -575,6 +583,7 @@ constant_mask ```text load / masked_load / group_load / group_slot_load +stride_load reduce / group_reduce control-flow results ``` @@ -945,6 +954,63 @@ selector、lo/hi accumulator 和多条物理指令。 在 high range 上返回的是全局累计还是 range-local 累计。这个差异会影响是否需要 额外给 high half 加上 low half 的总计数,因此不能只按 op 名字猜 lowering。 +### 4.9 Block-Strided UB Staging + +有些 CCE kernel 并不是在 register 内做任意 byte shuffle,而是先把结果写到 +UB scratch,再用 block-strided vector load/store materialize 目标 UB layout。 +`quant_minimum` 的 MXFP8 NZ case 是典型例子: + +```text +compute: + row-major ND FP8 scratch + +row-wise staging: + for row in 0..31: + q8_row = vmi.stride_load(nd + row * 64, + block_stride=1, repeat_stride=1) + vmi.stride_store(q8_row, nz + row * 32, + block_stride=33, repeat_stride=1) + +copy-out: + 2D MTE copies two 1024B NZ planes from UB to GM +``` + +这里 `q8_row` 的 VMI value 仍然是 contiguous `64xf8` 逻辑向量: + +```mlir +%q8_row = pto.vmi.stride_load %nd[%nd_off], %c1_i16, %c1_i16, %mask + : !pto.ptr, i16, i16, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xf8E4M3FN> + +pto.vmi.stride_store %q8_row, %nz[%nz_off], %c33_i16, %c1_i16, %mask + : !pto.vmi.vreg<64xf8E4M3FN>, !pto.ptr, i16, i16, + !pto.vmi.mask<64xpred> +``` + +Assignment 形状: + +```text +stride_load result = contiguous +stride_load mask = contiguous, granularity follows result element width +stride_store value = contiguous +stride_store mask = contiguous, granularity follows value element width +``` + +VPTO 形状: + +```text +base_in = pto.addptr nd, nd_off +q8_row = pto.vsldb base_in, block_stride=1, repeat_stride=1, mask + +base_out = pto.addptr nz, nz_off +updated = pto.vsstb q8_row, base_out, block_stride=33, repeat_stride=1, mask + -> updated_base +``` + +这个场景说明:memory layout transformation 不一定要变成 VMI data layout。 +只要 VMI op 的语义是“从哪些地址读/写哪些 logical lane”,register value +仍然可以保持 contiguous,`vmi-to-vpto` 也仍然是 local lowering。 + ## 5. 当前边界 当前设计方向: @@ -973,15 +1039,30 @@ packed group_slots f32->f16 cast: 非法,除非 assignment 能把它 commute 到 group_broadcast 之后,或者使用 支持的 row-local slots=1 path。 +FP4 packed input/output: + packed FP4 不属于当前 VMI surface。PTO/VPTO 已有 !pto.f4E1M2x2 + 和 !pto.f4E2M1x2 packed 物理类型,且这些类型的 shape 语义是 + packed pair/byte 数,不是 logical FP4 lane 数。在 VMI 中直接写 + vreg 会让 N 表示物理 packed byte 还是逻辑 FP4 元素 + 产生歧义,因此 verifier 会直接拒绝 + vmi.vreg<...x!pto.f4E1M2x2/!pto.f4E2M1x2>。 + + 当前 VMI surface 不包含专用 FP4 packed-memory op。FP4 packed IO + 需要先作为独立语义重新设计,不能进入当前 dialect surface。 + extract: 暂不作为支持的 VMI surface。 padding transfer_read: 当前 tail 设计不需要;tail 使用 mask。 -scan / contract / gather / scatter / compress / active_prefix_index: +scan / contract / compress / active_prefix_index: dialect surface 中可以存在,但除非补充具体 case,否则不属于第一阶段聚焦的 layout/lowering 实现集合。 + +gather / scatter: + 当前只覆盖 UB pointer、contiguous layout 和已明确支持的 element/index 宽度。 + `ui16` gather 可承接 E8M0 byte-pair reorder;它不是通用 byte shuffle。 ``` 设计目标是优先保证语义完整:只要 VMI 接受某个 case,所需的 layout 沟通就必须 diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index fcdf7fe292..eb634ede79 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -4,7 +4,7 @@ `vmi-layout-assignment-lowering-design.md`,并以 `vmi-layout-lowering-cases.md` 为测试和验收来源。 -不使用旧 `vmi-dialect-design.md` 作为设计输入。 +不使用早期 VMI 草稿作为设计输入。 ## 1. Pipeline @@ -56,8 +56,8 @@ vmi-layout-fold-consumers: example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become a store of deinterleaved=2 when the store has a layout-aware vstsx2 INTLV lowering - current implementation: pto.vmi.store, pto.vmi.tile_write, and the value - operand of pto.vmi.masked_store when the existing mask arity matches, fed by + current implementation: pto.vmi.store and the value operand of + pto.vmi.masked_store when the existing mask arity matches, fed by ensure_layout from deinterleaved=2/4, block_elems=1 to contiguous. factor=2 uses the store's vstsx2 INTLV lowering; factor=4 is still store-local, but it materializes through physical interleave before vsts. @@ -100,7 +100,7 @@ pto-validate-vmi-layout-ir: `ensure_mask_granularity` at the layout gate, so unsupported helper materializations fail before `vmi-to-vpto`. It also checks the first semantic local lowering families, non-contiguous - `pto.vmi.store`/`pto.vmi.tile_write`, block8 + `pto.vmi.store`, block8 `pto.vmi.group_load`, `pto.vmi.group_slot_load`, group_slots `pto.vmi.group_store`, group_slots `pto.vmi.group_reduce_add{f|i}`, explicit-slots `pto.vmi.group_broadcast`, `pto.vmi.truncf`, @@ -529,7 +529,7 @@ contains enough information for `vmi-to-vpto`. Implementation-relevant layout facts: ```text -dense store/tile_write: +dense store: requests contiguous source. If the value is assigned deinterleaved, assignment inserts ensure_layout at the store use. A later optimization may fold ensure_layout + store into a layout-aware store lowering. @@ -709,7 +709,7 @@ Memory legality constraints: ```text S=32 tail fast load: - requires full_tile_readable + requires full_footprint_readable otherwise require gather fallback or diagnose compact S=12 logical S=16: @@ -1054,7 +1054,7 @@ vmi-to-vpto contract: diagnostic family builder / owner required failure 3.7.4 slots=1 unit-stride store buildStoreRequests no aligned row-local store path 3.9 dense store of group slots buildStoreRequests use group_store/group_broadcast -3.11.2 S=32 unsafe tail buildMaskRequests missing full_tile_readable/gather +3.11.2 S=32 unsafe tail buildMaskRequests missing full_footprint_readable/gather 3.13 slots=8 width cast buildCastRequests no packed slot cast transform 3.14 unsupported group size buildGroupReduceRequests no supported reduce layout/lowering 3.15.3 compact S=12 buildGroupMemoryRequests no compact gather plan @@ -1316,13 +1316,10 @@ truncf group-slot cast: group_store: row-local group_slots(G, slots=1) lowering is implemented as one lane-0 - vsts per group and is covered by the reduce->truncf->group_store lit case. - The current plan is accepted only when row_stride is a constant positive - multiple of the 32B store alignment in destination elements: 8 for f32, - 16 for f16, and 32 for f8. Unit-stride f32 output is rejected because only - the first row-local store is 32B-aligned; later `group_off + r` stores are - 4B apart. A future pack-to-slots=8 or unaligned-store lowering is required before - contiguous `%c1` slots=1 group_store can be accepted. + vsts per group for packed unit-stride output, or as one 1PT store per group + for non-unit row strides. The packed path is covered by the + reduce->truncf->group_store lit case, while the point-store path is covered + by `test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto`. Packed group_slots(G, slots=8) group_store is implemented only when num_groups is a multiple of 8 and row_stride is constant 1; it emits one PAT_VL8 store per packed slot block. Non-unit packed group stores remain a @@ -1449,7 +1446,7 @@ not allowed: walking from a consumer to a producer to decide a lowering walking from a consumer to a mask producer to decide whether a lowering is legal inspecting users to choose a result layout or materialization - recovering full_tile_readable from surrounding MTE/caller context + recovering full_footprint_readable from surrounding MTE/caller context ``` Current audit result: @@ -1624,7 +1621,7 @@ Aggregate catalog headings are covered through their endpoint subcases: ```text 3.11 partial tail groups: 3.11.1 positive S=64 active-row tail - 3.11.2 diagnostic S=32 tail without full_tile_readable + 3.11.2 diagnostic S=32 tail without full_footprint_readable 3.15 compact S=12 written as logical S=16: 3.15.1 positive source row stride 16 @@ -1928,10 +1925,10 @@ runtime SIM: test/vpto/cases/vmi/group-reduce-s64-tail-store ``` -The companion negative lit case for contiguous `%c1` slots=1 group_store is: +The companion lit case for non-unit slots=1 point-store lowering is: ```text -test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto ``` Current checked-in coverage for S=64 row-local group-slot RHS elementwise @@ -2130,7 +2127,7 @@ Diagnostic-only cases: ```text 3.9 dense store of group slots -3.11.2 S=32 tail without full_tile_readable +3.11.2 S=32 tail without full_footprint_readable 3.7.4 S=64 slots=1 group_store with unit output stride 3.13 packed group-slot f32 -> f16 cast 3.14 unsupported group size @@ -2165,7 +2162,7 @@ lit: test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto - test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto + test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto test/lit/vmi/vmi_ptoas_public_abi_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 497f6cad8c..40edc656ad 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -1,8 +1,8 @@ # VMI Layout Assignment And Lowering Design 本文是新的 VMI layout assignment / lowering 设计文档。它只以 -`docs/designs/vmi-layout-lowering-cases.md` 为 source of truth,不继承旧 -`vmi-dialect-design.md` 的 layout 设计,以避免旧上下文污染。 +`docs/designs/vmi-layout-lowering-cases.md` 为 source of truth,不继承早期 +VMI 草稿的 layout 设计,以避免旧上下文污染。 目标: @@ -140,7 +140,7 @@ mask and tail: masked_load grouped tail feeding group_reduce masked select/store one semantic mask used by multiple predicate granularities - S=32 tail with and without full_tile_readable + S=32 tail with and without full_footprint_readable compact S=12 diagnostic strided memory: @@ -187,7 +187,7 @@ control-flow propagation: public ABI rejection memory legality: - full_tile_readable proof, grouped masks, predicate granularity, aligned + full_footprint_readable proof, grouped masks, predicate granularity, aligned strided group memory, stable gather diagnostic value-indexed accumulation: @@ -591,7 +591,7 @@ Mask semantics and memory legality are separate: mask: decides which logical lanes participate in compute/store semantics -full_tile_readable: +full_footprint_readable: decides whether a rounded-up physical load is allowed to read inactive lanes ``` @@ -606,10 +606,10 @@ Example: ```text S=32 tail num_groups=6: - without full_tile_readable: + without full_footprint_readable: fast DINTLV_B32 full-tile load is illegal - with full_tile_readable: + with full_footprint_readable: full 8-row physical tile may be loaded compute mask is PAT_VL48 per physical part group store mask is PAT_VL6 @@ -807,7 +807,7 @@ Hard constraints: group_slots cannot feed ordinary dense consumers direct group-slot width-changing cast requires an explicit slot-preserving transform public/external VMI function boundary requires a stable ABI or diagnostic -S=32 fast tail load requires full_tile_readable or gather fallback +S=32 fast tail load requires full_footprint_readable or gather fallback ``` `slots = 1` row-local cast may satisfy the slot-preserving transform requirement. @@ -936,7 +936,7 @@ The pattern must not: 1. inspect all users to decide result layout 2. inspect defining ops to decide source layout 3. choose between S=16 block_elems=1 and block_elems=8 -4. decide whether a load is full_tile_readable +4. decide whether a load is full_footprint_readable 5. decide function signature specialization ``` @@ -1016,8 +1016,8 @@ dense store of group_slots: packed group-slot f32->f16: group_broadcast before truncf, or keep group_store as f32 -S=32 tail without full_tile_readable: - mark source full_tile_readable or enable stable gather fallback +S=32 tail without full_footprint_readable: + mark source full_footprint_readable or enable stable gather fallback S=32 group_load with unaligned source_group_stride: choose a stride divisible by 8 f32 elements or enable stable gather fallback diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 3fd0c4b7eb..3fbe2bd3f1 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -1496,7 +1496,7 @@ for r = 0..7: = s * s ``` -#### 3.7.4 Unit-Stride Store Is Not A Valid Lowering Yet +#### 3.7.4 Slots=1 Store Lowers To Packed Or Point Stores The row-local S=64 result uses one physical vreg per group with the semantic value in lane 0: @@ -1505,17 +1505,15 @@ value in lane 0: %sum_r lane 0 = reduce(row_r[0..63]) ``` -The current VPTO lowering for `slots = 1` group_store emits one lane-0 `vsts` -per group. Therefore unit-stride f32 output would issue stores at: +The current VPTO lowering for `slots = 1` group_store has two paths. -```text -group_off + 0, group_off + 1, group_off + 2, ... -``` +For unit-stride output where all groups fit in one physical vector, the +lowering packs the lane-0 values into one dense vector and stores that vector +with a normal `vsts`. -Only the first address is necessarily 32B-aligned. The remaining f32 addresses -are 4B apart and are not valid for this `vsts` lowering. The compiler must not -accept this as a clean lowering until either pack-to-slots=8 materialization -support or unaligned-store support exists. +For non-unit row strides, each group stores its lane-0 scalar with a point +store. That emits `vsts` with `dist = "1PT_B32"` for f32 and only requires the +natural 4B alignment of the scalar element. VMI input: @@ -1525,14 +1523,10 @@ VMI input: pto.vmi.group_store %sum, %out[%group_off], %c1 {num_groups = 8} ``` -Required diagnostic: +Current checked-in coverage for the point-store path is: ```text -VMI-LAYOUT-CONTRACT: - pto.vmi.group_store with #pto.vmi.layout lowers - as one lane-0 vsts per group and requires constant positive row_stride - divisible by 8 f32 elements for 32B store alignment. Packed or unaligned - contiguous store lowering is not implemented. +test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto ``` ### 3.8 `group_reduce -> truncf -> group_broadcast -> store` diff --git a/docs/designs/vmi-mxfp8-32x32-expected-lowering.md b/docs/designs/vmi-mxfp8-32x32-expected-lowering.md new file mode 100644 index 0000000000..5130da3ec6 --- /dev/null +++ b/docs/designs/vmi-mxfp8-32x32-expected-lowering.md @@ -0,0 +1,236 @@ +# VMI MXFP8 32x32 Expected VPTO Lowering + +本文记录 `test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto` +的预期 VPTO lower 结果。输入 VMI case 在 `vecscope` 内按 8 行一组循环, +每次处理一个 `256xf32` tile,也就是 8 行 x 32 列。 + +这里写的是设计目标,不是当前 `--emit-vpto` 的实际输出。重点是把 E8M0 +scale 的内存效果写明确:每个 8x32 chunk 产生 8 个 scale byte。lowering +按 CCE 风格先写到 32B 对齐的 padded UB slot,再通过 UB->GM copy 的 +`src_stride=32B, dst_stride=8B` 消除 UB padding,使 GM 端仍然连续。 + +## Complete Expected PTO File + +```mlir +module attributes {pto.backend = "vpto", pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @vmi_tquant_mxfp8_32x32_nd_kernel(%src_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %false = arith.constant false + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + %c4_i32 = arith.constant 4 : i32 + %c5_i32 = arith.constant 5 : i32 + %c6_i32 = arith.constant 6 : i32 + %c7_i32 = arith.constant 7 : i32 + %c8_i32 = arith.constant 8 : i32 + %c23_i32 = arith.constant 23 : i32 + %c24_i32 = arith.constant 24 : i32 + %c40_i32 = arith.constant 40 : i32 + %c48_i32 = arith.constant 48 : i32 + %c56_i32 = arith.constant 56 : i32 + %c254_i32 = arith.constant 254 : i32 + %c2139095040_i32 = arith.constant 2139095040 : i32 + + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_fp8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_fp8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_e8m0 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.copy_gm_to_ubuf %src_gm, %ub_src, %c0_i64, %c1_i64, %c4096_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c4096_i64, %c4096_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.copy_gm_to_ubuf %out_fp8_gm, %ub_out_fp8_u8, %c0_i64, %c1_i64, %c1024_i64, %c0_i64, %c0_i64, %false, %c0_i64, %c1024_i64, %c1024_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i1, i64, i64, i64 + pto.set_flag[, , ] + pto.wait_flag[, , ] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c8 { + %elem_off = arith.muli %row, %c32 : index + %elem_off_64 = arith.addi %elem_off, %c64 : index + %elem_off_128 = arith.addi %elem_off, %c128 : index + %elem_off_192 = arith.addi %elem_off, %c192 : index + + %x0 = pto.vlds %ub_src[%elem_off] : !pto.ptr -> !pto.vreg<64xf32> + %x1 = pto.vlds %ub_src[%elem_off_64] : !pto.ptr -> !pto.vreg<64xf32> + %x2 = pto.vlds %ub_src[%elem_off_128] : !pto.ptr -> !pto.vreg<64xf32> + %x3 = pto.vlds %ub_src[%elem_off_192] : !pto.ptr -> !pto.vreg<64xf32> + + %d0, %d1 = pto.vdintlv %x0, %x1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d2, %d3 = pto.vdintlv %x2, %x3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d4, %d5 = pto.vdintlv %d0, %d2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %d6, %d7 = pto.vdintlv %d1, %d3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %all_b32 = pto.pset_b32 "PAT_ALL" : !pto.mask + %slot8_b32 = pto.pge_b32 "PAT_VL8" : !pto.mask + %vl8_b32 = pto.pset_b32 "PAT_VL8" : !pto.mask + %vl16_b32 = pto.pset_b32 "PAT_VL16" : !pto.mask + %vl24_b32, %unused24 = pto.plt_b32 %c24_i32 : i32 -> !pto.mask, i32 + %vl32_b32 = pto.pset_b32 "PAT_VL32" : !pto.mask + %vl40_b32, %unused40 = pto.plt_b32 %c40_i32 : i32 -> !pto.mask, i32 + %vl48_b32, %unused48 = pto.plt_b32 %c48_i32 : i32 -> !pto.mask, i32 + %vl56_b32, %unused56 = pto.plt_b32 %c56_i32 : i32 -> !pto.mask, i32 + + %abs0 = pto.vabs %d4, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs1 = pto.vabs %d6, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs2 = pto.vabs %d5, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %abs3 = pto.vabs %d7, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %g0 = pto.vcgmax %abs0, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g1 = pto.vcgmax %abs1, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g2 = pto.vcgmax %abs2, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g3 = pto.vcgmax %abs3, %all_b32 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g01 = pto.vmax %g0, %g1, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %g23 = pto.vmax %g2, %g3, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %amax = pto.vmax %g01, %g23, %slot8_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %amax_i32 = pto.vbitcast %amax : !pto.vreg<64xf32> -> !pto.vreg<64xi32> + %exp_mask = pto.vdup %c2139095040_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %shift = pto.vdup %c23_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %emax = pto.vdup %c8_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %scale_exp_bias = pto.vdup %c254_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %exp_bits = pto.vand %amax_i32, %exp_mask, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %exp = pto.vshr %exp_bits, %shift, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %e8m0_payload_i32 = pto.vsub %exp, %emax, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + + %idx0 = pto.vdup %c0_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx1 = pto.vdup %c1_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx2 = pto.vdup %c2_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx3 = pto.vdup %c3_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx4 = pto.vdup %c4_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx5 = pto.vdup %c5_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx6 = pto.vdup %c6_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + %idx7 = pto.vdup %c7_i32, %all_b32 : i32, !pto.mask -> !pto.vreg<64xi32> + + %not_vl8 = pto.pnot %vl8_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_8_15 = pto.pand %vl16_b32, %not_vl8, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_1 = pto.vsel %idx1, %idx0, %range_8_15 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl16 = pto.pnot %vl16_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_16_23 = pto.pand %vl24_b32, %not_vl16, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_2 = pto.vsel %idx2, %broadcast_idx_1, %range_16_23 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl24 = pto.pnot %vl24_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_24_31 = pto.pand %vl32_b32, %not_vl24, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_3 = pto.vsel %idx3, %broadcast_idx_2, %range_24_31 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl32 = pto.pnot %vl32_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_32_39 = pto.pand %vl40_b32, %not_vl32, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_4 = pto.vsel %idx4, %broadcast_idx_3, %range_32_39 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl40 = pto.pnot %vl40_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_40_47 = pto.pand %vl48_b32, %not_vl40, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_5 = pto.vsel %idx5, %broadcast_idx_4, %range_40_47 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl48 = pto.pnot %vl48_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_48_55 = pto.pand %vl56_b32, %not_vl48, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx_6 = pto.vsel %idx6, %broadcast_idx_5, %range_48_55 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %not_vl56 = pto.pnot %vl56_b32, %all_b32 : !pto.mask, !pto.mask -> !pto.mask + %range_56_63 = pto.pand %all_b32, %not_vl56, %all_b32 : !pto.mask, !pto.mask, !pto.mask -> !pto.mask + %broadcast_idx = pto.vsel %idx7, %broadcast_idx_6, %range_56_63 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + + %scale_u16 = pto.vpack %e8m0_payload_i32, "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> + %scale_u8 = pto.vpack %scale_u16, "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> + %scale_slot = arith.divui %row, %c8 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + %scale8_b8 = pto.pge_b8 "PAT_VL8" : !pto.mask + pto.vsts %scale_u8, %ub_out_e8m0[%scale_ub_off], %scale8_b8 {dist = "NORM_B8"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask + + %scale_exp = pto.vsub %scale_exp_bias, %e8m0_payload_i32, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %scale_bits = pto.vshl %scale_exp, %shift, %all_b32 : !pto.vreg<64xi32>, !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xi32> + %scale_f32 = pto.vbitcast %scale_bits : !pto.vreg<64xi32> -> !pto.vreg<64xf32> + %scale_vec = pto.vselr %scale_f32, %broadcast_idx : !pto.vreg<64xf32>, !pto.vreg<64xi32> -> !pto.vreg<64xf32> + + %m0 = pto.vmul %d4, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m1 = pto.vmul %d6, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m2 = pto.vmul %d5, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + %m3 = pto.vmul %d7, %scale_vec, %all_b32 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + + %i0, %i1 = pto.vintlv %m0, %m2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i2, %i3 = pto.vintlv %m1, %m3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i4, %i5 = pto.vintlv %i0, %i2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %i6, %i7 = pto.vintlv %i1, %i3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r0, %r1 = pto.vdintlv %i4, %i5 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r2, %r3 = pto.vdintlv %i6, %i7 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r4, %r5 = pto.vdintlv %r0, %r2 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + %r6, %r7 = pto.vdintlv %r1, %r3 : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> + + %all_b8 = pto.pset_b8 "PAT_ALL" : !pto.mask + %q0 = pto.vcvt %r4, %all_b32 {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q1 = pto.vcvt %r6, %all_b32 {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q2 = pto.vcvt %r5, %all_b32 {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q3 = pto.vcvt %r7, %all_b32 {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q01 = pto.vor %q0, %q1, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q012 = pto.vor %q01, %q2, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + %q = pto.vor %q012, %q3, %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + pto.vsts %q, %ub_out_fp8_f8[%elem_off], %all_b8 : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + } + } + + pto.set_flag[, , ] + pto.wait_flag[, , ] + pto.copy_ubuf_to_gm %ub_out_fp8_u8, %out_fp8_gm, %c0_i64, %c1_i64, %c1024_i64, %c0_i64, %c1024_i64, %c1024_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.copy_ubuf_to_gm %ub_out_e8m0, %out_e8m0_gm, %c0_i64, %c4_i64, %c8_i64, %c0_i64, %c8_i64, %c32_i64 + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, i64 + pto.barrier + return + } + } +} +``` + +## Scale Store Contract + +上面的 lower 对每次循环执行一条 `NORM_B8` store,写到 32B 对齐的 UB +slot: + +```text +row = 0 -> UB[0..7], UB[8..31] padding +row = 8 -> UB[32..39], UB[40..63] padding +row = 16 -> UB[64..71], UB[72..95] padding +row = 24 -> UB[96..103], UB[104..127] padding +``` + +最终 copy-out 只搬每个 slot 的前 8B: + +```text +copy len = 8B +repeat = 4 +source stride = 32B +destination stride = 8B +``` + +因此 GM 端效果仍然是连续 scale 输出: + +```text +GM[0..7] <- UB[0..7] +GM[8..15] <- UB[32..39] +GM[16..23] <- UB[64..71] +GM[24..31] <- UB[96..103] +``` diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index e8c44a1454..867567de30 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -28,14 +28,24 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, int64_t factor, int64_t blockElems = 1); + static VMILayoutAttr getSparse(::mlir::MLIRContext *context, + int64_t sparseFactor); static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, int64_t numGroups, - int64_t slots = 0); + int64_t slots = 0, + int64_t sparseFactor = 1); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } + bool isSparse() const { return getKind() == "sparse"; } bool isGroupSlots() const { return getKind() == "num_groups"; } int64_t getNumGroups() const { return getFactor(); } + bool hasSparseFactor() const { + return isSparse() || (isGroupSlots() && getBlockElems() != 1); + } + int64_t getSparseFactor() const { + return isSparse() ? getFactor() : getBlockElems(); + } }]; } diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 9acce8cd7b..1b44fc0e40 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -438,6 +438,16 @@ def VMIGroupReduceAddIOp : VMI_Op<"group_reduce_addi"> { let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; } +def VMIGroupReduceMaxIOp : VMI_Op<"group_reduce_maxi"> { + let summary = "VMI masked integer maximum reduction within fixed logical groups"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + VMI_MaskTypeConstraint:$mask, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `,` $mask attr-dict `:` type($source) `,` type($mask) `->` type($result)"; +} + def VMIGroupBroadcastOp : VMI_Op<"group_broadcast"> { let summary = "VMI broadcast group-slot values back to each logical group"; let arguments = (ins VMI_VRegTypeConstraint:$source, @@ -474,6 +484,23 @@ def VMIExtFOp : VMI_Op<"extf"> { def VMITruncFOp : VMI_Op<"truncf"> { let summary = "VMI floating-point elementwise truncation"; + let arguments = (ins VMI_VRegTypeConstraint:$source, + OptionalAttr:$rounding); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMIFPToSIOp : VMI_Op<"fptosi"> { + let summary = "VMI floating-point to signed integer elementwise conversion"; + let arguments = (ins VMI_VRegTypeConstraint:$source); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; +} + +def VMISIToFPOp : VMI_Op<"sitofp"> { + let summary = "VMI signed integer to floating-point elementwise conversion"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; @@ -538,6 +565,16 @@ def VMIGroupSlotLoadOp : VMI_Op<"group_slot_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI block-strided vector load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, + I16:$block_stride, I16:$repeat_stride, + VMI_MaskTypeConstraint:$mask); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($source) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask) `->` type($result)"; +} + def VMIMaskedLoadOp : VMI_Op<"masked_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked vector load with passthrough lanes"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, @@ -595,6 +632,16 @@ def VMIMaskedStoreOp : VMI_Op<"masked_store", [DeclareOpInterfaceMethods]> { + let summary = "VMI block-strided vector store"; + let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, + Index:$offset, I16:$block_stride, I16:$repeat_stride, + VMI_MaskTypeConstraint:$mask); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$value `,` $destination `[` $offset `]` `,` $block_stride `,` $repeat_stride `,` $mask attr-dict `:` type($value) `,` type($destination) `,` type($block_stride) `,` type($repeat_stride) `,` type($mask)"; +} + def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods]> { let summary = "VMI logical masked indexed scatter"; let arguments = (ins VMI_VRegTypeConstraint:$value, @@ -606,22 +653,6 @@ def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods]> { - let summary = "VMI logical tile read"; - let arguments = (ins AnyType:$source); - let results = (outs VMI_VRegTypeConstraint:$result); - let hasVerifier = 1; - let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; -} - -def VMITileWriteOp : VMI_Op<"tile_write", [DeclareOpInterfaceMethods]> { - let summary = "VMI logical tile write"; - let arguments = (ins VMI_VRegTypeConstraint:$value, AnyType:$destination); - let results = (outs); - let hasVerifier = 1; - let assemblyFormat = "$value `,` $destination attr-dict `:` type($value) `,` type($destination)"; -} - def VMIShuffleOp : VMI_Op<"shuffle"> { let summary = "VMI static lane shuffle"; let arguments = (ins VMI_VRegTypeConstraint:$source, DenseI64ArrayAttr:$indices); diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 429a20bf0d..f4fb3744dc 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -38,6 +38,7 @@ enum class VMILayoutMaterializationSupportKind { Identity, ContiguousToDeinterleaved, DeinterleavedToContiguous, + DeinterleavedToDeinterleavedViaContiguous, }; struct VMILayoutMaterializationSupport { @@ -92,7 +93,8 @@ struct VMIGroupLoadSupport { enum class VMIGroupSlotsStoreSupportKind { Slots8UnitStrideVsts, - Slots1AlignedLane0Vsts, + Slots1PointVsts, + Slots1PackedUnitStrideVsts, }; struct VMIGroupSlotsStoreSupport { @@ -162,7 +164,7 @@ struct VMIExtFSupport { enum class VMITruncISupportKind { Deinterleaved2I32ToContiguousI16, Deinterleaved4I32ToContiguousI8, - GroupSlots1I32ToI16, + GroupSlots1I32ToNarrow, }; struct VMITruncISupport { @@ -173,6 +175,8 @@ struct VMITruncISupport { enum class VMIExtISupportKind { ContiguousI16ToDeinterleaved2I32, ContiguousI8ToDeinterleaved4I32, + GroupSlotsI16ToI32, + GroupSlotsI8ToI32, }; struct VMIExtISupport { @@ -271,6 +275,11 @@ class VMILayoutSupport { VMIGroupReduceAddIOp op, std::string *reason = nullptr) const; + FailureOr + getGroupReduceMaxISupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupReduceMaxIOp op, + std::string *reason = nullptr) const; + FailureOr getGroupBroadcastSupport(const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index 043da612e6..c9bded10f6 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -45,6 +45,7 @@ enum class VMIReductionKind { AddI, AddF, GroupAddI, + GroupMaxI, GroupAddF, GroupMaxF, MaxF, @@ -185,6 +186,10 @@ class VMITargetCapabilityRegistry { if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) return VMICapabilityResult::supported(); + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( "unsupported source/result layout pair"); } @@ -238,12 +243,13 @@ class VMITargetCapabilityRegistry { return VMICapabilityResult::missingCapability( "currently supports only f16/f32 elements for floating-point " "reduction"); - case VMIReductionKind::GroupAddI: { + case VMIReductionKind::GroupAddI: + case VMIReductionKind::GroupMaxI: { auto intType = dyn_cast(elementType); if (intType && intType.getWidth() == 32) return VMICapabilityResult::supported(); return VMICapabilityResult::missingCapability( - "grouped integer add reduction supports only i32 accumulator " + "grouped integer reduction supports only i32 accumulator " "elements because narrow integer reductions widen their result; " "cast i8/i16 storage before grouped reduction"); } diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 25e08ac381..d07423ae60 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -164,9 +164,32 @@ static FailureOr getLayoutBlockElems(Type type) { return (*layout).isDeinterleaved() ? (*layout).getBlockElems() : 1; } +static FailureOr getVMIPhysicalElementType(VMIVRegType type) { + Type elementType = type.getElementType(); + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.hasSparseFactor()) + return elementType; + + auto integerType = dyn_cast(elementType); + if (!integerType || !integerType.isUnsigned()) + return failure(); + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + int64_t sparseFactor = layout.getSparseFactor(); + if (elementBits == 0 || sparseFactor <= 1) + return failure(); + int64_t physicalBits = static_cast(elementBits) * sparseFactor; + if (physicalBits != 16 && physicalBits != 32) + return failure(); + return IntegerType::get(type.getContext(), physicalBits); +} + static FailureOr getPhysicalLanesPerPart(Type type) { - if (auto vregType = dyn_cast(type)) - return getDataLanesPerPart(vregType.getElementType()); + if (auto vregType = dyn_cast(type)) { + FailureOr physicalElementType = getVMIPhysicalElementType(vregType); + if (failed(physicalElementType)) + return failure(); + return getDataLanesPerPart(*physicalElementType); + } if (auto maskType = dyn_cast(type)) return getMaskLanesPerPart(maskType.getGranularity()); return failure(); @@ -316,6 +339,16 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, return success(); } +static bool isPackedByteGroupStore(Type memoryType, VMIVRegType dataType) { + Type memoryElementType = getMemoryElementType(memoryType); + if (!memoryElementType) + return false; + auto memoryIntegerType = dyn_cast(memoryElementType); + auto dataIntegerType = dyn_cast(dataType.getElementType()); + return memoryIntegerType && dataIntegerType && + memoryIntegerType.getWidth() == 8 && dataIntegerType.getWidth() == 32; +} + static LogicalResult verifyNumGroups(Operation *op, VMIVRegType type, int64_t numGroups) { if (numGroups <= 0) @@ -339,8 +372,9 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, if (auto vregType = dyn_cast(vmiType)) { FailureOr lanesPerPart = - getDataLanesPerPart(vregType.getElementType()); - if (failed(lanesPerPart)) + getPhysicalLanesPerPart(vregType); + FailureOr physicalElementType = getVMIPhysicalElementType(vregType); + if (failed(lanesPerPart) || failed(physicalElementType)) return op->emitOpError( "requires data element type with known physical lane count"); for (Type physicalType : physicalTypes) { @@ -348,7 +382,7 @@ static LogicalResult verifyPhysicalParts(Operation *op, Type vmiType, if (!partType) return op->emitOpError("requires physical data parts to be !pto.vreg"); if (partType.getElementCount() != *lanesPerPart || - partType.getElementType() != vregType.getElementType()) + partType.getElementType() != *physicalElementType) return op->emitOpError( "requires physical data part type to match VMI lane-map helper"); } @@ -428,9 +462,16 @@ VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0); } +VMILayoutAttr VMILayoutAttr::getSparse(MLIRContext *context, + int64_t sparseFactor) { + return VMILayoutAttr::get(context, "sparse", sparseFactor, 1, 0); +} + VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, - int64_t numGroups, int64_t slots) { - return VMILayoutAttr::get(context, "num_groups", numGroups, 1, slots); + int64_t numGroups, int64_t slots, + int64_t sparseFactor) { + return VMILayoutAttr::get(context, "num_groups", numGroups, sparseFactor, + slots); } Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { @@ -458,22 +499,33 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { return {}; } } + } else if (kind == "sparse") { + if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) + return {}; } else if (kind == "num_groups") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; - if (succeeded(parser.parseOptionalComma())) { + while (succeeded(parser.parseOptionalComma())) { StringRef field; - if (failed(parser.parseKeyword(&field)) || field != "slots" || - failed(parser.parseEqual()) || failed(parser.parseInteger(slots))) { + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual())) + return {}; + if (field == "slots") { + if (failed(parser.parseInteger(slots))) + return {}; + } else if (field == "sparse") { + if (failed(parser.parseInteger(blockElems))) + return {}; + } else { parser.emitError(parser.getCurrentLocation(), - "expected 'slots = '"); + "expected 'slots = ' or " + "'sparse = '"); return {}; } } } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " - "'deinterleaved' or 'num_groups'"); + "'deinterleaved' or 'sparse' or 'num_groups'"); return {}; } @@ -490,10 +542,14 @@ void VMILayoutAttr::print(AsmPrinter &printer) const { printer << " = " << getFactor(); if (getBlockElems() != 1) printer << ", block_elems = " << getBlockElems(); + } else if (isSparse()) { + printer << " = " << getFactor(); } else if (isGroupSlots()) { printer << " = " << getFactor(); if (getSlots() != 0) printer << ", slots = " << getSlots(); + if (getBlockElems() != 1) + printer << ", sparse = " << getBlockElems(); } printer << ">"; } @@ -524,13 +580,24 @@ VMILayoutAttr::verify(function_ref emitError, return success(); } + if (kind == "sparse") { + if (factor <= 1) + return emitError() << "#pto.vmi.layout requires sparse factor greater than 1"; + if (blockElems != 1 || slots != 0) + return emitError() << "#pto.vmi.layout requires block_elems and slots to be their " + "defaults"; + return success(); + } + if (kind == "num_groups") { if (factor <= 0) return emitError() << "#pto.vmi.layout requires num_groups to be positive"; - if (blockElems != 1) + if (blockElems <= 0) return emitError() << "#pto.vmi.layout requires block_elems to be 1"; + << "> requires sparse factor to be positive"; if (slots < 0) return emitError() << "#pto.vmi.layout emitError, << formatVMIVRegType(elementCount, elementType, layout) << "' expected an integer, index, floating-point, or " "PTO low-precision element type"; + if (pto::isPTOFloat4PackedType(elementType)) + return emitError() + << "'" << formatVMIVRegType(elementCount, elementType, layout) + << "' uses a packed FP4 physical pair type as a VMI logical " + "element type; packed FP4 input/output is not a supported VMI " + "surface because the logical FP4 lane count and physical packed " + "byte count are ambiguous"; if (layout && !mlir::isa(layout)) return emitError() << "'" @@ -1172,23 +1246,24 @@ LogicalResult VMIGroupReduceMaxFOp::verify() { return verifyGroupReduceFloatOp(*this, /*requiresReassoc=*/false); } -LogicalResult VMIGroupReduceAddIOp::verify() { - auto sourceType = cast(getSource().getType()); - auto maskType = cast(getMask().getType()); - auto resultType = cast(getResult().getType()); +template +static LogicalResult verifyGroupReduceIntegerOp(OpTy op) { + auto sourceType = cast(op.getSource().getType()); + auto maskType = cast(op.getMask().getType()); + auto resultType = cast(op.getResult().getType()); if (!isVMIIntegerLikeType(sourceType.getElementType())) - return emitOpError("requires integer-like VMI source element type"); + return op.emitOpError("requires integer-like VMI source element type"); auto intType = dyn_cast(sourceType.getElementType()); if (!intType || intType.getWidth() != 32) - return emitOpError( + return op.emitOpError( "requires i32 accumulator element type; cast i8/i16 storage to i32 " "before grouped reduction because integer reduction widens narrow " "inputs"); if (sourceType.getElementCount() != resultType.getElementCount()) - return emitOpError( + return op.emitOpError( "requires source and result logical lane counts to match"); if (sourceType.getElementType() != resultType.getElementType()) - return emitOpError("requires source and result element types to match"); + return op.emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { bool supportedSourceLayout = sourceLayout.isContiguous() || @@ -1199,21 +1274,29 @@ LogicalResult VMIGroupReduceAddIOp::verify() { (sourceLayout.getBlockElems() == 1 || sourceLayout.getBlockElems() == 8)); if (!supportedSourceLayout) - return emitOpError( + return op.emitOpError( "requires layout-assigned source to use contiguous layout or " "deinterleaved=2/4 layout with block_elems=1 or block_elems=8"); } if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) - return emitOpError() << "requires layout-assigned result to use " - "#pto.vmi.layout"; + resultLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) + return op.emitOpError() << "requires layout-assigned result to use " + "#pto.vmi.layout"; } - if (failed(verifyMaskMatchesData(getOperation(), maskType, sourceType))) + if (failed(verifyMaskMatchesData(op.getOperation(), maskType, sourceType))) return failure(); - return verifyNumGroups(getOperation(), sourceType, - getNumGroupsAttr().getInt()); + return verifyNumGroups(op.getOperation(), sourceType, + op.getNumGroupsAttr().getInt()); +} + +LogicalResult VMIGroupReduceAddIOp::verify() { + return verifyGroupReduceIntegerOp(*this); +} + +LogicalResult VMIGroupReduceMaxIOp::verify() { + return verifyGroupReduceIntegerOp(*this); } LogicalResult VMIGroupBroadcastOp::verify() { @@ -1321,6 +1404,50 @@ LogicalResult VMITruncFOp::verify() { getVMIElementBitWidth(resultType.getElementType())) return emitOpError( "requires result element type to be narrower than source element type"); + if (auto roundingAttr = (*this)->getAttrOfType("rounding")) { + StringRef rounding = roundingAttr.getValue(); + if (rounding != "A" && rounding != "H") + return emitOpError("rounding attr must be A or H"); + if (!sourceType.getElementType().isF32() || + !pto::isPTOHiFloat8Type(resultType.getElementType())) + return emitOpError( + "rounding attr is currently only supported for f32 to !pto.hif8 " + "truncf"); + } + return success(); +} + +LogicalResult VMIFPToSIOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMIFloatLikeType(sourceType.getElementType())) + return emitOpError("requires floating-point-like source element type"); + if (!isVMISignedOrSignlessIntegerType(resultType.getElementType())) + return emitOpError("requires signed or signless integer result element " + "type"); + if (getVMIElementBitWidth(resultType.getElementType()) != 32) + return emitOpError("requires 32-bit integer result element type"); + return success(); +} + +LogicalResult VMISIToFPOp::verify() { + auto sourceType = cast(getSource().getType()); + auto resultType = cast(getResult().getType()); + if (sourceType.getElementCount() != resultType.getElementCount()) + return emitOpError( + "requires source and result logical lane counts to match"); + if (!isVMISignedOrSignlessIntegerType(sourceType.getElementType())) + return emitOpError( + "requires signed or signless integer source element type"); + if (!isVMIFloatLikeType(resultType.getElementType())) + return emitOpError("requires floating-point-like result element type"); + if (getVMIElementBitWidth(sourceType.getElementType()) != 32) + return emitOpError("requires 32-bit integer source element type"); + if (!resultType.getElementType().isF32()) + return emitOpError("requires f32 result element type"); return success(); } @@ -1480,9 +1607,10 @@ LogicalResult VMIGatherOp::verify() { return failure(); auto indexElementType = dyn_cast(indicesType.getElementType()); - if (!indexElementType || indexElementType.getWidth() != 32 || - indexElementType.isSigned()) - return emitOpError("requires signless or unsigned 32-bit integer indices"); + if (!indexElementType || indexElementType.isSigned() || + (indexElementType.getWidth() != 16 && indexElementType.getWidth() != 32)) + return emitOpError( + "requires signless or unsigned 16-bit or 32-bit integer indices"); if (failed(verifyAllSameVRegShapeAndLayout( getOperation(), {indicesType, passthruType, resultType}, @@ -1492,6 +1620,14 @@ LogicalResult VMIGatherOp::verify() { {passthruType, resultType}, /*requireSameElement=*/true))) return failure(); + + auto resultIntegerType = dyn_cast(resultType.getElementType()); + if (indexElementType.getWidth() == 16 && + (!resultIntegerType || !resultIntegerType.isUnsigned() || + resultIntegerType.getWidth() != 16)) + return emitOpError( + "requires ui16 result and passthru element type when using ui16 " + "indices"); return verifyMaskMatchesData(getOperation(), maskType, resultType); } @@ -1535,7 +1671,8 @@ void VMIStoreOp::getEffects( LogicalResult VMIGroupStoreOp::verify() { auto valueType = cast(getValue().getType()); - if (failed(verifyMemoryElementMatches(getOperation(), + if (!isPackedByteGroupStore(getDestination().getType(), valueType) && + failed(verifyMemoryElementMatches(getOperation(), getDestination().getType(), valueType, "destination"))) return failure(); @@ -1549,6 +1686,21 @@ void VMIGroupStoreOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } +LogicalResult VMIStrideLoadOp::verify() { + auto resultType = cast(getResult().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, resultType); +} + +void VMIStrideLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedStoreOp::verify() { auto valueType = cast(getValue().getType()); auto maskType = cast(getMask().getType()); @@ -1565,6 +1717,22 @@ void VMIMaskedStoreOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } +LogicalResult VMIStrideStoreOp::verify() { + auto valueType = cast(getValue().getType()); + auto maskType = cast(getMask().getType()); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), valueType, + "destination"))) + return failure(); + return verifyMaskMatchesData(getOperation(), maskType, valueType); +} + +void VMIStrideStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + LogicalResult VMIScatterOp::verify() { auto valueType = cast(getValue().getType()); auto indicesType = cast(getIndices().getType()); @@ -1592,30 +1760,6 @@ void VMIScatterOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } -LogicalResult VMITileReadOp::verify() { - return verifyMemoryElementMatches(getOperation(), getSource().getType(), - cast(getResult().getType()), - "source"); -} - -void VMITileReadOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); -} - -LogicalResult VMITileWriteOp::verify() { - return verifyMemoryElementMatches(getOperation(), getDestination().getType(), - cast(getValue().getType()), - "destination"); -} - -void VMITileWriteOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); -} - LogicalResult VMIShuffleOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 7529953fa5..6eae21dbeb 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -539,22 +539,6 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } - if (auto tileWrite = dyn_cast(op)) { - auto valueType = cast(tileWrite.getValue().getType()); - VMILayoutAttr layout = valueType.getLayoutAttr(); - if (!layout || layout.isContiguous()) - return success(); - - std::string reason; - if (failed(supports.getContiguousStoreSupport(valueType, &reason))) - return emitLayoutSupportContract( - op, diagOS, - "pto.vmi.tile_write has no registered contiguous-memory layout " - "support", - reason); - return success(); - } - if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); VMILayoutAttr layout = resultType.getLayoutAttr(); @@ -629,6 +613,40 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceAddISupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_addi has no registered group_slots layout " + "support", + reason); + return success(); + } + + if (auto reduce = dyn_cast(op)) { + auto resultType = cast(reduce.getResult().getType()); + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots()) + return success(); + + std::string reason; + if (failed( + supports.getGroupReduceMaxISupport(capabilities, reduce, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_reduce_maxi has no registered group_slots layout " + "support", + reason); + return success(); + } + if (auto broadcast = dyn_cast(op)) { auto sourceType = cast(broadcast.getSource().getType()); VMILayoutAttr layout = sourceType.getLayoutAttr(); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index f976b0d5a7..5276732484 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -284,6 +284,23 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); } + VMILayoutAttr getPreferredGroupBroadcastSourceLayout(Value value, + int64_t numGroups) { + auto type = dyn_cast(value.getType()); + if (!type) + return getContiguousLayout(); + if (VMILayoutAttr existing = type.getLayoutAttr()) + if (existing.isGroupSlots() && existing.getSlots() > 0) + return existing; + VMILayoutAttr solved = getDataLayout(value); + if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && + solved.getSlots() > 0) + return solved; + if (value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(type, numGroups); + return getPreferredGroupSlotsLayout(type, numGroups); + } + VMILayoutAttr getPreferredGroupLoadResultLayout(VMIGroupLoadOp op) { auto type = cast(op.getResult().getType()); if (VMILayoutAttr existing = type.getLayoutAttr()) @@ -350,7 +367,8 @@ struct LayoutSolver { return solved; if (value.getDefiningOp() || value.getDefiningOp() || - value.getDefiningOp()) + value.getDefiningOp() || + value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); if (value.getDefiningOp()) return getPreferredGroupSlotLoadLayout(type, numGroups); @@ -452,12 +470,12 @@ struct LayoutSolver { bool canProducerAdoptConsumerLayout(Operation *op) { if (!op) return false; - return isa(op); + return isa(op); } bool canGroupBroadcastProduceLayout(VMIGroupBroadcastOp broadcast, @@ -496,12 +514,37 @@ struct LayoutSolver { return true; } + bool isUnsupportedGroupBroadcastResultForLayout(Value value, + VMILayoutAttr layout) { + auto broadcast = value.getDefiningOp(); + return broadcast && !canGroupBroadcastProduceLayout(broadcast, layout); + } + + LogicalResult constrainElementwiseBinary(OpOperand &lhs, OpOperand &rhs, + Value result, Operation *op) { + VMILayoutAttr lhsLayout = getExplicitDataLayout(lhs.get()); + VMILayoutAttr rhsLayout = getExplicitDataLayout(rhs.get()); + VMILayoutAttr fallback = getContiguousLayout(); + if ((lhsLayout && + isUnsupportedGroupBroadcastResultForLayout(rhs.get(), lhsLayout)) || + (rhsLayout && + isUnsupportedGroupBroadcastResultForLayout(lhs.get(), rhsLayout))) { + requestDataUse(lhs, fallback); + requestDataUse(rhs, fallback); + return setNaturalLayout(result, fallback, op); + } + + if (failed(unite(lhs.get(), rhs.get(), op))) + return failure(); + return unite(lhs.get(), result, op); + } + bool canAdoptConsumerRequestedLayout(Value value, VMILayoutAttr requestedLayout) { Operation *definingOp = value.getDefiningOp(); if (!definingOp) return false; - if (!isa(definingOp)) { + if (!isa(definingOp)) { if (!requestedLayout || requestedLayout.isContiguous()) return false; if (!canProducerAdoptConsumerLayout(definingOp)) @@ -669,38 +712,44 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto addf = dyn_cast(op)) { - if (failed(unite(addf.getLhs(), addf.getRhs(), op)) || - failed(unite(addf.getLhs(), addf.getResult(), op))) + if (failed(constrainElementwiseBinary(addf.getLhsMutable(), + addf.getRhsMutable(), + addf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto addi = dyn_cast(op)) { - if (failed(unite(addi.getLhs(), addi.getRhs(), op)) || - failed(unite(addi.getLhs(), addi.getResult(), op))) + if (failed(constrainElementwiseBinary(addi.getLhsMutable(), + addi.getRhsMutable(), + addi.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto subf = dyn_cast(op)) { - if (failed(unite(subf.getLhs(), subf.getRhs(), op)) || - failed(unite(subf.getLhs(), subf.getResult(), op))) + if (failed(constrainElementwiseBinary(subf.getLhsMutable(), + subf.getRhsMutable(), + subf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto subi = dyn_cast(op)) { - if (failed(unite(subi.getLhs(), subi.getRhs(), op)) || - failed(unite(subi.getLhs(), subi.getResult(), op))) + if (failed(constrainElementwiseBinary(subi.getLhsMutable(), + subi.getRhsMutable(), + subi.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto mulf = dyn_cast(op)) { - if (failed(unite(mulf.getLhs(), mulf.getRhs(), op)) || - failed(unite(mulf.getLhs(), mulf.getResult(), op))) + if (failed(constrainElementwiseBinary(mulf.getLhsMutable(), + mulf.getRhsMutable(), + mulf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto muli = dyn_cast(op)) { - if (failed(unite(muli.getLhs(), muli.getRhs(), op)) || - failed(unite(muli.getLhs(), muli.getResult(), op))) + if (failed(constrainElementwiseBinary(muli.getLhsMutable(), + muli.getRhsMutable(), + muli.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -712,20 +761,23 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto divf = dyn_cast(op)) { - if (failed(unite(divf.getLhs(), divf.getRhs(), op)) || - failed(unite(divf.getLhs(), divf.getResult(), op))) + if (failed(constrainElementwiseBinary(divf.getLhsMutable(), + divf.getRhsMutable(), + divf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto minf = dyn_cast(op)) { - if (failed(unite(minf.getLhs(), minf.getRhs(), op)) || - failed(unite(minf.getLhs(), minf.getResult(), op))) + if (failed(constrainElementwiseBinary(minf.getLhsMutable(), + minf.getRhsMutable(), + minf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto maxf = dyn_cast(op)) { - if (failed(unite(maxf.getLhs(), maxf.getRhs(), op)) || - failed(unite(maxf.getLhs(), maxf.getResult(), op))) + if (failed(constrainElementwiseBinary(maxf.getLhsMutable(), + maxf.getRhsMutable(), + maxf.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -764,33 +816,47 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto fptosi = dyn_cast(op)) { + if (failed(unite(fptosi.getSource(), fptosi.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (auto sitofp = dyn_cast(op)) { + if (failed(unite(sitofp.getSource(), sitofp.getResult(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto andi = dyn_cast(op)) { - if (failed(unite(andi.getLhs(), andi.getRhs(), op)) || - failed(unite(andi.getLhs(), andi.getResult(), op))) + if (failed(constrainElementwiseBinary(andi.getLhsMutable(), + andi.getRhsMutable(), + andi.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto ori = dyn_cast(op)) { - if (failed(unite(ori.getLhs(), ori.getRhs(), op)) || - failed(unite(ori.getLhs(), ori.getResult(), op))) + if (failed(constrainElementwiseBinary( + ori.getLhsMutable(), ori.getRhsMutable(), ori.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto xori = dyn_cast(op)) { - if (failed(unite(xori.getLhs(), xori.getRhs(), op)) || - failed(unite(xori.getLhs(), xori.getResult(), op))) + if (failed(constrainElementwiseBinary(xori.getLhsMutable(), + xori.getRhsMutable(), + xori.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto shli = dyn_cast(op)) { - if (failed(unite(shli.getLhs(), shli.getRhs(), op)) || - failed(unite(shli.getLhs(), shli.getResult(), op))) + if (failed(constrainElementwiseBinary(shli.getLhsMutable(), + shli.getRhsMutable(), + shli.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } if (auto shrui = dyn_cast(op)) { - if (failed(unite(shrui.getLhs(), shrui.getRhs(), op)) || - failed(unite(shrui.getLhs(), shrui.getResult(), op))) + if (failed(constrainElementwiseBinary(shrui.getLhsMutable(), + shrui.getRhsMutable(), + shrui.getResult(), op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -960,11 +1026,39 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + auto resultType = cast(reduce.getResult().getType()); + int64_t numGroups = reduce.getNumGroupsAttr().getInt(); + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredGroupReduceLayoutFact(sourceType, numGroups); + VMILayoutAttr sourceLayout = + getPreferredGroupReduceSourceLayout(sourceType, numGroups); + VMILayoutAttr solvedSourceLayout = + getExplicitDataLayout(reduce.getSource()); + if (solvedSourceLayout && succeeded(fact) && + isCompatibleGroupReduceSourceLayout(*fact, solvedSourceLayout)) + sourceLayout = solvedSourceLayout; + requestDataUse(reduce.getSourceMutable(), sourceLayout); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceLayout, + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + if (failed(setNaturalLayout( + reduce.getResult(), + succeeded(fact) + ? fact->resultLayout + : getPreferredGroupSlotsLayout(resultType, numGroups), + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto broadcast = dyn_cast(op)) { - auto sourceType = cast(broadcast.getSource().getType()); - requestDataUse(broadcast.getSourceMutable(), - getPreferredGroupSlotsLayout( - sourceType, broadcast.getNumGroupsAttr().getInt())); + requestDataUse( + broadcast.getSourceMutable(), + getPreferredGroupBroadcastSourceLayout( + broadcast.getSource(), broadcast.getNumGroupsAttr().getInt())); return WalkResult::advance(); } if (auto hist = dyn_cast(op)) { @@ -1007,6 +1101,16 @@ struct LayoutSolver { if (auto extsi = dyn_cast(op)) { auto sourceType = cast(extsi.getSource().getType()); auto resultType = cast(extsi.getResult().getType()); + VMILayoutAttr sourceLayout = getDataLayout(extsi.getSource()); + if (sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 8 && + getElementBitWidth(sourceType.getElementType()) < 32 && + getElementBitWidth(resultType.getElementType()) == 32) { + requestDataUse(extsi.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(extsi.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); @@ -1022,6 +1126,16 @@ struct LayoutSolver { if (auto extui = dyn_cast(op)) { auto sourceType = cast(extui.getSource().getType()); auto resultType = cast(extui.getResult().getType()); + VMILayoutAttr sourceLayout = getDataLayout(extui.getSource()); + if (sourceLayout && sourceLayout.isGroupSlots() && + sourceLayout.getSlots() == 8 && + getElementBitWidth(sourceType.getElementType()) < 32 && + getElementBitWidth(resultType.getElementType()) == 32) { + requestDataUse(extui.getSourceMutable(), sourceLayout); + if (failed(setNaturalLayout(extui.getResult(), sourceLayout, op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } VMILayoutSupport supports; FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); @@ -1065,11 +1179,18 @@ struct LayoutSolver { FailureOr fact = supports.getPreferredCastLayoutFact(sourceType, resultType); VMILayoutAttr sourceLayout = getDataLayout(trunci.getSource()); - if (succeeded(fact) && fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout && sourceLayout.isGroupSlots() && - sourceLayout.getSlots() == 1) { + if (succeeded(fact) && sourceLayout && sourceLayout.isGroupSlots() && + (sourceLayout.getSlots() == 1 || sourceLayout.getSlots() == 8) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { requestDataUse(trunci.getSourceMutable(), sourceLayout); - if (failed(setNaturalLayout(trunci.getResult(), sourceLayout, op))) + VMILayoutAttr resultLayout = sourceLayout; + if (sourceLayout.getSlots() == 8 && + fact->kind == VMICastLayoutKind::Narrow4x) + resultLayout = VMILayoutAttr::getGroupSlots( + ctx, sourceLayout.getNumGroups(), sourceLayout.getSlots(), + /*sparseFactor=*/4); + if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); } @@ -1132,6 +1253,17 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto load = dyn_cast(op)) { + auto resultType = cast(load.getResult().getType()); + if (failed( + setNaturalLayout(load.getResult(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + if (failed(requestMaskUse( + load.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(resultType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { requestDataUse(store.getValueMutable(), getContiguousLayout()); return WalkResult::advance(); @@ -1152,6 +1284,15 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto store = dyn_cast(op)) { + auto valueType = cast(store.getValue().getType()); + requestDataUse(store.getValueMutable(), getContiguousLayout()); + if (failed(requestMaskUse( + store.getMaskMutable(), getContiguousLayout(), + getMaskGranularityForElement(valueType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto scatter = dyn_cast(op)) { auto valueType = cast(scatter.getValue().getType()); requestDataUse(scatter.getValueMutable(), getContiguousLayout()); @@ -1171,10 +1312,6 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } - if (auto tileWrite = dyn_cast(op)) { - requestDataUse(tileWrite.getValueMutable(), getContiguousLayout()); - return WalkResult::advance(); - } if (auto split = dyn_cast(op)) { int64_t channels = split.getNumResults(); VMICapabilityResult capability = capabilities.supportsChannelCount( @@ -1637,6 +1774,14 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto reduce = dyn_cast(op)) { + auto sourceType = cast(reduce.getSource().getType()); + if (failed(requestMaskUse( + reduce.getMaskMutable(), sourceType.getLayoutAttr(), + getMaskGranularityForElement(sourceType.getElementType()), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed(requestMaskUse( diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp index fda374f661..ac7942d93a 100644 --- a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +++ b/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp @@ -108,9 +108,6 @@ struct VMILayoutFoldConsumersPass if (auto store = dyn_cast(op)) tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), maybeDeadEnsures); - if (auto tileWrite = dyn_cast(op)) - tryFoldEnsureLayoutIntoOperand(tileWrite.getValueMutable(), - maybeDeadEnsures); if (auto maskedStore = dyn_cast(op)) tryFoldEnsureLayoutIntoMaskedStore(maskedStore, maybeDeadEnsures, maybeDeadMaskEnsures); diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index a3babbf7ab..393b5b687c 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -335,6 +335,12 @@ getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) return VMILayoutMaterializationSupport{ VMILayoutMaterializationSupportKind::DeinterleavedToContiguous}; + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind:: + DeinterleavedToDeinterleavedViaContiguous}; return fail("unsupported source/result layout pair"); } @@ -774,17 +780,18 @@ VMILayoutSupport::getGroupSlotsStoreSupport( if (elementBits == 0 || 256 % elementBits != 0) return fail("slots=1 group_store requires an 8/16/32-bit element " "type"); - int64_t alignedStrideElems = 256 / elementBits; std::optional rowStride = getConstantIndexValue(op.getRowStride()); - if (!rowStride || *rowStride <= 0 || *rowStride % alignedStrideElems != 0) - return fail(Twine("slots=1 group_store currently lowers as one " - "lane-0 vsts per group and requires constant " - "positive row_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B store alignment; packed or unaligned " - "contiguous store lowering is not implemented"); + FailureOr lanesPerPart = + getDataLanesPerPart(valueType.getElementType()); + if (rowStride && *rowStride == 1 && succeeded(lanesPerPart) && + numGroups <= *lanesPerPart) + return VMIGroupSlotsStoreSupport{ + VMIGroupSlotsStoreSupportKind::Slots1PackedUnitStrideVsts}; + if (rowStride && *rowStride <= 0) + return fail("slots=1 group_store requires positive row_stride when " + "row_stride is constant"); return VMIGroupSlotsStoreSupport{ - VMIGroupSlotsStoreSupportKind::Slots1AlignedLane0Vsts}; + VMIGroupSlotsStoreSupportKind::Slots1PointVsts}; } if (layout.getSlots() == 8) { @@ -989,6 +996,19 @@ VMILayoutSupport::getGroupReduceAddISupport( VMIReductionKind::GroupAddI, reason); } +FailureOr +VMILayoutSupport::getGroupReduceMaxISupport( + const VMITargetCapabilityRegistry &capabilities, VMIGroupReduceMaxIOp op, + std::string *reason) const { + return getGroupReduceAddSupportImpl( + capabilities, op.getOperation(), + cast(op.getSource().getType()), + cast(op.getMask().getType()), + cast(op.getResult().getType()), + op.getNumGroupsAttr().getInt(), /*requiresReassoc=*/false, + VMIReductionKind::GroupMaxI, reason); +} + FailureOr VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastOp op, std::string *reason) const { @@ -1064,6 +1084,15 @@ FailureOr VMILayoutSupport::getGroupBroadcastSupport( return VMIGroupBroadcastSupport{ VMIGroupBroadcastSupportKind::GroupSlotsVselr}; + bool deinterleavedSmallGroup = + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() == 1 && + *groupSize < *lanesPerPart && *groupSize >= *resultFactor && + *groupSize % *resultFactor == 0 && + *lanesPerPart % (*groupSize / *resultFactor) == 0; + if (deinterleavedSmallGroup) + return VMIGroupBroadcastSupport{ + VMIGroupBroadcastSupportKind::GroupSlotsVselr}; + int64_t logicalSpanPerResultChunk = *lanesPerPart * *resultFactor; if (*groupSize < *lanesPerPart || *groupSize % logicalSpanPerResultChunk != 0) return fail("deinterleaved result requires every physical result chunk to " @@ -1107,7 +1136,7 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - !sourceType.getElementType().isF32() || *resultArity != 1) + !sourceType.getElementType().isF32()) return fail("requires f32 deinterleaved source and contiguous result"); FailureOr fact = @@ -1118,11 +1147,13 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { "element width"); if (fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor * *resultArity) return VMITruncFSupport{ VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16}; if (fact->kind == VMICastLayoutKind::Narrow4x && - sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + sourceLayout.getFactor() == fact->factor && + *sourceArity == fact->factor * *resultArity) return VMITruncFSupport{ VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8}; @@ -1192,6 +1223,34 @@ static FailureOr getExtISupportImpl(OpT op, failed(resultArity)) return fail("requires assigned source/result layouts and computable " "physical arity"); + + if (sourceLayout.isGroupSlots() && resultLayout.isGroupSlots()) { + if (!isa(sourceType.getElementType()) || + !isa(resultType.getElementType())) + return fail("requires integer source/result element types"); + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 8 || resultLayout.getSlots() != 8) + return fail("requires matching group_slots(num_groups=G, slots=8) " + "source/result layouts"); + if (*sourceArity != *resultArity) + return fail("group_slots integer extension requires matching physical " + "arity"); + + unsigned sourceBits = pto::getPTOStorageElemBitWidth( + sourceType.getElementType()); + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultType.getElementType()); + if (resultBits != 32) + return fail("group_slots integer extension requires 32-bit result " + "element type"); + if (sourceBits == 16) + return VMIExtISupport{VMIExtISupportKind::GroupSlotsI16ToI32}; + if (sourceBits == 8) + return VMIExtISupport{VMIExtISupportKind::GroupSlotsI8ToI32}; + return fail("group_slots integer extension source must be 8-bit or " + "16-bit"); + } + if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || !isa(sourceType.getElementType()) || !isa(resultType.getElementType())) @@ -1259,13 +1318,15 @@ VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { if (sourceLayout.isGroupSlots() || resultLayout.isGroupSlots()) { if (!sourceLayout.isGroupSlots() || !resultLayout.isGroupSlots() || sourceLayout.getNumGroups() != resultLayout.getNumGroups() || - sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || - sourceBits != 32 || resultBits != 16 || *sourceArity != *resultArity) + sourceLayout.getSlots() != resultLayout.getSlots() || + (sourceLayout.getSlots() != 1 && sourceLayout.getSlots() != 8) || + sourceBits != 32 || (resultBits != 16 && resultBits != 8) || + *sourceArity != *resultArity) return fail("group-slot trunci requires matching " - "group_slots(num_groups=G, slots=1) source/result layouts, " - "32-bit integer source, 16-bit integer result, and matching " - "physical arity"); - return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToI16}; + "group_slots(num_groups=G, slots=1 or 8) source/result layouts, " + "32-bit integer source, 8/16-bit integer result, and " + "matching physical arity"); + return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToNarrow}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || @@ -1314,8 +1375,6 @@ VMILayoutSupport::getBitcastSupport(VMIBitcastOp op, return fail("requires assigned source and result layouts"); if (sourceLayout != resultLayout) return fail("requires matching source and result layouts"); - if (sourceLayout.isGroupSlots()) - return fail("does not support group_slots layouts"); FailureOr sourceArity = getVMIPhysicalArity(sourceType); FailureOr resultArity = getVMIPhysicalArity(resultType); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 806f6c67fc..4e0c7f3829 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -122,6 +122,16 @@ bool isVMIOp(Operation *op) { return op->getName().getStringRef().starts_with("pto.vmi."); } +StringRef getTruncFRoundModeForResult(Type resultElementType) { + return pto::isPTOHiFloat8Type(resultElementType) ? "A" : "R"; +} + +StringRef getTruncFRoundMode(VMITruncFOp op, Type resultElementType) { + if (auto roundingAttr = op->getAttrOfType("rounding")) + return roundingAttr.getValue(); + return getTruncFRoundModeForResult(resultElementType); +} + bool isLayoutAssignedVMIType(Type type) { if (auto vregType = dyn_cast(type)) return static_cast(vregType.getLayoutAttr()); @@ -231,6 +241,25 @@ materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, return SmallVector(unpackOp->getResults()); } +static FailureOr getVMIVRegPhysicalElementType(VMIVRegType type) { + Type elementType = type.getElementType(); + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.hasSparseFactor()) + return elementType; + + auto integerType = dyn_cast(elementType); + if (!integerType || !integerType.isUnsigned()) + return failure(); + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + int64_t sparseFactor = layout.getSparseFactor(); + if (elementBits == 0 || sparseFactor <= 1) + return failure(); + int64_t physicalBits = static_cast(elementBits) * sparseFactor; + if (physicalBits != 16 && physicalBits != 32) + return failure(); + return IntegerType::get(type.getContext(), physicalBits); +} + class VMIToVPTOTypeConverter final : public OneToNTypeConverter { public: VMIToVPTOTypeConverter() { @@ -238,13 +267,17 @@ class VMIToVPTOTypeConverter final : public OneToNTypeConverter { addConversion( [](VMIVRegType type, SmallVectorImpl &results) -> LogicalResult { FailureOr arity = getVMIPhysicalArity(type); + FailureOr physicalElementType = + getVMIVRegPhysicalElementType(type); + if (failed(arity) || failed(physicalElementType)) + return failure(); FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (failed(arity) || failed(lanesPerPart)) + getDataLanesPerPart(*physicalElementType); + if (failed(lanesPerPart)) return failure(); for (int64_t i = 0; i < *arity; ++i) results.push_back(VRegType::get(type.getContext(), *lanesPerPart, - type.getElementType())); + *physicalElementType)); return success(); }); addConversion( @@ -625,8 +658,13 @@ FailureOr getVMITypeElementCount(Type type) { } FailureOr getVMITypeLanesPerPart(Type type) { - if (auto vregType = dyn_cast(type)) - return getDataLanesPerPart(vregType.getElementType()); + if (auto vregType = dyn_cast(type)) { + FailureOr physicalElementType = + getVMIVRegPhysicalElementType(vregType); + if (failed(physicalElementType)) + return failure(); + return getDataLanesPerPart(*physicalElementType); + } if (auto maskType = dyn_cast(type)) return getMaskLanesPerPart(maskType.getGranularity()); return failure(); @@ -796,6 +834,27 @@ std::optional getConstantIndexValue(Value value) { return std::nullopt; } +bool isKnownIndexMultipleOf(Value value, int64_t multiple, int depth = 0) { + if (multiple <= 1) + return true; + if (depth > 6) + return false; + if (std::optional constant = getConstantIndexValue(value)) + return *constant % multiple == 0; + + if (auto add = value.getDefiningOp()) + return isKnownIndexMultipleOf(add.getLhs(), multiple, depth + 1) && + isKnownIndexMultipleOf(add.getRhs(), multiple, depth + 1); + if (auto sub = value.getDefiningOp()) + return isKnownIndexMultipleOf(sub.getLhs(), multiple, depth + 1) && + isKnownIndexMultipleOf(sub.getRhs(), multiple, depth + 1); + if (auto mul = value.getDefiningOp()) + return isKnownIndexMultipleOf(mul.getLhs(), multiple, depth + 1) || + isKnownIndexMultipleOf(mul.getRhs(), multiple, depth + 1); + + return false; +} + FailureOr getStaticMemRefElementCount(Type type) { auto memrefType = dyn_cast(type); if (!memrefType || !memrefType.hasStaticShape()) @@ -807,6 +866,24 @@ FailureOr getStaticMemRefElementCount(Type type) { return elements; } +static Type getMemoryElementType(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getElementType(); + if (auto memrefType = dyn_cast(type)) + return memrefType.getElementType(); + return {}; +} + +static bool isPackedByteGroupStore(Type destinationType, VRegType valueType) { + Type destinationElementType = getMemoryElementType(destinationType); + auto destinationIntegerType = + dyn_cast_or_null(destinationElementType); + auto valueIntegerType = dyn_cast(valueType.getElementType()); + return destinationIntegerType && valueIntegerType && + pto::getPTOStorageElemBitWidth(destinationIntegerType) == 8 && + pto::getPTOStorageElemBitWidth(valueIntegerType) == 32; +} + enum class VMIMemoryValidMaskKind { AllTrue, ExplicitMask, @@ -1328,15 +1405,19 @@ checkSupportedGatherShape(const VMITargetCapabilityRegistry &capabilities, if (!sourceCapability.isSupported()) return fail(sourceCapability.reason); - if (pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) - return fail("currently requires 32-bit result element type so physical " - "offset and result lane counts match pto.vgather2_bc"); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); auto indexElementType = dyn_cast(indicesType.getElementType()); - if (!indexElementType || indexElementType.getWidth() != 32 || - indexElementType.isSigned()) - return fail("requires signless or unsigned 32-bit indices"); - if (maskType.getGranularity() != "b32") - return fail("requires b32 mask granularity"); + if (!indexElementType || indexElementType.isSigned()) + return fail("requires signless or unsigned integer indices"); + bool isU16Gather = resultBits == 16 && indexElementType.isUnsigned() && + indexElementType.getWidth() == 16 && + maskType.getGranularity() == "b16"; + bool isB32Gather = resultBits == 32 && indexElementType.getWidth() == 32 && + maskType.getGranularity() == "b32"; + if (!isU16Gather && !isB32Gather) + return fail("requires either 32-bit results with 32-bit indices and b32 " + "mask, or ui16 results with ui16 indices and b16 mask"); FailureOr resultArity = getVMIPhysicalArity(resultType); FailureOr indicesArity = getVMIPhysicalArity(indicesType); @@ -1350,20 +1431,25 @@ checkSupportedGatherShape(const VMITargetCapabilityRegistry &capabilities, return fail("requires result, indices, passthru, and mask to have the " "same physical arity"); - std::string resultReason; - std::string indicesReason; - std::string passthruReason; - std::string maskReason; - if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) - return fail(Twine("result requires full physical chunks; ") + resultReason); - if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) - return fail(Twine("indices require full physical chunks; ") + - indicesReason); - if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) - return fail(Twine("passthru requires full physical chunks; ") + - passthruReason); - if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) - return fail(Twine("mask requires full physical chunks; ") + maskReason); + if (isB32Gather) { + std::string resultReason; + std::string indicesReason; + std::string passthruReason; + std::string maskReason; + if (failed(checkFullDataPhysicalChunks(resultType, &resultReason))) + return fail(Twine("result requires full physical chunks; ") + + resultReason); + if (failed(checkFullDataPhysicalChunks(indicesType, &indicesReason))) + return fail(Twine("indices require full physical chunks; ") + + indicesReason); + if (failed(checkFullDataPhysicalChunks(passthruType, &passthruReason))) + return fail(Twine("passthru requires full physical chunks; ") + + passthruReason); + if (failed(checkFullVMIPhysicalChunks(maskType, &maskReason))) + return fail(Twine("mask requires full physical chunks; ") + maskReason); + } else if (*resultArity != 1) { + return fail("ui16 gather currently supports one physical chunk"); + } return success(); } @@ -1429,6 +1515,77 @@ checkSupportedScatterShape(const VMITargetCapabilityRegistry &capabilities, return success(); } +LogicalResult +checkSupportedStrideStoreShape(const VMITargetCapabilityRegistry &capabilities, + VMIStrideStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto valueType = cast(op.getValue().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr valueLayout = valueType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!valueLayout || !maskLayout) + return fail("requires assigned value and mask layouts"); + if (!valueLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous value and mask layouts"); + + VMICapabilityResult destinationCapability = + capabilities.supportsUBPointerMemory(op.getDestination().getType(), + "destination", "pto.vsstb", + "pto.vsstb writes only UB"); + if (!destinationCapability.isSupported()) + return fail(destinationCapability.reason); + if (failed(checkSupportedStoreShape(capabilities, valueType, + op.getDestination(), + op.getDestination().getType(), reason))) + return failure(); + + FailureOr valueArity = getVMIPhysicalArity(valueType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(valueArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*valueArity != 1 || *maskArity != 1) + return fail("currently supports one physical value/mask chunk"); + return success(); +} + +LogicalResult +checkSupportedStrideLoadShape(const VMITargetCapabilityRegistry &capabilities, + VMIStrideLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + auto maskType = cast(op.getMask().getType()); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskType.getLayoutAttr(); + if (!resultLayout || !maskLayout) + return fail("requires assigned result and mask layouts"); + if (!resultLayout.isContiguous() || !maskLayout.isContiguous()) + return fail("requires contiguous result and mask layouts"); + + VMICapabilityResult sourceCapability = capabilities.supportsUBPointerMemory( + op.getSource().getType(), "source", "pto.vsldb", + "pto.vsldb reads only UB"); + if (!sourceCapability.isSupported()) + return fail(sourceCapability.reason); + + FailureOr resultArity = getVMIPhysicalArity(resultType); + FailureOr maskArity = getVMIPhysicalArity(maskType); + if (failed(resultArity) || failed(maskArity)) + return fail("requires computable physical arity"); + if (*resultArity != 1 || *maskArity != 1) + return fail("currently supports one physical result/mask chunk"); + return success(); +} + Value stripMaskMaterialization(Value value) { while (true) { if (auto ensure = value.getDefiningOp()) { @@ -2336,8 +2493,12 @@ LogicalResult checkFullGroupBroadcastResultShape( bool blockFragmentSmallGroup = layout.isDeinterleaved() && layout.getBlockElems() > 1 && groupSize < lanesPerPart && lanesPerPart % layout.getBlockElems() == 0; + bool deinterleavedSmallGroup = + layout.isDeinterleaved() && layout.getBlockElems() == 1 && + groupSize < lanesPerPart && groupSize >= *factor && + groupSize % *factor == 0 && lanesPerPart % (groupSize / *factor) == 0; int64_t logicalSpanPerResultChunk = lanesPerPart * *factor; - if (!blockFragmentSmallGroup && + if (!blockFragmentSmallGroup && !deinterleavedSmallGroup && (groupSize < lanesPerPart || groupSize % logicalSpanPerResultChunk != 0)) return fail("group_broadcast deinterleaved result requires every " @@ -2415,6 +2576,99 @@ FailureOr createGroupSlotIndexVector(Location loc, VRegType indexType, return result; } +FailureOr createMappedGroupSlotIndexVector( + Location loc, VMIVRegType resultVMIType, int64_t part, int64_t chunk, + VRegType indexType, int64_t groupSize, int64_t slots, int64_t &sourceChunk, + PatternRewriter &rewriter) { + if (groupSize <= 0 || slots <= 0) + return failure(); + + int64_t lanesPerPart = indexType.getElementCount(); + SmallVector slotByLane; + slotByLane.reserve(lanesPerPart); + std::optional resolvedSourceChunk; + for (int64_t lane = 0; lane < lanesPerPart; ++lane) { + FailureOr logicalLane = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, lane); + if (failed(logicalLane)) + return failure(); + int64_t group = *logicalLane / groupSize; + int64_t candidateSourceChunk = group / slots; + if (resolvedSourceChunk && *resolvedSourceChunk != candidateSourceChunk) + return failure(); + resolvedSourceChunk = candidateSourceChunk; + slotByLane.push_back(group % slots); + } + if (!resolvedSourceChunk) + return failure(); + sourceChunk = *resolvedSourceChunk; + + FailureOr baseScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), slotByLane.front(), rewriter); + FailureOr maskType = + getMaskTypeForVReg(indexType, rewriter.getContext()); + FailureOr allMask = createAllTrueMaskForVReg(loc, indexType, rewriter); + if (failed(baseScalar) || failed(maskType) || failed(allMask)) + return failure(); + + Value result = rewriter + .create(loc, indexType, *baseScalar, *allMask, + /*position=*/nullptr) + .getResult(); + int64_t rangeBegin = 0; + while (rangeBegin < lanesPerPart) { + int64_t slot = slotByLane[rangeBegin]; + int64_t rangeEnd = rangeBegin + 1; + while (rangeEnd < lanesPerPart && slotByLane[rangeEnd] == slot) + ++rangeEnd; + if (rangeBegin != 0 || slot != slotByLane.front()) { + FailureOr slotScalar = createScalarOffsetConstant( + loc, indexType.getElementType(), slot, rewriter); + FailureOr laneMask = + createLaneRangeMask(loc, *maskType, rangeBegin, rangeEnd, rewriter); + if (failed(slotScalar) || failed(laneMask)) + return failure(); + Value splat = rewriter + .create(loc, indexType, *slotScalar, *allMask, + /*position=*/nullptr) + .getResult(); + result = rewriter.create(loc, indexType, splat, result, *laneMask) + .getResult(); + } + rangeBegin = rangeEnd; + } + return result; +} + +template +FailureOr reduceVcgSlotsToLane0(Location loc, Value reduced, + VRegType resultType, Value firstLaneMask, + PatternRewriter &rewriter) { + unsigned indexBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return failure(); + + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = VRegType::get( + rewriter.getContext(), resultType.getElementCount(), indexElementType); + Value accumulator = reduced; + for (int64_t slot = 1; slot < 8; ++slot) { + FailureOr slotIndex = createGroupSlotIndexVector( + loc, indexType, resultType.getElementCount(), slot, rewriter); + if (failed(slotIndex)) + return failure(); + Value selected = + rewriter.create(loc, resultType, reduced, *slotIndex) + .getResult(); + accumulator = rewriter + .create(loc, resultType, selected, + accumulator, firstLaneMask) + .getResult(); + } + return accumulator; +} + std::optional getX2MemoryDistToken(Type elementType, StringRef prefix) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); @@ -2423,6 +2677,13 @@ std::optional getX2MemoryDistToken(Type elementType, return (Twine(prefix) + "_B" + Twine(elementBits)).str(); } +std::optional getPointStoreDistToken(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return std::nullopt; + return (Twine("1PT_B") + Twine(elementBits)).str(); +} + std::optional getVPTOCmpMode(StringRef predicate) { if (predicate == "eq" || predicate == "ne" || predicate == "lt" || predicate == "le" || predicate == "gt" || predicate == "ge") @@ -2645,6 +2906,19 @@ FailureOr> materializeDataLayoutConversion( return results; } + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && + (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) { + VMILayoutAttr contiguous = + VMILayoutAttr::getContiguous(rewriter.getContext()); + FailureOr> dense = materializeDataLayoutConversion( + op, sourceParts, resultTypes, sourceLayout, contiguous, rewriter); + if (failed(dense)) + return failure(); + return materializeDataLayoutConversion(op, *dense, resultTypes, contiguous, + resultLayout, rewriter); + } + (void)rewriter.notifyMatchFailure( op, "unsupported VMI data layout materialization"); return failure(); @@ -4074,28 +4348,35 @@ struct OneToNVMIGroupSlotLoadOpPattern if (!stride || *stride != 1) return rewriter.notifyMatchFailure( op, "slots=8 group_slot_load requires constant unit stride"); - if (resultTypes.size() != 1) - return rewriter.notifyMatchFailure( - op, "slots=8 group_slot_load expects one physical result"); - auto resultType = dyn_cast(resultTypes.front()); - if (!resultType) - return rewriter.notifyMatchFailure( - op, "group_slot_load result must be vreg"); - FailureOr maskType = - getMaskTypeForVReg(resultType, rewriter.getContext()); - if (failed(maskType)) - return rewriter.notifyMatchFailure( - op, "unsupported element type for group_slot_load mask"); - FailureOr oneBlockMask = - createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); - if (failed(oneBlockMask)) - return rewriter.notifyMatchFailure( - op, "failed to create group_slot_load mask"); - Value slotBase = makePtr(*offset); - results.push_back(rewriter - .create(op.getLoc(), resultType, slotBase, - zeroI16, zeroI16, *oneBlockMask) - .getResult()); + for (auto [chunk, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + int64_t groupBegin = static_cast(chunk) * slots; + int64_t activeGroups = std::min(slots, numGroups - groupBegin); + if (activeGroups <= 0) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load has no active groups for chunk"); + std::string pattern = (Twine("PAT_VL") + Twine(activeGroups)).str(); + FailureOr slotMask = + createPrefixMask(op.getLoc(), *maskType, pattern, rewriter); + if (failed(slotMask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_slot_load mask"); + Value groupOffset = + createChunkOffset(op.getLoc(), *offset, groupBegin, rewriter); + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op.getLoc(), vregType, slotBase, + zeroI16, zeroI16, *slotMask) + .getResult()); + } rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } @@ -4247,10 +4528,17 @@ struct OneToNVMIGatherOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "gather physical part type mismatch"); - Value gathered = rewriter - .create(op.getLoc(), resultType, - *source, indices, mask) - .getResult(); + unsigned resultBits = pto::getPTOStorageElemBitWidth( + cast(resultType).getElementType()); + Value gathered = resultBits == 16 + ? rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult() + : rewriter + .create(op.getLoc(), resultType, + *source, indices, mask) + .getResult(); results.push_back( rewriter .create(op.getLoc(), resultType, gathered, passthru, mask) @@ -4497,16 +4785,72 @@ struct OneToNVMIGroupStoreOpPattern if (elementBits == 0 || 256 % elementBits != 0) return rewriter.notifyMatchFailure( op, "slots=1 group_store requires supported element width"); - int64_t alignedStrideElems = 256 / elementBits; std::optional constantRowStride = getConstantIndexValue(op.getRowStride()); - if (!constantRowStride || *constantRowStride <= 0 || - *constantRowStride % alignedStrideElems != 0) + FailureOr lanesPerPart = + getDataLanesPerPart(valueVMIType.getElementType()); + int64_t alignedStoreElems = 256 / elementBits; + if (constantRowStride && *constantRowStride == 1 && + succeeded(lanesPerPart) && layout.getNumGroups() <= *lanesPerPart && + isKnownIndexMultipleOf(op.getOffset(), alignedStoreElems)) { + auto firstType = dyn_cast(valueParts.front().getType()); + if (!firstType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(firstType, rewriter.getContext()); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstType, rewriter); + if (failed(maskType) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for packed group_store mask"); + + Value packed = + rewriter + .create(op.getLoc(), firstType, valueParts.front(), + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + for (int64_t group = 1; group < layout.getNumGroups(); ++group) { + auto vregType = dyn_cast(valueParts[group].getType()); + if (!vregType || vregType != firstType) + return rewriter.notifyMatchFailure( + op, "packed group_store requires uniform vreg parts"); + Value splat = + rewriter + .create(op.getLoc(), firstType, valueParts[group], + *allMask, rewriter.getStringAttr("LOWEST")) + .getResult(); + FailureOr laneMask = createLaneRangeMask( + op.getLoc(), *maskType, group, group + 1, rewriter); + if (failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane mask"); + packed = rewriter + .create(op.getLoc(), firstType, splat, packed, + *laneMask) + .getResult(); + } + + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, layout.getNumGroups(), rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store store mask"); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, packed, *destination, + *offset, /*dist=*/nullptr, *storeMask); + rewriter.eraseOp(op); + return success(); + } + if (constantRowStride && *constantRowStride <= 0) return rewriter.notifyMatchFailure( - op, Twine("slots=1 group_store requires constant positive " - "row_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B lane-0 vsts alignment"); + op, "slots=1 group_store requires positive row_stride when " + "row_stride is constant"); + std::optional pointDist = + getPointStoreDistToken(valueVMIType.getElementType()); + if (!pointDist) + return rewriter.notifyMatchFailure( + op, "slots=1 group_store requires 1PT_B8/B16/B32 store support"); for (auto [group, value] : llvm::enumerate(valueParts)) { auto vregType = dyn_cast(value.getType()); @@ -4528,7 +4872,8 @@ struct OneToNVMIGroupStoreOpPattern /*chunkLaneOffset=*/0, rewriter); rewriter.create(op.getLoc(), /*updated_base=*/Type{}, value, *destination, - groupOffset, /*dist=*/nullptr, *mask); + groupOffset, rewriter.getStringAttr(*pointDist), + *mask); } rewriter.eraseOp(op); @@ -4550,6 +4895,133 @@ struct OneToNVMIGroupStoreOpPattern return rewriter.notifyMatchFailure( op, "slots=8 group_store arity mismatch"); + if (!valueParts.empty()) { + auto firstVRegType = dyn_cast(valueParts.front().getType()); + if (!firstVRegType) + return rewriter.notifyMatchFailure(op, + "group_store value must be vreg"); + bool packedByteStore = isPackedByteGroupStore( + op.getDestination().getType(), firstVRegType); + if (packedByteStore) { + bool sparsePackedByteStore = layout.hasSparseFactor(); + for (Value value : valueParts) { + auto vregType = dyn_cast(value.getType()); + if (!vregType || vregType != firstVRegType) + return rewriter.notifyMatchFailure( + op, "packed slots=8 group_store requires uniform vreg parts"); + } + + FailureOr maskType = + getMaskTypeForVReg(firstVRegType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for packed group_store mask"); + if (!sparsePackedByteStore && numGroups == 8 && + valueParts.size() == 1 && + isKnownIndexMultipleOf(*offset, 32)) { + MLIRContext *ctx = rewriter.getContext(); + auto ui16 = IntegerType::get( + ctx, 16, IntegerType::SignednessSemantics::Unsigned); + auto ui8 = IntegerType::get( + ctx, 8, IntegerType::SignednessSemantics::Unsigned); + auto packed16Type = VRegType::get(ctx, 128, ui16); + auto packed8Type = VRegType::get(ctx, 256, ui8); + Value packed16 = + rewriter + .create(op.getLoc(), packed16Type, + valueParts.front(), + rewriter.getStringAttr("LOWER")) + .getResult(); + Value packed8 = + rewriter + .create(op.getLoc(), packed8Type, packed16, + rewriter.getStringAttr("LOWER")) + .getResult(); + FailureOr packedMaskType = + getMaskTypeForVReg(packed8Type, ctx); + if (failed(packedMaskType)) + return rewriter.notifyMatchFailure( + op, "failed to create packed byte group_store mask type"); + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *packedMaskType, numGroups, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed byte group_store mask"); + rewriter.create( + op.getLoc(), /*updated_base=*/Type{}, packed8, *destination, + *offset, rewriter.getStringAttr("NORM_B8"), *storeMask); + rewriter.eraseOp(op); + return success(); + } + + auto indexElementType = IntegerType::get( + rewriter.getContext(), + pto::getPTOStorageElemBitWidth(firstVRegType.getElementType())); + auto indexType = + VRegType::get(rewriter.getContext(), + firstVRegType.getElementCount(), indexElementType); + FailureOr slotIndex = createGroupSlotIndexVector( + op.getLoc(), indexType, /*groupSize=*/8, /*baseGroupSlot=*/0, + rewriter); + FailureOr allMask = + createAllTrueMaskForVReg(op.getLoc(), firstVRegType, rewriter); + if (failed(slotIndex) || failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane selector"); + + for (int64_t blockStart = 0; blockStart < numGroups; + blockStart += 32) { + FailureOr zero = + createZeroVector(op.getLoc(), firstVRegType, rewriter); + if (failed(zero)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store accumulator"); + Value merged = *zero; + for (int64_t localPart = 0; localPart < 4; ++localPart) { + int64_t partIndex = blockStart / 8 + localPart; + if (partIndex >= static_cast(valueParts.size())) + break; + int64_t remainingGroups = numGroups - partIndex * 8; + int64_t activeGroups = std::min(8, remainingGroups); + if (activeGroups <= 0) + break; + Value selected = + rewriter + .create(op.getLoc(), firstVRegType, + valueParts[partIndex], *slotIndex) + .getResult(); + FailureOr laneMask = + createLaneRangeMask(op.getLoc(), *maskType, localPart * 8, + localPart * 8 + activeGroups, rewriter); + if (failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store lane mask"); + merged = rewriter + .create(op.getLoc(), firstVRegType, selected, + merged, *laneMask) + .getResult(); + } + + int64_t activeGroups = + std::min(32, numGroups - blockStart); + FailureOr storeMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, activeGroups, rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to create packed group_store store mask"); + Value groupOffset = createGroupChunkOffset( + op.getLoc(), *offset, *rowStride, blockStart / 4, + /*chunkLaneOffset=*/0, rewriter); + rewriter.create( + op.getLoc(), /*updated_base=*/Type{}, merged, *destination, + groupOffset, rewriter.getStringAttr("PK4_B32"), *storeMask); + } + + rewriter.eraseOp(op); + return success(); + } + } + for (auto [slotBlock, value] : llvm::enumerate(valueParts)) { auto vregType = dyn_cast(value.getType()); if (!vregType) @@ -4705,217 +5177,126 @@ struct OneToNVMIMaskedStoreOpPattern } }; -struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; +struct OneToNVMIStrideLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, + matchAndRewrite(VMIStrideLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + FailureOr source = getSingleValue( + op, adaptor.getSource(), "stride_load source must convert to one value", + rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), "stride_load offset must convert to one value", + rewriter); + FailureOr blockStride = getSingleValue( + op, adaptor.getBlockStride(), + "stride_load block_stride must convert to one value", rewriter); + FailureOr repeatStride = getSingleValue( + op, adaptor.getRepeatStride(), + "stride_load repeat_stride must convert to one value", rewriter); + if (failed(source) || failed(offset) || failed(blockStride) || + failed(repeatStride)) + return failure(); + + ValueRange maskParts = adaptor.getMask(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (resultTypes.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "stride_load supports one physical result/mask chunk"); + auto resultType = dyn_cast(resultTypes.front()); + if (!resultType || !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "stride_load requires physical vreg/mask parts"); + + Value base = rewriter + .create(op.getLoc(), (*source).getType(), + *source, *offset) + .getResult(); + Value loaded = + rewriter + .create(op.getLoc(), resultType, base, *blockStride, + *repeatStride, maskParts.front()) + .getResult(); + rewriter.replaceOp(op, SmallVector{loaded}, + adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMIStrideStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIStrideStoreOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { FailureOr destination = getSingleValue( op, adaptor.getDestination(), - "scatter destination must convert to one value", rewriter); - if (failed(destination)) + "stride_store destination must convert to one value", rewriter); + FailureOr offset = getSingleValue( + op, adaptor.getOffset(), + "stride_store offset must convert to one value", rewriter); + FailureOr blockStride = getSingleValue( + op, adaptor.getBlockStride(), + "stride_store block_stride must convert to one value", rewriter); + FailureOr repeatStride = getSingleValue( + op, adaptor.getRepeatStride(), + "stride_store repeat_stride must convert to one value", rewriter); + if (failed(destination) || failed(offset) || failed(blockStride) || + failed(repeatStride)) return failure(); ValueRange valueParts = adaptor.getValue(); - ValueRange indicesParts = adaptor.getIndices(); ValueRange maskParts = adaptor.getMask(); - if (valueParts.size() != indicesParts.size() || - valueParts.size() != maskParts.size()) - return rewriter.notifyMatchFailure(op, "scatter physical arity mismatch"); - - for (auto [value, indices, mask] : - llvm::zip_equal(valueParts, indicesParts, maskParts)) { - if (!isa(value.getType()) || - !isa(indices.getType()) || !isa(mask.getType())) - return rewriter.notifyMatchFailure( - op, "scatter physical part type mismatch"); - rewriter.create(op.getLoc(), value, *destination, indices, - mask); - } + if (valueParts.size() != 1 || maskParts.size() != 1) + return rewriter.notifyMatchFailure( + op, "stride_store supports one physical value/mask chunk"); + if (!isa(valueParts.front().getType()) || + !isa(maskParts.front().getType())) + return rewriter.notifyMatchFailure( + op, "stride_store requires physical vreg/mask parts"); + Value base = rewriter + .create(op.getLoc(), (*destination).getType(), + *destination, *offset) + .getResult(); + rewriter.create(op.getLoc(), base.getType(), valueParts.front(), + base, *blockStride, *repeatStride, + maskParts.front()); rewriter.eraseOp(op); return success(); } }; -struct OneToNVMITileReadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(VMITileReadOp op, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { - auto resultVMIType = cast(op.getResult().getType()); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "tile_read source must convert to one value", rewriter); - if (failed(source)) - return failure(); - - Value zero = rewriter.create(op.getLoc(), 0); - FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( - op, resultVMIType, (*source).getType(), zero, rewriter); - if (failed(lanesPerPart)) - return failure(); - - TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); - if (resultLayout && resultLayout.isDeinterleaved() && - resultLayout.getFactor() == 2) { - std::optional dist = - getX2MemoryDistToken(resultVMIType.getElementType(), "DINTLV"); - if (dist && !resultTypes.empty() && resultTypes.size() % 2 == 0) { - int64_t groups = resultTypes.size() / 2; - SmallVector lows; - SmallVector highs; - lows.reserve(groups); - highs.reserve(groups); - for (int64_t group = 0; group < groups; ++group) { - Type lowType = resultTypes[group]; - Type highType = resultTypes[groups + group]; - if (lowType != highType) - return rewriter.notifyMatchFailure( - op, "vldsx2 requires matching low/high result types"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); - auto load = rewriter.create(op.getLoc(), lowType, highType, - *source, chunkOffset, - rewriter.getStringAttr(*dist)); - lows.push_back(load.getLow()); - highs.push_back(load.getHigh()); - } - SmallVector results; - results.reserve(resultTypes.size()); - results.append(lows); - results.append(highs); - rewriter.replaceOp(op, results, adaptor.getResultMapping()); - return success(); - } - } - - SmallVector contiguousParts; - contiguousParts.reserve(resultTypes.size()); - for (auto [index, resultType] : llvm::enumerate(resultTypes)) { - auto vregType = dyn_cast(resultType); - if (!vregType) - return rewriter.notifyMatchFailure(op, "tile_read result must be vreg"); - Value chunkOffset = - createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); - contiguousParts.push_back(rewriter - .create(op.getLoc(), resultType, - /*updated_base=*/Type{}, - *source, chunkOffset, - /*dist=*/nullptr) - .getResult()); - } - - FailureOr> results = materializeDataLayoutConversion( - op, contiguousParts, resultTypes, - VMILayoutAttr::getContiguous(rewriter.getContext()), - resultVMIType.getLayoutAttr(), rewriter); - if (failed(results)) - return failure(); - - rewriter.replaceOp(op, *results, adaptor.getResultMapping()); - return success(); - } -}; - -struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern::OneToNOpConversionPattern; +struct OneToNVMIScatterOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(VMITileWriteOp op, OpAdaptor adaptor, + matchAndRewrite(VMIScatterOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { - auto valueVMIType = cast(op.getValue().getType()); - FailureOr lanesPerPart = - getDataLanesPerPart(valueVMIType.getElementType()); - if (failed(lanesPerPart)) - return rewriter.notifyMatchFailure( - op, "tile_write requires known physical lanes per part"); - bool fullPhysicalChunks = - succeeded(checkFullDataPhysicalChunks(valueVMIType, nullptr)); FailureOr destination = getSingleValue( op, adaptor.getDestination(), - "tile_write destination must convert to one value", rewriter); + "scatter destination must convert to one value", rewriter); if (failed(destination)) return failure(); ValueRange valueParts = adaptor.getValue(); - Value zero = rewriter.create(op.getLoc(), 0); - VMILayoutSupport localSupports; - FailureOr storeSupport = - localSupports.getContiguousStoreSupport(valueVMIType); - if (succeeded(storeSupport) && - storeSupport->kind == - VMIContiguousStoreSupportKind::Deinterleaved2Vstsx2) { - std::optional dist = - getX2MemoryDistToken(valueVMIType.getElementType(), "INTLV"); - if (dist && !valueParts.empty() && valueParts.size() % 2 == 0) { - int64_t groups = valueParts.size() / 2; - for (int64_t group = 0; group < groups; ++group) { - Value low = valueParts[group]; - Value high = valueParts[groups + group]; - if (low.getType() != high.getType()) - return rewriter.notifyMatchFailure( - op, "vstsx2 requires matching low/high value types"); - auto vregType = dyn_cast(low.getType()); - if (!vregType) - return rewriter.notifyMatchFailure(op, - "tile_write value must be vreg"); - FailureOr mask = - createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); - if (failed(mask)) - return rewriter.notifyMatchFailure( - op, "unsupported element type for tile_write mask"); - Value chunkOffset = createChunkOffset( - op.getLoc(), zero, group * 2 * *lanesPerPart, rewriter); - rewriter.create(op.getLoc(), low, high, *destination, - chunkOffset, rewriter.getStringAttr(*dist), - *mask); - } - rewriter.eraseOp(op); - return success(); - } - } - - SmallVector contiguousTypes; - contiguousTypes.reserve(valueParts.size()); - for (Value value : valueParts) - contiguousTypes.push_back(value.getType()); - - FailureOr> storeParts = materializeDataLayoutConversion( - op, valueParts, contiguousTypes, valueVMIType.getLayoutAttr(), - VMILayoutAttr::getContiguous(rewriter.getContext()), rewriter); - if (failed(storeParts)) - return failure(); + ValueRange indicesParts = adaptor.getIndices(); + ValueRange maskParts = adaptor.getMask(); + if (valueParts.size() != indicesParts.size() || + valueParts.size() != maskParts.size()) + return rewriter.notifyMatchFailure(op, "scatter physical arity mismatch"); - for (auto [index, value] : llvm::enumerate(*storeParts)) { - auto vregType = dyn_cast(value.getType()); - if (!vregType) - return rewriter.notifyMatchFailure(op, "tile_write value must be vreg"); - if (!fullPhysicalChunks) { - FailureOr activeLanes = - getContiguousActiveDataLanes(valueVMIType, index); - if (failed(activeLanes)) - return rewriter.notifyMatchFailure( - op, "failed to compute tile_write active lanes"); - if (*activeLanes == 0) - continue; - } - FailureOr mask = - fullPhysicalChunks - ? createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter) - : createContiguousStoreMask(op.getLoc(), valueVMIType, index, - vregType, rewriter); - if (failed(mask)) + for (auto [value, indices, mask] : + llvm::zip_equal(valueParts, indicesParts, maskParts)) { + if (!isa(value.getType()) || + !isa(indices.getType()) || !isa(mask.getType())) return rewriter.notifyMatchFailure( - op, "unsupported element type for tile_write mask"); - Value chunkOffset = - createChunkOffset(op.getLoc(), zero, index * *lanesPerPart, rewriter); - rewriter.create(op.getLoc(), - /*updated_base=*/Type{}, value, *destination, - chunkOffset, /*dist=*/nullptr, *mask); + op, "scatter physical part type mismatch"); + rewriter.create(op.getLoc(), value, *destination, indices, + mask); } rewriter.eraseOp(op); @@ -4923,6 +5304,7 @@ struct OneToNVMITileWriteOpPattern : OneToNOpConversionPattern { } }; + template struct OneToNVMIBinaryOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -5699,6 +6081,12 @@ struct OneToNVMIGroupReduceOpPattern : OneToNOpConversionPattern { .create(op.getLoc(), resultType, sourceParts[index], maskParts[index]) .getResult(); + FailureOr lane0Reduced = reduceVcgSlotsToLane0( + op.getLoc(), reduced, resultType, *firstLaneMask, rewriter); + if (failed(lane0Reduced)) + return rewriter.notifyMatchFailure( + op, "failed to fold group_reduce_addf VLane partials"); + reduced = *lane0Reduced; if (!accumulator) { accumulator = reduced; continue; @@ -5734,6 +6122,12 @@ struct OneToNVMIGroupReduceOpPattern : OneToNOpConversionPattern { return supports.getGroupReduceAddISupport(capabilities, op, reason); } + FailureOr getSupport(VMILayoutSupport &supports, + VMIGroupReduceMaxIOp op, + std::string *reason) const { + return supports.getGroupReduceMaxISupport(capabilities, op, reason); + } + FailureOr getSupport(VMILayoutSupport &supports, VMIGroupReduceMaxFOp op, std::string *reason) const { @@ -5827,6 +6221,7 @@ struct OneToNVMIGroupBroadcastOpPattern op, "group_broadcast requires uniform physical vreg types"); int64_t sourceChunk = flatIndex; int64_t baseGroupSlot = 0; + Value mappedGroupSlotIndex; if (resultLayoutFactor == 1) { if (*groupSize >= lanesPerPart) { int64_t chunksPerGroup = *groupSize / lanesPerPart; @@ -5853,6 +6248,9 @@ struct OneToNVMIGroupBroadcastOpPattern bool blockFragmentSmallGroup = resultLayout && resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + bool deinterleavedSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1 && *groupSize < lanesPerPart; if (blockFragmentSmallGroup) { int64_t runningFlatIndex = 0; bool found = false; @@ -5888,6 +6286,45 @@ struct OneToNVMIGroupBroadcastOpPattern if (!found) return rewriter.notifyMatchFailure( op, "group_broadcast result chunk index is out of range"); + } else if (deinterleavedSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = + getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast deinterleaved small-group source " + "requires explicit group_slots slots or derivable " + "legacy slot count"); + slots = groupCount / sourceParts.size(); + } + FailureOr index = createMappedGroupSlotIndexVector( + op.getLoc(), resultVMIType, part, chunk, indexType, + *groupSize, slots, sourceChunk, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, + "failed to create group_broadcast mapped group-slot index " + "vector"); + mappedGroupSlotIndex = *index; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); } else { int64_t runningFlatIndex = 0; bool found = false; @@ -5954,7 +6391,11 @@ struct OneToNVMIGroupBroadcastOpPattern bool blockFragmentSmallGroup = resultLayout && resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1; - if (resultLayoutFactor != 1 && !blockFragmentSmallGroup) + bool deinterleavedSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup && + !deinterleavedSmallGroup) return rewriter.notifyMatchFailure( op, "group_broadcast small-group deinterleaved result is not " "supported"); @@ -5962,9 +6403,12 @@ struct OneToNVMIGroupBroadcastOpPattern sourceChunk >= static_cast(sourceParts.size())) return rewriter.notifyMatchFailure( op, "group_broadcast source chunk is out of range"); - FailureOr groupSlotIndex = createGroupSlotIndexVector( - op.getLoc(), indexType, selectionGroupSize, baseGroupSlot, - rewriter); + FailureOr groupSlotIndex = + mappedGroupSlotIndex + ? FailureOr(mappedGroupSlotIndex) + : createGroupSlotIndexVector(op.getLoc(), indexType, + selectionGroupSize, baseGroupSlot, + rewriter); if (failed(groupSlotIndex)) return rewriter.notifyMatchFailure( op, "failed to create group_broadcast group-slot index vector"); @@ -6240,15 +6684,11 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { return success(); } - if ((sourceParts.size() != 2 && sourceParts.size() != 4) || - resultTypes.size() != 1) - return rewriter.notifyMatchFailure( - op, "only f32 deinterleaved=2/4 to 16/8-bit contiguous truncf is " - "supported"); + if (resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "truncf requires result chunks"); auto sourceType0 = dyn_cast(sourceParts.front().getType()); - auto resultType = dyn_cast(resultTypes.front()); - if (!sourceType0 || !sourceType0.getElementType().isF32() || !resultType) + if (!sourceType0 || !sourceType0.getElementType().isF32()) return rewriter.notifyMatchFailure( op, "unsupported physical truncf source/result type"); for (Value sourcePart : sourceParts) { @@ -6258,15 +6698,32 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { op, "truncf source physical parts must have matching f32 type"); } - unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); + SmallVector resultVRegTypes; + resultVRegTypes.reserve(resultTypes.size()); + for (Type physicalResultType : resultTypes) { + auto resultType = dyn_cast(physicalResultType); + if (!resultType || + (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( + resultType.getElementType()) == 0 + : resultType != resultVRegTypes.front())) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf result type"); + resultVRegTypes.push_back(resultType); + } + + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultVRegTypes.front().getElementType()); ArrayRef parts; - if (sourceParts.size() == 2 && resultBits == 16) { + int64_t factor = 0; + if (resultBits == 16 && sourceParts.size() == 2 * resultTypes.size()) { static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; parts = kEvenOddParts; - } else if (sourceParts.size() == 4 && resultBits == 8) { + factor = 2; + } else if (resultBits == 8 && + sourceParts.size() == 4 * resultTypes.size()) { static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; parts = kPacked4Parts; + factor = 4; } else { return rewriter.notifyMatchFailure( op, "unsupported physical truncf source/result width relation"); @@ -6274,31 +6731,44 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { FailureOr sourceMask = createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); - FailureOr resultMask = - createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); - if (failed(sourceMask) || failed(resultMask)) + if (failed(sourceMask)) return rewriter.notifyMatchFailure(op, "failed to build truncf masks"); - StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr rnd = rewriter.getStringAttr( + getTruncFRoundMode(op, resultVRegTypes.front().getElementType())); StringAttr sat = rewriter.getStringAttr("SAT"); - SmallVector partials; - partials.reserve(parts.size()); - for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { - partials.push_back(rewriter - .create(op.getLoc(), resultType, - sourcePart, *sourceMask, rnd, sat, - rewriter.getStringAttr(part)) - .getResult()); - } + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [chunkIndex, resultType] : llvm::enumerate(resultVRegTypes)) { + FailureOr resultMask = + createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + if (failed(resultMask)) + return rewriter.notifyMatchFailure( + op, "failed to build truncf result mask"); + + SmallVector partials; + partials.reserve(parts.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + Value sourcePart = + sourceParts[partIndex * resultTypes.size() + chunkIndex]; + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, rnd, sat, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } - Value merged = partials.front(); - for (Value partial : llvm::drop_begin(partials)) - merged = rewriter - .create(op.getLoc(), resultType, merged, partial, - *resultMask) - .getResult(); + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + results.push_back(merged); + } - rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } }; @@ -6311,6 +6781,8 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { matchAndRewrite(OpT op, typename OneToNOpConversionPattern::OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.empty()) @@ -6329,6 +6801,123 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { "type"); } + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && + resultLayout.isGroupSlots()) { + if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || + sourceLayout.getSlots() != 8 || resultLayout.getSlots() != 8 || + sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot integer extension shape"); + + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()); + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if ((sourceBits != 8 && sourceBits != 16) || resultBits != 32) + return rewriter.notifyMatchFailure( + op, "group-slot integer extension requires 8/16-bit source and " + "32-bit result element widths"); + + FailureOr maskType = + getMaskTypeForVReg(sourceType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "failed to create group-slot integer extension mask type"); + FailureOr slotMask = createPrefixMaskForActiveLanes( + op.getLoc(), *maskType, sourceLayout.getSlots(), rewriter); + if (failed(slotMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension mask"); + + SmallVector partNames; + int64_t partFactor = 0; + if (sourceBits == 16) { + partNames.assign({"EVEN", "ODD"}); + partFactor = 2; + } else { + partNames.assign({"P0", "P1", "P2", "P3"}); + partFactor = 4; + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [chunkIndex, sourcePart, resultType] : + llvm::enumerate(sourceParts, resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || pto::getPTOStorageElemBitWidth( + resultVRegType.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot integer extension result type"); + + SmallVector convertedParts; + convertedParts.reserve(partNames.size()); + for (StringRef partName : partNames) { + convertedParts.push_back( + rewriter + .create(op.getLoc(), resultVRegType, sourcePart, + *slotMask, /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(partName)) + .getResult()); + } + + FailureOr resultMaskType = + getMaskTypeForVReg(resultVRegType, rewriter.getContext()); + FailureOr resultAllMask = + createAllTrueMaskForVReg(op.getLoc(), resultVRegType, rewriter); + if (failed(resultMaskType) || failed(resultAllMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension result seed"); + + auto indexType = VRegType::get( + rewriter.getContext(), resultVRegType.getElementCount(), + IntegerType::get(rewriter.getContext(), 32)); + int64_t groupBegin = + static_cast(chunkIndex) * sourceLayout.getSlots(); + int64_t activeSlots = std::min( + sourceLayout.getSlots(), sourceLayout.getNumGroups() - groupBegin); + if (activeSlots <= 0) + return rewriter.notifyMatchFailure( + op, "group-slot integer extension has no active slots"); + Value assembled; + for (int64_t slot = 0; slot < activeSlots; ++slot) { + int64_t partIndex = slot % partFactor; + int64_t sourceLane = slot / partFactor; + FailureOr laneIndexScalar = createScalarOffsetConstant( + op.getLoc(), indexType.getElementType(), sourceLane, rewriter); + FailureOr laneMask = createLaneRangeMask( + op.getLoc(), *resultMaskType, slot, slot + 1, rewriter); + if (failed(laneIndexScalar) || failed(laneMask)) + return rewriter.notifyMatchFailure( + op, "failed to build group-slot integer extension slot mask"); + Value laneIndex = + rewriter + .create(op.getLoc(), indexType, *laneIndexScalar, + *resultAllMask, /*position=*/nullptr) + .getResult(); + Value selected = + rewriter + .create(op.getLoc(), resultVRegType, + convertedParts[partIndex], laneIndex) + .getResult(); + if (!assembled) { + assembled = selected; + continue; + } + assembled = rewriter + .create(op.getLoc(), resultVRegType, selected, + assembled, *laneMask) + .getResult(); + } + + results.push_back(assembled); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + SmallVector resultVRegTypes; resultVRegTypes.reserve(resultTypes.size()); for (Type resultType : resultTypes) { @@ -6404,11 +6993,14 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { if (sourceLayout && resultLayout && sourceLayout.isGroupSlots() && resultLayout.isGroupSlots()) { if (sourceLayout.getNumGroups() != resultLayout.getNumGroups() || - sourceLayout.getSlots() != 1 || resultLayout.getSlots() != 1 || + sourceLayout.getSlots() != resultLayout.getSlots() || + (sourceLayout.getSlots() != 1 && sourceLayout.getSlots() != 8) || pto::getPTOStorageElemBitWidth(sourceVMIType.getElementType()) != 32 || - pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != - 16 || + (pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 16 && + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) != + 8) || sourceParts.size() != resultTypes.size()) return rewriter.notifyMatchFailure( op, "unsupported group-slot trunci shape"); @@ -6416,27 +7008,54 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { SmallVector results; results.reserve(resultTypes.size()); StringAttr sat = rewriter.getStringAttr("SAT"); - StringAttr even = rewriter.getStringAttr("EVEN"); - FailureOr lane0Mask = createPrefixMask( - op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), "PAT_VL1", - rewriter); - if (failed(lane0Mask)) + const char *activeSlotPattern = + sourceLayout.getSlots() == 1 ? "PAT_VL1" : "PAT_VL8"; + FailureOr activeSlotMask = createPrefixMask( + op.getLoc(), MaskType::get(rewriter.getContext(), "b32"), + activeSlotPattern, rewriter); + if (failed(activeSlotMask)) return rewriter.notifyMatchFailure( - op, "failed to build group-slot trunci lane0 mask"); + op, "failed to build group-slot trunci active slot mask"); for (auto [sourcePart, physicalResultType] : llvm::zip_equal(sourceParts, resultTypes)) { auto sourceType = dyn_cast(sourcePart.getType()); auto resultType = dyn_cast(physicalResultType); if (!sourceType || pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || - !resultType || - pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 16) + !resultType) return rewriter.notifyMatchFailure( op, "unsupported group-slot trunci physical type"); + + unsigned physicalResultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (resultLayout.hasSparseFactor() && + resultLayout.getSparseFactor() == 4 && + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) == + 8 && + physicalResultBits == 32) { + if (sourcePart.getType() == resultType) { + results.push_back(sourcePart); + } else { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart) + .getResult()); + } + continue; + } + + if (physicalResultBits != 16 && physicalResultBits != 8) + return rewriter.notifyMatchFailure( + op, "unsupported group-slot trunci physical type"); + + StringAttr part = + physicalResultBits == 16 + ? rewriter.getStringAttr("EVEN") + : rewriter.getStringAttr("P0"); results.push_back(rewriter .create(op.getLoc(), resultType, - sourcePart, *lane0Mask, - /*rnd=*/nullptr, sat, even) + sourcePart, *activeSlotMask, + /*rnd=*/nullptr, sat, part) .getResult()); } rewriter.replaceOp(op, results, adaptor.getResultMapping()); @@ -6511,6 +7130,92 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIFPToSIOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIFPToSIOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "fptosi physical source/result arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + StringAttr sat = rewriter.getStringAttr("SAT"); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceType || !sourceType.getElementType().isF32() || + !resultVRegType || + !isa(resultVRegType.getElementType()) || + pto::getPTOStorageElemBitWidth(resultVRegType.getElementType()) != 32) + return rewriter.notifyMatchFailure( + op, "fptosi requires physical f32 source and 32-bit integer " + "result chunks"); + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build fptosi mask"); + results.push_back(rewriter + .create(op.getLoc(), resultVRegType, + sourcePart, *mask, rnd, sat, + /*part=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + +struct OneToNVMISIToFPOpPattern : OneToNOpConversionPattern { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMISIToFPOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + ValueRange sourceParts = adaptor.getSource(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (sourceParts.size() != resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "sitofp physical source/result arity mismatch"); + + SmallVector results; + results.reserve(resultTypes.size()); + StringAttr rnd = rewriter.getStringAttr("R"); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + auto sourceType = dyn_cast(sourcePart.getType()); + auto resultVRegType = dyn_cast(resultType); + if (!sourceType || !isa(sourceType.getElementType()) || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32 || + !resultVRegType || !resultVRegType.getElementType().isF32()) + return rewriter.notifyMatchFailure( + op, "sitofp requires physical 32-bit integer source and f32 " + "result chunks"); + + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build sitofp mask"); + results.push_back(rewriter + .create(op.getLoc(), resultVRegType, + sourcePart, *mask, rnd, + /*sat=*/nullptr, /*part=*/nullptr) + .getResult()); + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIBitcastOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -6968,11 +7673,12 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, OneToNVMIGroupLoadOpPattern, OneToNVMIGroupSlotLoadOpPattern, + OneToNVMIStrideLoadOpPattern, OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, - OneToNVMIScatterOpPattern, OneToNVMITileReadOpPattern, - OneToNVMITileWriteOpPattern, OneToNVMIBinaryOpPattern, + OneToNVMIStrideStoreOpPattern, OneToNVMIScatterOpPattern, + OneToNVMIBinaryOpPattern, OneToNVMIBinaryOpPattern, OneToNVMIBinaryOpPattern, OneToNVMIBinaryOpPattern, @@ -7003,12 +7709,14 @@ void populateVMIOneToNConversionPatterns( OneToNVMIReduceMinMaxFOpPattern, OneToNVMIExtFOpPattern, OneToNVMITruncFOpPattern, OneToNVMIExtIOpPattern, OneToNVMIExtIOpPattern, - OneToNVMITruncIOpPattern, OneToNVMIBitcastOpPattern, + OneToNVMITruncIOpPattern, OneToNVMIFPToSIOpPattern, + OneToNVMISIToFPOpPattern, OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>(typeConverter, patterns.getContext()); patterns.add< OneToNVMIGroupReduceOpPattern, OneToNVMIGroupReduceOpPattern, + OneToNVMIGroupReduceOpPattern, OneToNVMIGroupReduceOpPattern>( typeConverter, patterns.getContext(), capabilities); patterns.add( @@ -7091,6 +7799,64 @@ LogicalResult checkSupportedTruncIShape(VMITruncIOp op, return success(); } +LogicalResult checkSupportedFPToSIShape(VMIFPToSIOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout != resultLayout) + return fail("requires source/result layouts to match"); + if (!sourceType.getElementType().isF32()) + return fail("requires f32 source element type"); + if (!isa(resultType.getElementType()) || + pto::getPTOStorageElemBitWidth(resultType.getElementType()) != 32) + return fail("requires 32-bit integer result element type"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity) || + *sourceArity != *resultArity) + return fail("requires matching computable physical arity"); + return success(); +} + +LogicalResult checkSupportedSIToFPShape(VMISIToFPOp op, + std::string *reason = nullptr) { + auto fail = [&](const Twine &message) { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto sourceType = cast(op.getSource().getType()); + auto resultType = cast(op.getResult().getType()); + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + if (sourceLayout != resultLayout) + return fail("requires source/result layouts to match"); + if (!isa(sourceType.getElementType()) || + pto::getPTOStorageElemBitWidth(sourceType.getElementType()) != 32) + return fail("requires 32-bit integer source element type"); + if (!resultType.getElementType().isF32()) + return fail("requires f32 result element type"); + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity) || + *sourceArity != *resultArity) + return fail("requires matching computable physical arity"); + return success(); +} + LogicalResult checkSupportedBitcastShape(VMIBitcastOp op, std::string *reason) { VMILayoutSupport supports; if (failed(supports.getBitcastSupport(op, reason))) @@ -7400,6 +8166,9 @@ checkSupportedGroupReduceShape(const VMITargetCapabilityRegistry &capabilities, } else if constexpr (std::is_same_v) { if (succeeded(supports.getGroupReduceMaxFSupport(capabilities, op, reason))) return success(); + } else if constexpr (std::is_same_v) { + if (succeeded(supports.getGroupReduceMaxISupport(capabilities, op, reason))) + return success(); } else { if (succeeded(supports.getGroupReduceAddISupport(capabilities, op, reason))) return success(); @@ -7686,6 +8455,17 @@ verifySupportedVMIToVPTOOps(ModuleOp module, op, "pto.vmi.load", cast(load.getResult().getType()), load.getSource(), getConstantIndexValue(load.getOffset())); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedStrideLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.stride_load lowers through pto.vsldb only for one " + "contiguous physical result/mask chunk and a supported UB source (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedGroupLoadShape(capabilities, load, &reason))) @@ -7800,39 +8580,31 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << reason << ")"; return WalkResult::interrupt(); } - if (auto scatter = dyn_cast(op)) { + if (auto store = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedScatterShape(capabilities, scatter, &reason))) + if (succeeded( + checkSupportedStrideStoreShape(capabilities, store, &reason))) return WalkResult::advance(); - scatter.emitError() + store.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.scatter lowers through pto.vscatter only with a UB " - "pointer destination, contiguous full physical chunks, 32-bit " - "value elements, i32 indices, and b32 masks (" + << "pto.vmi.stride_store lowers through pto.vsstb only for one " + "contiguous physical value/mask chunk and a supported UB " + "destination (" << reason << ")"; return WalkResult::interrupt(); } - if (auto tileRead = dyn_cast(op)) - return emitMemoryUnsupported( - op, "pto.vmi.tile_read", - cast(tileRead.getResult().getType()), - tileRead.getSource(), 0); - if (auto tileWrite = dyn_cast(op)) { + if (auto scatter = dyn_cast(op)) { std::string reason; - if (succeeded(checkSupportedStoreShape( - capabilities, cast(tileWrite.getValue().getType()), - tileWrite.getDestination(), tileWrite.getDestination().getType(), - &reason))) + if (succeeded(checkSupportedScatterShape(capabilities, scatter, &reason))) return WalkResult::advance(); - tileWrite.emitError() + scatter.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable " - "element type and either full physical chunks or contiguous " - "tail-store layout, with UB-backed destination (" + << "pto.vmi.scatter lowers through pto.vscatter only with a UB " + "pointer destination, contiguous full physical chunks, 32-bit " + "value elements, i32 indices, and b32 masks (" << reason << ")"; return WalkResult::interrupt(); } - if (auto ensure = dyn_cast(op)) { auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); @@ -8101,6 +8873,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto reduce = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedGroupReduceShape(capabilities, reduce, &reason))) + return WalkResult::advance(); + reduce.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_reduce_maxi lowers through pto.vcgmax/vmax only " + "for i32 accumulator values; i8/i16 storage must be cast to i32 " + "before grouped reduction because narrow integer reductions " + "widen their result (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto reduce = dyn_cast(op)) { std::string reason; if (succeeded( @@ -8189,6 +8976,32 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::interrupt(); } + if (auto fptosi = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedFPToSIShape(fptosi, &reason))) + return WalkResult::advance(); + + fptosi.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.fptosi supports f32 source chunks to matching 32-bit " + "integer result chunks with identical assigned layouts (" + << reason << ")"; + return WalkResult::interrupt(); + } + + if (auto sitofp = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedSIToFPShape(sitofp, &reason))) + return WalkResult::advance(); + + sitofp.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.sitofp supports 32-bit integer source chunks to " + "matching f32 result chunks with identical assigned layouts (" + << reason << ")"; + return WalkResult::interrupt(); + } + if (auto extsi = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedExtSIShape(extsi, &reason))) @@ -8198,7 +9011,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << kVMIDiagUnsupportedPrefix << "pto.vmi.extsi supports contiguous signed/signless 8-bit or " "16-bit integer physical source chunks to 32-bit integer " - "deinterleaved=4/2 results (" + "deinterleaved=4/2 results, or matching " + "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " + "32-bit integer result (" << reason << ")"; return WalkResult::interrupt(); } @@ -8212,7 +9027,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << kVMIDiagUnsupportedPrefix << "pto.vmi.extui supports contiguous unsigned 8-bit or 16-bit " "integer physical source chunks to unsigned 32-bit integer " - "deinterleaved=4/2 results (" + "deinterleaved=4/2 results, or matching " + "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " + "32-bit integer result (" << reason << ")"; return WalkResult::interrupt(); } @@ -8228,8 +9045,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, "source parts to one contiguous 16-bit integer result chunk, " "32-bit integer deinterleaved=4 source parts to one contiguous " "8-bit integer result chunk, or 32-bit integer " - "group_slots(num_groups=G, slots=1) to 16-bit integer " - "group_slots(num_groups=G, slots=1) (" + "group_slots(num_groups=G, slots=1 or 8) to 8/16-bit integer " + "group_slots(num_groups=G, slots=1 or 8) (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_gather_indices_invalid.pto b/test/lit/vmi/vmi_gather_indices_invalid.pto index 057e3d1244..4b37624430 100644 --- a/test/lit/vmi/vmi_gather_indices_invalid.pto +++ b/test/lit/vmi/vmi_gather_indices_invalid.pto @@ -22,4 +22,4 @@ module { } } -// CHECK: 'pto.vmi.gather' op requires signless or unsigned 32-bit integer indices +// CHECK: 'pto.vmi.gather' op requires signless or unsigned 16-bit or 32-bit integer indices diff --git a/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto new file mode 100644 index 0000000000..b9416e81cf --- /dev/null +++ b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto @@ -0,0 +1,24 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment 2>&1 | FileCheck %s + +module { + func.func @direct_i8_group_reduce_max_invalid( + %source: !pto.vmi.vreg<256xi8>, + %mask: !pto.vmi.mask<256xpred>) { + %max = pto.vmi.group_reduce_maxi %source, %mask {num_groups = 8} + : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xi8> + return + } +} + +// CHECK: requires i32 accumulator element type; cast i8/i16 storage to i32 before grouped reduction because integer reduction widens narrow inputs diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto similarity index 70% rename from test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto rename to test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto index 187a79d42b..9cf2720915 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s module { - func.func @vmi_layout_assignment_group_load_block8_truncf_invalid( + func.func @vmi_layout_assignment_group_load_block8_truncf( %src: !pto.ptr, %sum_dst: !pto.ptr, %dense_dst: !pto.ptr, @@ -26,8 +26,6 @@ module { -> !pto.vmi.vreg<128xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr - // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %dense_dst[%off] @@ -35,3 +33,14 @@ module { return } } + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_load_block8_truncf( +// CHECK: pto.vsldb +// CHECK: pto.vcgadd +// CHECK: pto.vintlv +// CHECK: pto.vdintlv +// CHECK: pto.vcvt +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto index 982d1d8a28..882d0cd30e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -44,10 +44,10 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( -// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vcgadd // LOWER-COUNT-8: pto.vdup {{.*}} {position = "LOWEST"} // LOWER-COUNT-8: pto.vmul -// LOWER-COUNT-8: pto.vcadd +// LOWER-COUNT-8: pto.vcgadd // LOWER-COUNT-8: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto index 6cbedb442b..1a68df9d86 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -34,7 +34,7 @@ module { // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( // LOWER-COUNT-6: pto.vlds -// LOWER-COUNT-6: pto.vcadd +// LOWER-COUNT-6: pto.vcgadd // LOWER-COUNT-6: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto index cf46aa5870..d97210bc7b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -39,7 +39,7 @@ module { // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_truncf( // LOWER: pto.pge_b32 "PAT_VL1" -// LOWER: pto.vcadd +// LOWER: pto.vcgadd // LOWER: pto.vsel // LOWER: pto.pge_b32 "PAT_VL1" // LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 94cd55c58c..00c15db0a0 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -35,6 +35,16 @@ module { : !pto.vmi.vreg<128xf32>, !pto.ptr return } + + func.func @vmi_layout_assignment_group_slot_load_extui( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xui32> { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %wide = pto.vmi.extui %narrow + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + return %wide : !pto.vmi.vreg<256xui32> + } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( @@ -54,3 +64,11 @@ module { // CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK: pto.vmi.group_store %[[OUT]] // CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_extui( +// CHECK-SAME: -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> +// CHECK: %[[NARROW:.*]] = pto.vmi.group_slot_load +// CHECK-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> +// CHECK: %[[WIDE:.*]] = pto.vmi.extui %[[NARROW]] +// CHECK-SAME: !pto.vmi.vreg<256xui8, #pto.vmi.layout> -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> +// CHECK: return %[[WIDE]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto index b5533d9abc..9992655422 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -63,6 +63,7 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( +// LOWER: pto.pge_b32 "PAT_VL8" // LOWER: pto.vsldb // LOWER: pto.vsts {{.*}}, %arg21[%arg23], {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask // LOWER-COUNT-8: pto.vsldb diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto similarity index 66% rename from test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto rename to test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto index 35959585de..89e67118ab 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s module { - func.func @vmi_layout_assignment_group_store_slots1_unit_stride_invalid( + func.func @vmi_layout_assignment_group_store_slots1_unit_stride( %source: !pto.vmi.vreg<512xf32>, %mask: !pto.vmi.mask<512xpred>, %dst: !pto.ptr, @@ -19,14 +19,16 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> -> !pto.vmi.vreg<512xf32> - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support - // CHECK-SAME: slots=1 group_store currently lowers as one lane-0 vsts per group - // CHECK-SAME: requires constant positive row_stride divisible by 8 elements - // CHECK-SAME: packed or unaligned contiguous store lowering is not implemented - // CHECK: note: see current operation: "pto.vmi.group_store" - // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} : !pto.vmi.vreg<512xf32>, !pto.ptr return } } + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_store_slots1_unit_stride( +// CHECK-COUNT-8: pto.vcgadd +// CHECK: pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto index 6b2d588e04..ca1ee6c921 100644 --- a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -20,15 +20,6 @@ module { return %narrow : !pto.vmi.vreg<128xf16> } - func.func @vmi_layout_assignment_tile_read_truncf( - %src: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { - %wide = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<128xf32> - %narrow = pto.vmi.truncf %wide - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> - return %narrow : !pto.vmi.vreg<128xf16> - } - func.func @vmi_layout_assignment_load_truncf_multi_use( %src: !pto.ptr, %dst: !pto.ptr, @@ -42,17 +33,6 @@ module { return %narrow : !pto.vmi.vreg<128xf16> } - func.func @vmi_layout_assignment_tile_read_truncf_multi_use( - %src: memref<128xf32>, - %dst: memref<128xf32>) -> !pto.vmi.vreg<128xf16> { - %wide = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<128xf32> - pto.vmi.tile_write %wide, %dst - : !pto.vmi.vreg<128xf32>, memref<128xf32> - %narrow = pto.vmi.truncf %wide - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> - return %narrow : !pto.vmi.vreg<128xf16> - } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( @@ -73,25 +53,6 @@ module { // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. -// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read -// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> - -// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf( -// LOWER: %[[ZERO:.*]] = arith.constant 0 : index -// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" -// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} -// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] -// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> -// LOWER-NOT: pto.vmi. -// LOWER-NOT: !pto.vmi. - // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( // ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[WIDE:.*]] = pto.vmi.load @@ -111,23 +72,3 @@ module { // LOWER: return {{.*}} : !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. - -// ASSIGN-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: %[[WIDE:.*]] = pto.vmi.tile_read -// ASSIGN-SAME: memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: pto.vmi.tile_write %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> - -// LOWER-LABEL: func.func @vmi_layout_assignment_tile_read_truncf_multi_use( -// LOWER: pto.vsts -// LOWER: pto.vdintlv -// LOWER: pto.vcvt -// LOWER: return {{.*}} : !pto.vreg<128xf16> -// LOWER-NOT: pto.vmi. -// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto b/test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto new file mode 100644 index 0000000000..9ab9d16090 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto @@ -0,0 +1,40 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_assignment_group_slot_trunci_sparse( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xui8>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_sparse( +// ASSIGN: %[[NARROW:.*]] = pto.vmi.trunci +// ASSIGN-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> +// ASSIGN: pto.vmi.group_store %[[NARROW]] + +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_sparse( +// LOWER-NOT: pto.vcvt +// LOWER-NOT: pto.vpack +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_fold_consumers_store.pto b/test/lit/vmi/vmi_layout_fold_consumers_store.pto index 281d737861..e8249eec06 100644 --- a/test/lit/vmi/vmi_layout_fold_consumers_store.pto +++ b/test/lit/vmi/vmi_layout_fold_consumers_store.pto @@ -30,16 +30,6 @@ module { return } - func.func @vmi_layout_fold_consumers_tile_write( - %src: !pto.vmi.vreg<128xf16>, - %dst: memref<128xf32>) { - %wide = pto.vmi.extf %src - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> - pto.vmi.tile_write %wide, %dst - : !pto.vmi.vreg<128xf32>, memref<128xf32> - return - } - } // FOLD-LABEL: func.func @vmi_layout_fold_consumers_store( @@ -72,21 +62,3 @@ module { // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. - -// FOLD-LABEL: func.func @vmi_layout_fold_consumers_tile_write( -// FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// FOLD: %[[WIDE:.*]] = pto.vmi.extf %[[SRC]] -// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// FOLD-NOT: pto.vmi.ensure_layout -// FOLD: pto.vmi.tile_write %[[WIDE]] -// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// FOLD-NOT: pto.vmi.ensure_layout -// FOLD: return - -// LOWER-LABEL: func.func @vmi_layout_fold_consumers_tile_write( -// LOWER: %[[WIDE0:.*]] = pto.vcvt -// LOWER: %[[WIDE1:.*]] = pto.vcvt -// LOWER-NOT: pto.vintlv -// LOWER: pto.vstsx2 %[[WIDE0]], %[[WIDE1]] -// LOWER-NOT: pto.vmi. -// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto similarity index 70% rename from test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto rename to test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto index 49d728f73d..50ac7a6b3f 100644 --- a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto @@ -6,18 +6,18 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -pto-validate-vmi-layout-ir | FileCheck %s module { - func.func @vmi_to_vpto_bitcast_group_slots_invalid( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + func.func @vmi_layout_gate_bitcast_group_slots( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { %out = pto.vmi.bitcast %source : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> - return + return %out : !pto.vmi.vreg<128xi32, #pto.vmi.layout> } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.bitcast requires matching source/result layouts -// CHECK-SAME: identical physical arity and matching per-chunk logical bit footprints -// CHECK-SAME: does not support group_slots layouts +// CHECK-LABEL: func.func @vmi_layout_gate_bitcast_group_slots( +// CHECK: pto.vmi.bitcast diff --git a/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto index 7c62871865..14e874beac 100644 --- a/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_store_support_invalid.pto @@ -20,18 +20,3 @@ module { return } } - -// ----- - -module { - func.func @vmi_layout_gate_tile_write_deint_tail_invalid( - %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, - %dst: memref<129xf32>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.tile_write has no registered contiguous-memory layout support - // CHECK-SAME: requires arity divisible by layout factor - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, - memref<129xf32> - return - } -} diff --git a/test/lit/vmi/vmi_memory_element_type_invalid.pto b/test/lit/vmi/vmi_memory_element_type_invalid.pto index 4d6a199e11..5f1b8133bf 100644 --- a/test/lit/vmi/vmi_memory_element_type_invalid.pto +++ b/test/lit/vmi/vmi_memory_element_type_invalid.pto @@ -30,28 +30,3 @@ module { } // CHECK: 'pto.vmi.store' op requires memory destination element type to match VMI data element type - -// ----- - -module { - func.func @vmi_tile_read_element_type_invalid(%src: memref<128xf32>) { - %value = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<128xf16> - return - } -} - -// CHECK: 'pto.vmi.tile_read' op requires memory source element type to match VMI data element type - -// ----- - -module { - func.func @vmi_tile_write_element_type_invalid( - %value: !pto.vmi.vreg<128xf16>, %dst: memref<128xf32>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<128xf16>, memref<128xf32> - return - } -} - -// CHECK: 'pto.vmi.tile_write' op requires memory destination element type to match VMI data element type diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto index 3ba8eb29dc..0970134935 100644 --- a/test/lit/vmi/vmi_op_verifier_basic.pto +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -11,7 +11,6 @@ module { func.func @vmi_op_verifier_basic( %ptr: !pto.ptr, - %tile: memref<128xf32>, %layouted: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { @@ -46,8 +45,6 @@ module { pto.vmi.group_store %slot_loaded, %ptr[%c0], %c1 {num_groups = 8} : !pto.vmi.vreg<128xf32>, !pto.ptr pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr - %tile_read = pto.vmi.tile_read %tile : memref<128xf32> -> !pto.vmi.vreg<128xf32> - pto.vmi.tile_write %tile_read, %tile : !pto.vmi.vreg<128xf32>, memref<128xf32> %small = "pto.vmi.shuffle"(%broadcast) { indices = array @@ -102,8 +99,6 @@ module { // CHECK: pto.vmi.group_slot_load // CHECK: pto.vmi.group_store // CHECK: pto.vmi.store -// CHECK: pto.vmi.tile_read -// CHECK: pto.vmi.tile_write // CHECK: pto.vmi.ensure_layout // CHECK: pto.vmi.ensure_mask_layout // CHECK: pto.vmi.ensure_mask_granularity diff --git a/test/lit/vmi/vmi_shuffle_indices_invalid.pto b/test/lit/vmi/vmi_shuffle_indices_invalid.pto new file mode 100644 index 0000000000..fe6582ee86 --- /dev/null +++ b/test/lit/vmi/vmi_shuffle_indices_invalid.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -split-input-file 2>&1 | FileCheck %s + +module { + func.func @vmi_shuffle_index_count_invalid(%src: !pto.vmi.vreg<8xui8>) { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<8xui8>) -> !pto.vmi.vreg<4xui8> + return + } +} + +// CHECK: 'pto.vmi.shuffle' op requires shuffle index count to match result logical lane count + +// ----- + +module { + func.func @vmi_shuffle_index_oob_invalid(%src: !pto.vmi.vreg<8xui8>) { + %out = "pto.vmi.shuffle"(%src) { + indices = array + } : (!pto.vmi.vreg<8xui8>) -> !pto.vmi.vreg<4xui8> + return + } +} + +// CHECK: 'pto.vmi.shuffle' op requires every shuffle index to select an existing source logical lane diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto similarity index 59% rename from test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto rename to test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto index f946686b6f..d20be7e33e 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto @@ -6,17 +6,24 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -pto-validate-vmi-layout-ir 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_layout_gate_bitcast_group_slots_invalid( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { - // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.bitcast has no registered layout support - // CHECK-SAME: does not support group_slots layouts - // CHECK: note: see current operation: %{{.*}} = "pto.vmi.bitcast" + func.func @vmi_to_vpto_bitcast_group_slots( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { %out = pto.vmi.bitcast %source : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> - return + return %out : !pto.vmi.vreg<128xi32, #pto.vmi.layout> } } + +// CHECK-LABEL: func.func @vmi_to_vpto_bitcast_group_slots( +// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> +// CHECK-SAME: -> !pto.vreg<64xi32> +// CHECK: %[[B0:.*]] = pto.vbitcast %[[V0]] : !pto.vreg<64xf32> -> !pto.vreg<64xi32> +// CHECK: return %[[B0]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto index 4c15af9f19..cd78e684c7 100644 --- a/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto +++ b/test/lit/vmi/vmi_to_vpto_ensure_layout_deint4.pto @@ -34,6 +34,19 @@ module { return %p0, %p1, %p2, %p3 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> } + + func.func @vmi_to_vpto_ensure_layout_deint2_to_deint4( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %split = pto.vmi.ensure_layout %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%split) + : (!pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> + } } // CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint4_to_contiguous( @@ -55,3 +68,14 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_ensure_layout_deint2_to_deint4( +// CHECK: pto.vintlv +// CHECK: pto.vintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto index 83bf5db675..21c8753ec7 100644 --- a/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_gather_f16_invalid.pto @@ -25,4 +25,4 @@ module { } // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.gather lowers through pto.vgather2_bc + pto.vsel only -// CHECK-SAME: 32-bit result element type +// CHECK-SAME: 32-bit result elements diff --git a/test/lit/vmi/vmi_to_vpto_gather_u16.pto b/test/lit/vmi/vmi_to_vpto_gather_u16.pto new file mode 100644 index 0000000000..bcf0caede3 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_gather_u16.pto @@ -0,0 +1,37 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_gather_u16( + %src: !pto.ptr, + %indices: !pto.vmi.vreg<32xui16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<32xb16, #pto.vmi.layout>, + %passthru: !pto.vmi.vreg<32xui16, #pto.vmi.layout>) + -> !pto.vreg<128xui16> { + %out = pto.vmi.gather %src[%indices], %mask, %passthru + : !pto.ptr, + !pto.vmi.vreg<32xui16, #pto.vmi.layout>, + !pto.vmi.mask<32xb16, #pto.vmi.layout>, + !pto.vmi.vreg<32xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<32xui16, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<32xui16, #pto.vmi.layout>) + -> !pto.vreg<128xui16> + return %part : !pto.vreg<128xui16> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_gather_u16( +// CHECK: %[[GATHER:.*]] = pto.vgather2 %arg0, %arg1, %arg2 : !pto.ptr, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[OUT:.*]] = pto.vsel %[[GATHER]], %arg3, %arg2 : !pto.vreg<128xui16>, !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: return %[[OUT]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto new file mode 100644 index 0000000000..f82a877737 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_s32_deint2_small_group( + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %broadcast = pto.vmi.group_broadcast %source {num_groups = 4} + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%broadcast) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) + return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_s32_deint2_small_group( +// CHECK-COUNT-2: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index 019b45f7c5..edef94f273 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -33,9 +33,10 @@ module { // CHECK-LABEL: func.func @vmi_to_vpto_group_ops( // CHECK-COUNT-8: pto.vlds -// CHECK-COUNT-8: pto.vcadd +// CHECK: pto.vcgadd +// CHECK: pto.vselr +// CHECK-COUNT-7: pto.vcgadd // CHECK-COUNT-8: {position = "LOWEST"} -// CHECK-NOT: pto.vselr // CHECK-COUNT-8: pto.vsts // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto index f2681f3359..5d06c115ac 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto @@ -35,7 +35,7 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s256_broadcast( -// CHECK: pto.vcadd +// CHECK: pto.vcgadd // CHECK: pto.vadd // CHECK: pto.vsel // CHECK: pto.vdup {{.*}} {position = "LOWEST"} diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto index 55ae7fd255..935eb2b80f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -34,10 +34,10 @@ module { // CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64( // CHECK-DAG: %[[VL1:.*]] = pto.pge_b32 "PAT_VL1" -// CHECK: pto.vcadd +// CHECK: pto.vcgadd // CHECK: pto.vadd // CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] -// CHECK: pto.vcadd +// CHECK: pto.vcgadd // CHECK: pto.vsel {{.*}}, {{.*}}, %[[VL1]] // CHECK: return {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto index 99359b1a8e..7b7bc76761 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto @@ -32,7 +32,7 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_group_reduce_s64_support( -// CHECK-COUNT-8: pto.vcadd +// CHECK-COUNT-8: pto.vcgadd // CHECK: pto.vsel // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto index cf6591f36c..7403599d5f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -52,10 +52,59 @@ module { !pto.ptr return } + + func.func @vmi_to_vpto_group_slot_load_u8_scale_broadcast( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %c23_i32 = arith.constant 23 : i32 + %scale_u8 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %scale = pto.vmi.bitcast %bits + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + pto.vmi.store %vec, %dst[%off] + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_load_u16_broadcast( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %c1 = arith.constant 1 : index + %scale_u16 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %scale_u32 = pto.vmi.extui %scale_u16 + : !pto.vmi.vreg<256xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + %vec = pto.vmi.group_broadcast %scale_u32 {num_groups = 8} + : !pto.vmi.vreg<256xui32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + pto.vmi.store %vec, %dst[%off] + : !pto.vmi.vreg<256xui32, #pto.vmi.layout>, + !pto.ptr + return + } } // CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8( -// CHECK-DAG: %[[MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask // CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> // CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[MASK]] : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<64xf32> // CHECK: return %[[OUT]] @@ -64,7 +113,7 @@ module { // CHECK-COUNT-8: pto.vsldb // CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_slots8_store( -// CHECK: %[[LOAD_MASK:.*]] = pto.pge_b32 "PAT_VL1" : !pto.mask +// CHECK: %[[LOAD_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask // CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg2 : -> // CHECK: %[[OUT:.*]] = pto.vsldb %[[BASE]], {{.*}}, {{.*}}, %[[LOAD_MASK]] // CHECK: %[[STORE_MASK:.*]] = pto.pge_b32 "PAT_VL8" : !pto.mask @@ -72,3 +121,31 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_u8_scale_broadcast( +// CHECK: pto.pge_b8 "PAT_VL8" : !pto.mask +// CHECK: pto.vsldb {{.*}} : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<256xui8> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P0"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P1"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P2"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "P3"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vselr {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xi32> -> !pto.vreg<64xui32> +// CHECK: pto.vsel {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vshl +// CHECK: pto.vselr +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_u16_broadcast( +// CHECK: pto.pge_b16 "PAT_VL8" : !pto.mask +// CHECK: pto.vsldb {{.*}} : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<128xui16> +// CHECK-DAG: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// CHECK-DAG: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vselr {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xi32> -> !pto.vreg<64xui32> +// CHECK: pto.vsel {{.*}} : !pto.vreg<64xui32>, !pto.vreg<64xui32>, !pto.mask -> !pto.vreg<64xui32> +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xui32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto index e806b28b92..c519205638 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto @@ -23,6 +23,7 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_group_slot_load_support( +// CHECK: pto.pge_b32 "PAT_VL8" // CHECK: pto.vsldb // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto new file mode 100644 index 0000000000..8f949fc0f7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto @@ -0,0 +1,28 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots1_1pt( + %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c2 = arith.constant 2 : index + pto.vmi.group_store %value, %dst[%off], %c2 {num_groups = 8} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots1_1pt( +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto new file mode 100644 index 0000000000..dc68aed9db --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @aligned_unit_stride_group_store( + %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %dst: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + pto.vmi.group_store %value, %dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @unaligned_unit_stride_group_store( + %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %row: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %off = arith.muli %row, %c2 : index + pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @aligned_unit_stride_group_store( +// CHECK-COUNT-8: pto.vdup +// CHECK: pto.vsts {{.*}} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: dist = "1PT_B32" + +// CHECK-LABEL: func.func @unaligned_unit_stride_group_store( +// CHECK-NOT: pto.vdup +// CHECK-COUNT-8: pto.vsts {{.*}} {dist = "1PT_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto new file mode 100644 index 0000000000..08fa565554 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_store_slots8_i32_to_u8( + %value: !pto.vmi.vreg<1024xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 32} + : !pto.vmi.vreg<1024xi32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_store_slots8_i32_to_u8_padded( + %value: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %dst: !pto.ptr) { + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + pto.vmi.group_store %value, %dst[%c32], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + !pto.ptr + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots8_i32_to_u8( +// CHECK-COUNT-1: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// CHECK-LABEL: func.func @vmi_to_vpto_group_store_slots8_i32_to_u8_padded( +// CHECK: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xi32> -> !pto.vreg<128xui16> +// CHECK: pto.vpack {{.*}} "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> +// CHECK: pto.vsts {{.*}} {dist = "NORM_B8"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto index 50051aab6d..d0ab02f361 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_casts.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -38,6 +38,64 @@ module { -> !pto.vreg<256xui8> return %p : !pto.vreg<256xui8> } + + func.func @vmi_to_vpto_fptosi_f32_to_i32( + %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) { + %wide = pto.vmi.fptosi %input + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %p0, %p1, %p2, %p3 = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32>) + return %p0, %p1, %p2, %p3 + : !pto.vreg<64xi32>, !pto.vreg<64xi32>, + !pto.vreg<64xi32>, !pto.vreg<64xi32> + } + + func.func @vmi_to_vpto_group_slot_trunci_i32_to_ui8( + %wide: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<128xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_sparse( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %dst: !pto.ptr, + %off: index) { + %c1 = arith.constant 1 : index + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + !pto.ptr + return + } } // CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u32( @@ -62,3 +120,36 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_fptosi_f32_to_i32( +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_i32_to_ui8( +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8( +// CHECK: pto.pge_b32 "PAT_VL8" +// CHECK: pto.vcvt {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> +// CHECK: pto.vsts +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_sparse( +// CHECK-NOT: pto.vcvt +// CHECK-NOT: pto.vpack +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto index 157f57f84d..40bbe153c0 100644 --- a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref.pto @@ -31,15 +31,6 @@ module { return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> } - func.func @vmi_to_vpto_tile_read_safe_tail_memref(%src: memref<128xf32>) - -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { - %value = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<100xf32, #pto.vmi.layout> - %p0, %p1 = "pto.vmi.unpack"(%value) - : (!pto.vmi.vreg<100xf32, #pto.vmi.layout>) - -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) - return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> - } } // CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref( @@ -61,13 +52,3 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_safe_tail_memref( -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[P0:.*]] = pto.vlds %arg0[%[[C0]]] : memref<128xf32> -> !pto.vreg<64xf32> -// CHECK: %[[P1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> -// CHECK: return %[[P0]], %[[P1]] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto index 7a222a35ad..8d1485d965 100644 --- a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto @@ -95,36 +95,3 @@ module { // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout // CHECK-SAME: with UB-backed destination // CHECK-SAME: destination is GM-backed - -// ----- - -module { - func.func @vmi_tile_read_gm_unsupported( - %src: memref<64xf32, #pto.address_space>) { - %value = pto.vmi.tile_read %src - : memref<64xf32, #pto.address_space> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint -// CHECK-SAME: source is GM-backed -// CHECK-SAME: requires UB-backed memory - -// ----- - -module { - func.func @vmi_tile_write_gm_unsupported( - %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %dst: memref<64xf32, #pto.address_space>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - memref<64xf32, #pto.address_space> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type -// CHECK-SAME: with UB-backed destination -// CHECK-SAME: destination is GM-backed diff --git a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto index 489891c72a..c3483450bc 100644 --- a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto @@ -142,36 +142,3 @@ module { // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.masked_store requires either full physical chunks or contiguous tail-store value/mask layout // CHECK-SAME: destination memref layout is non-identity // CHECK-SAME: contiguous identity lane-to-address maps - -// ----- - -module { - func.func @vmi_tile_read_strided_memref_unsupported( - %src: memref<128xf32, strided<[2], offset: 0>>) { - %value = pto.vmi.tile_read %src - : memref<128xf32, strided<[2], offset: 0>> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_read requires full physical chunks without padding lanes or a statically safe full-read footprint -// CHECK-SAME: source memref layout is non-identity -// CHECK-SAME: contiguous identity lane-to-address maps - -// ----- - -module { - func.func @vmi_tile_write_strided_memref_unsupported( - %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - %dst: memref<128xf32, strided<[2], offset: 0>>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, - memref<128xf32, strided<[2], offset: 0>> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type -// CHECK-SAME: destination memref layout is non-identity -// CHECK-SAME: contiguous identity lane-to-address maps diff --git a/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto index b412afdcca..34e75012b2 100644 --- a/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_store_width_invalid.pto @@ -21,18 +21,3 @@ module { // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store requires an 8/16/32-bit predicate-maskable element type // CHECK-SAME: requires an 8/16/32-bit element type - -// ----- - -module { - func.func @vmi_tile_write_f64_unsupported( - %value: !pto.vmi.vreg<32xf64, #pto.vmi.layout>, - %dst: memref<32xf64>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<32xf64, #pto.vmi.layout>, memref<32xf64> - return - } -} - -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type -// CHECK-SAME: requires an 8/16/32-bit element type diff --git a/test/lit/vmi/vmi_to_vpto_stride_load.pto b/test/lit/vmi/vmi_to_vpto_stride_load.pto new file mode 100644 index 0000000000..d30ce58a1f --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stride_load.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_stride_load( + %src: !pto.ptr, + %offset: index, + %mask: !pto.vmi.mask<64xb8, #pto.vmi.layout>) + -> !pto.vreg<256xf8E4M3FN> { + %c1 = arith.constant 1 : i16 + %out = pto.vmi.stride_load %src[%offset], %c1, %c1, %mask + : !pto.ptr, i16, i16, + !pto.vmi.mask<64xb8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf8E4M3FN, #pto.vmi.layout> + %part = "pto.vmi.unpack"(%out) + : (!pto.vmi.vreg<64xf8E4M3FN, #pto.vmi.layout>) + -> !pto.vreg<256xf8E4M3FN> + return %part : !pto.vreg<256xf8E4M3FN> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_stride_load( +// CHECK: %[[BASE:.*]] = pto.addptr %arg0, %arg1 : -> +// CHECK: %[[LOAD:.*]] = pto.vsldb %[[BASE]], %c1{{[^,]*}}, %c1{{[^,]*}}, %arg2 : !pto.ptr, i16, i16, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: return %[[LOAD]] : !pto.vreg<256xf8E4M3FN> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_stride_store.pto b/test/lit/vmi/vmi_to_vpto_stride_store.pto new file mode 100644 index 0000000000..f581ed1bc7 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_stride_store.pto @@ -0,0 +1,33 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_stride_store( + %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + %dst: !pto.ptr, + %mask: !pto.vmi.mask<64xb32, #pto.vmi.layout>) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : i16 + %c4 = arith.constant 4 : i16 + pto.vmi.stride_store %value, %dst[%c0], %c2, %c4, %mask + : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + !pto.ptr, i16, i16, + !pto.vmi.mask<64xb32, #pto.vmi.layout> + return + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_stride_store( +// CHECK: %[[BASE:.*]] = pto.addptr %arg1, %c0 : -> +// CHECK: %{{.*}} = pto.vsstb %arg0, %[[BASE]], %c2{{[^,]*}}, %c4{{[^,]*}}, %arg2 : !pto.vreg<64xf32>, !pto.ptr, i16, i16, !pto.mask -> !pto.ptr +// CHECK: return +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_read_write.pto b/test/lit/vmi/vmi_to_vpto_tile_read_write.pto deleted file mode 100644 index 5b7e6dbe00..0000000000 --- a/test/lit/vmi/vmi_to_vpto_tile_read_write.pto +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s - -module { - func.func @vmi_to_vpto_tile_read_write_contiguous(%src: memref<128xf32>, %dst: memref<128xf32>) { - %value = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> - return - } - - func.func @vmi_to_vpto_tile_read_deint2(%src: memref<128xf32>) - -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { - %value = pto.vmi.tile_read %src - : memref<128xf32> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - %p0, %p1 = "pto.vmi.unpack"(%value) - : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) - -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) - return %p0, %p1 : !pto.vreg<64xf32>, !pto.vreg<64xf32> - } - - func.func @vmi_to_vpto_tile_write_deint2( - %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, - %dst: memref<128xf32>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, memref<128xf32> - return - } -} - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_write_contiguous( -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index -// CHECK: %[[L0:.*]] = pto.vlds %arg0[%[[ZERO]]] : memref<128xf32> -> !pto.vreg<64xf32> -// CHECK: %[[L1:.*]] = pto.vlds %arg0[%[[C64]]] : memref<128xf32> -> !pto.vreg<64xf32> -// CHECK: pto.vsts %[[L0]], %arg1[%[[ZERO]]] -// CHECK: pto.vsts %[[L1]], %arg1[{{.*}}] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_read_deint2( -// CHECK: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%[[ZERO]]], "DINTLV_B32" -// CHECK: return %[[P0]], %[[P1]] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint2( -// CHECK: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" -// CHECK: pto.vstsx2 %arg0, %arg1, %arg2[%[[ZERO]]], "INTLV_B32", %[[MASK]] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto deleted file mode 100644 index 701d921186..0000000000 --- a/test/lit/vmi/vmi_to_vpto_tile_write_deint_tail.pto +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s - -module { - func.func @vmi_to_vpto_tile_write_deint_tail( - %value: !pto.vmi.vreg<4xf32, #pto.vmi.layout>, - %dst: memref<4xf32>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<4xf32, #pto.vmi.layout>, - memref<4xf32> - return - } -} - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_deint_tail( -// CHECK-SAME: %[[P0:[^,]+]]: !pto.vreg<64xf32> -// CHECK-SAME: %[[P1:[^,]+]]: !pto.vreg<64xf32> -// CHECK-SAME: %[[DST:[^)]+]]: memref<4xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 -// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = pto.vintlv %[[P0]], %[[P1]] -// CHECK: %[[MASK:.*]], %{{.*}} = pto.plt_b32 %[[C4]] : i32 -> !pto.mask, i32 -// CHECK: pto.vsts %[[LOW]], %[[DST]][%[[ZERO]]], %[[MASK]] -// CHECK-NOT: pto.vsts %[[HIGH]] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto b/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto deleted file mode 100644 index d4f37d48fc..0000000000 --- a/test/lit/vmi/vmi_to_vpto_tile_write_tail.pto +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s - -module { - func.func @vmi_to_vpto_tile_write_tail( - %value: !pto.vmi.vreg<100xf32, #pto.vmi.layout>, - %dst: memref<100xf32>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<100xf32, #pto.vmi.layout>, memref<100xf32> - return - } -} - -// CHECK-LABEL: func.func @vmi_to_vpto_tile_write_tail( -// CHECK-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> -// CHECK-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> -// CHECK-SAME: %[[DST:[^)]+]]: memref<100xf32> -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C36:.*]] = arith.constant 36 : i32 -// CHECK: %[[FULL_MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask -// CHECK: pto.vsts %[[V0]], %[[DST]][%[[ZERO]]], %[[FULL_MASK]] -// CHECK: %[[TAIL_MASK:.*]], %{{.*}} = pto.plt_b32 %[[C36]] : i32 -> !pto.mask, i32 -// CHECK: pto.vsts %[[V1]], %[[DST]][{{.*}}], %[[TAIL_MASK]] -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf.pto b/test/lit/vmi/vmi_to_vpto_truncf.pto index e8d8340c83..edac1ec223 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf.pto @@ -37,6 +37,18 @@ module { -> !pto.vreg<128xf16> return %p : !pto.vreg<128xf16> } + + func.func @vmi_to_vpto_truncf_f32_to_f16_multichunk( + %wide: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) { + %narrow = pto.vmi.truncf %wide + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %p0, %p1 = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xf16, #pto.vmi.layout>) + -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>) + return %p0, %p1 : !pto.vreg<128xf16>, !pto.vreg<128xf16> + } } // CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16( @@ -54,3 +66,15 @@ module { // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast + +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_f32_to_f16_multichunk( +// CHECK: %[[EVEN0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT0:.*]] = pto.vor %[[EVEN0]], %[[ODD0]] +// CHECK: %[[EVEN1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[ODD1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: %[[OUT1:.*]] = pto.vor %[[EVEN1]], %[[ODD1]] +// CHECK: return %[[OUT0]], %[[OUT1]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto b/test/lit/vmi/vmi_truncf_rounding_token_invalid.pto similarity index 55% rename from test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto rename to test/lit/vmi/vmi_truncf_rounding_token_invalid.pto index 4d4dac9d6d..cc191cacea 100644 --- a/test/lit/vmi/vmi_to_vpto_tile_write_tail_deint_invalid.pto +++ b/test/lit/vmi/vmi_truncf_rounding_token_invalid.pto @@ -6,17 +6,15 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s module { - func.func @vmi_to_vpto_tile_write_tail_deint_invalid( - %value: !pto.vmi.vreg<129xf32, #pto.vmi.layout>, - %dst: memref<129xf32>) { - pto.vmi.tile_write %value, %dst - : !pto.vmi.vreg<129xf32, #pto.vmi.layout>, memref<129xf32> + func.func @vmi_truncf_rounding_token_invalid( + %source: !pto.vmi.vreg<256xf32>) { + %result = pto.vmi.truncf %source {rounding = "R"} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> return } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.tile_write requires an 8/16/32-bit predicate-maskable element type -// CHECK-SAME: requires every deinterleaved part to have the same physical chunk count +// CHECK: rounding attr must be A or H diff --git a/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto b/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto new file mode 100644 index 0000000000..0847c26a10 --- /dev/null +++ b/test/lit/vmi/vmi_truncf_rounding_unsupported_invalid.pto @@ -0,0 +1,20 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --emit-pto-ir %s 2>&1 | FileCheck %s + +module { + func.func @vmi_truncf_rounding_unsupported_invalid( + %source: !pto.vmi.vreg<128xf32>) { + %result = pto.vmi.truncf %source {rounding = "H"} + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + return + } +} + +// CHECK: rounding attr is currently only supported for f32 to !pto.hif8 truncf diff --git a/test/lit/vpto/vgather2_u16_vpto_llvm.pto b/test/lit/vpto/vgather2_u16_vpto_llvm.pto new file mode 100644 index 0000000000..d2e8f983de --- /dev/null +++ b/test/lit/vpto/vgather2_u16_vpto_llvm.pto @@ -0,0 +1,30 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ( mkdir -p %T && ptoas --pto-arch=a5 --pto-backend=vpto %s -o %t --mlir-print-ir-after=convert-func-to-llvm 2>&1 || true ) | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vgather2_u16(%src: !pto.ptr, + %idx: !pto.vreg<128xui16>, + %dst: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %out = pto.vgather2 %src, %idx, %mask + : !pto.ptr, !pto.vreg<128xui16>, !pto.mask + -> !pto.vreg<128xui16> + pto.vsts %out, %dst[%c0], %mask + : !pto.vreg<128xui16>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK-LABEL: llvm.func @vgather2_u16_mix_aiv +// CHECK: %[[IDX:.*]] = llvm.bitcast %arg1 : vector<128xi16> to vector<64xi32> +// CHECK: llvm.call @llvm.hivm.vgather2.v300.v128u16(%arg0, %[[IDX]], diff --git a/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto b/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto new file mode 100644 index 0000000000..65cf7e6223 --- /dev/null +++ b/test/lit/vpto/vmi_fp4_e1_packed_surface_verify_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_fp4_e1_packed_surface_invalid( + %arg0: !pto.vmi.vreg<256x!pto.f4E1M2x2>) attributes {pto.kernel} { + return + } +} + +// CHECK: error: '!pto.vmi.vreg<256x!pto.f4E1M2x2>' uses a packed FP4 physical pair type as a VMI logical element type +// CHECK-SAME: packed FP4 input/output is not a supported VMI surface diff --git a/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto b/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto new file mode 100644 index 0000000000..18e8f6fd30 --- /dev/null +++ b/test/lit/vpto/vmi_fp4_packed_surface_verify_invalid.pto @@ -0,0 +1,19 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_fp4_packed_surface_invalid( + %arg0: !pto.vmi.vreg<256x!pto.f4E2M1x2>) attributes {pto.kernel} { + return + } +} + +// CHECK: error: '!pto.vmi.vreg<256x!pto.f4E2M1x2>' uses a packed FP4 physical pair type as a VMI logical element type +// CHECK-SAME: packed FP4 input/output is not a supported VMI surface diff --git a/test/lit/vpto/vmi_sitofp.pto b/test/lit/vpto/vmi_sitofp.pto new file mode 100644 index 0000000000..fca0e63698 --- /dev/null +++ b/test/lit/vpto/vmi_sitofp.pto @@ -0,0 +1,42 @@ +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @vmi_sitofp_kernel +// CHECK: pto.vcvt {{.*}} {rnd = "R"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<64xf32> + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_sitofp_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<64xi32> + %f = pto.vmi.sitofp %x + : !pto.vmi.vreg<64xi32> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %f, %ub_dst[%c0] + : !pto.vmi.vreg<64xf32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/lit/vpto/vmi_truncf_hif8.pto b/test/lit/vpto/vmi_truncf_hif8.pto new file mode 100644 index 0000000000..260c43ad7a --- /dev/null +++ b/test/lit/vpto/vmi_truncf_hif8.pto @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>/dev/null | FileCheck %s + +// CHECK-LABEL: func.func @vmi_truncf_hif8_default_kernel +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK-LABEL: func.func @vmi_truncf_hif8_hybrid_kernel +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_truncf_hif8_default_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %h = pto.vmi.truncf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> + pto.vmi.store %h, %ub_dst[%c0] + : !pto.vmi.vreg<256x!pto.hif8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } + + func.func @vmi_truncf_hif8_hybrid_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %x = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %h = pto.vmi.truncf %x {rounding = "H"} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256x!pto.hif8> + pto.vmi.store %h, %ub_dst[%c0] + : !pto.vmi.vreg<256x!pto.hif8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py new file mode 100644 index 0000000000..612b15c3f6 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/compare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v2.bin", dtype=np.int32) + output = np.fromfile("v2.bin", dtype=np.int32) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + if golden.shape != output.shape: + print(f"[ERROR] shape mismatch golden={golden.shape} output={output.shape}") + sys.exit(2) + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed idx={idx} " + f"golden={golden[idx] if idx >= 0 else 'n/a'} " + f"output={output[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py new file mode 100644 index 0000000000..1aa24c830e --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/golden.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under +# the terms and conditions of CANN Open Software License Agreement Version 2.0 +# (the "License"). Please refer to the License for details. You may not use +# this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +# AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +# FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +# for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 8 + + +def generate(output_dir: Path) -> None: + src = np.empty((ROWS, COLS), dtype=np.int32) + for row in range(ROWS): + base = np.arange(COLS, dtype=np.int32) * ((row % 3) + 1) + src[row, :] = base - row * 5 - 9 + src[row, COLS // 2] = row * 11 - 17 + src[row, COLS - 1] = 23 - row * 7 + dst = np.full(ROWS, -777, dtype=np.int32) + golden = np.max(src, axis=1).astype(np.int32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + dst.tofile(output_dir / "v2.bin") + golden.tofile(output_dir / "golden_v2.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto new file mode 100644 index 0000000000..f11fa15503 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto @@ -0,0 +1,51 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_group_reduce_i32_maxi_store_kernel(%src_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c512_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<64xi32> + %sum = pto.vmi.group_reduce_maxi %x, %mask {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> + -> !pto.vmi.vreg<64xi32> + pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<64xi32>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c32_i64 + nburst(%c1_i64, %c32_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp new file mode 100644 index 0000000000..7a7c0bacb9 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/launch.cpp @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_group_reduce_i32_maxi_store_kernel(__gm__ int32_t *src, + __gm__ int32_t *dst); + +void LaunchVmi_group_reduce_i32_maxi_store_kernel(int32_t *src, int32_t *dst, + void *stream) { + vmi_group_reduce_i32_maxi_store_kernel<<<1, nullptr, stream>>>( + (__gm__ int32_t *)src, (__gm__ int32_t *)dst); +} diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp new file mode 100644 index 0000000000..0aa2835503 --- /dev/null +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/main.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_group_reduce_i32_maxi_store_kernel(int32_t *src, int32_t *dst, + void *stream); + +int main() { + constexpr size_t kInputElems = 64; + constexpr size_t kOutputElems = 8; + size_t srcBytes = kInputElems * sizeof(int32_t); + size_t dstBytes = kOutputElems * sizeof(int32_t); + int32_t *srcHost = nullptr; + int32_t *dstHost = nullptr; + int32_t *srcDevice = nullptr; + int32_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_group_reduce_i32_maxi_store_kernel(srcDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/ptoas.flags similarity index 100% rename from test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/ptoas.flags rename to test/vpto/cases/vmi/group-reduce-i32-maxi-store/ptoas.flags diff --git a/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md new file mode 100644 index 0000000000..c92e060d37 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md @@ -0,0 +1,226 @@ + + +# VMI Kernels 的 CCE Case 范围 + +本文档从目标仓库 `.work/external/a5-kernel-standalone/cce` 的 raw CCE 测试入口列出 +VMI kernel 迁移范围。当前审计快照为本地 clone 的 +`main@ee81c3660d6336ecaecd805f02ffb2d69446984e`。不要从当前 +`test/vpto/cases/vmi/kernels` 已存在目录反推目标范围:历史目录包含额外 probe、 +历史 VMI coverage 和尚未对齐 CCE 数据流的 case。 + +## 统计规则 + +“必须支持”只统计目标仓库中的正确性、等价性和 minimum 测试入口。 +`smoke`、`timing`、`bench`、`debug`、`experiments` 和 bandwidth sweep 只用于补充代表 shape +或性能验证,不自动扩展为必须支持的语义 case。 + +| CCE family | 正确性来源 | 必须支持数量 | 首批 VMI 目标 | 暂缓或非首批 | +| --- | --- | ---: | --- | --- | +| `quant_minimum` | `quant_minimum/test/test_tquant.py` | 4 | 全部 4 个 | 该 suite 无暂缓项 | +| `block_quant` | `block_quant/test/test_equivalence.py` | 7 | 全部 7 个 | 除非 raw CCE 新增正确性入口,否则 HIF8 只算 VMI/compiler probe | +| `dynamic_quant` | `dynamic_quant/test/test_dq_equivalence.py` | 9 | 全部 9 个 | subset wrapper 不新增语义 case | +| `dequant/anti_mx_quant` | `dequant/anti_mx_quant/test/test_equivalence.py` | 16 | 先支持 FP8 case | FP4 输入因 VMI FP4 surface 未设计而暂缓 | +| `block_mx_quant` | `block_mx_quant/test/test_equivalence.py`; `test_cce.py` 是更宽的 smoke/correctness surface | 14 canonical,30 full union | 先支持 canonical FP8/OCP/rint 行 | FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | +| `swiglu_mx_quant` | `swiglu_mx_quant/test/test_equivalence.py` | 14 | 先支持 FP8/OCP/rint f16/bf16 行 | FP4 暂缓;CCE 源码中 `scale_alg=1` CUBLAS 路径异常 | +| `tutorial/block_mx_quant` | `tutorial/block_mx_quant/README.md` | 已由 `block_mx_quant` 覆盖 | BF16 FP8 tutorial shape 作为代表覆盖 | tutorial FP4 与主 `block_mx_quant` 共用 FP4 blocker | + +## quant_minimum + +来源:`.work/external/a5-kernel-standalone/cce/quant_minimum/test/test_tquant.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| `mxfp8_32x32_nd` | 必须支持 | +| `mxfp8_32x64_nz` | 必须支持 | +| `int8_sym_64x128_nd` | 必须支持 | +| `int8_asym_64x128_nd` | 必须支持 | + +`test_cycle_match.py` 只对同一组 `MINIMUM_CASES` 做 PTO/CCE cycle 对比,不新增语义 case。 + +## block_quant + +来源:`.work/external/a5-kernel-standalone/cce/block_quant/test/test_equivalence.py`。 +所有 case 都使用 `row_block_size=1`、`col_block_size=128`、`dst_type=292`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| BF16 `(4,128)` | 必须支持 | +| FP16 `(8,128)` | 必须支持 | +| BF16 `(32,128)` | 必须支持 | +| FP16 `(16,256)` | 必须支持 | +| BF16 `(2,128)` | 必须支持 | +| FP16 `(4,256)` | 必须支持 | +| 带 `min_scale` 的 BF16 `(4,128)` | 必须支持 | + +`minimal_test.py` 是 smoke 子集。`test_hardware_scale.py`、`large_shape_correctness.py` +和 bandwidth sweep 主要验证更大 streaming shape;只有当它们暴露新的 VMI memory/layout +约束时,才增加代表性 runtime case。 + +## dynamic_quant + +来源:`.work/external/a5-kernel-standalone/cce/dynamic_quant/test/test_dq_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| per-token,无 smooth,FP16 `(4,32)` | 必须支持 | +| per-token,无 smooth,FP16 `(16,128)` | 必须支持 | +| per-token,smooth,FP16 `(8,64)` | 必须支持 | +| per-token,smooth,FP16 `(16,128)` | 必须支持 | +| per-channel,FP16 `(128,128)` | 必须支持 | +| per-channel,FP16 `(256,256)` | 必须支持 | +| per-token,无 smooth,BF16 `(4,32)` | 必须支持 | +| per-token,smooth,BF16 `(8,64)` | 必须支持 | +| per-channel,BF16 `(128,128)` | 必须支持 | + +`test_pertoken_only.py`、`test_perchannel_128.py` 和 `test_perchannel_all.py` +只是该表的子集或重新分组。 + +## dequant / anti_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/dequant/anti_mx_quant/test/test_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 -> BF16 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> FP32 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> FP16 `(4,128)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(16,512)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(64,2048)` | 首批必须支持 | +| FP8 E4M3 -> BF16 `(1024,2048)` | 代表 large/perf;若 medium 已覆盖相同 lowering,首批 runtime shape 不必加入 | +| FP8 E5M2 -> BF16 `(4,128)` | 首批必须支持 | +| FP8 E5M2 -> BF16 `(16,512)` | 首批必须支持 | +| FP4 E2M1 -> BF16 `(4,64)` | 暂缓 | +| FP4 E2M1 -> BF16 `(16,256)` | 暂缓 | +| FP4 E2M1 -> BF16 `(4096,512)` | 暂缓 | +| FP4 E2M1 -> BF16 `(65536,2048)` | 暂缓 | +| FP4 E1M2 -> BF16 `(4,64)` | 暂缓 | +| FP4 E1M2 -> BF16 `(16,256)` | 暂缓 | +| FP4 E1M2 -> BF16 `(4096,512)` | 暂缓 | +| FP4 E1M2 -> BF16 `(65536,2048)` | 暂缓 | + +这些 FP4 行是真实目标仓库 case,但当前 VMI 尚未定义 logical FP4 packed input lane +或 packed-byte load/store 语义。不要用临时 byte trick 模拟。 + +## block_mx_quant + +canonical 来源:`.work/external/a5-kernel-standalone/cce/block_mx_quant/test/test_equivalence.py`。 +这是 `HW_RESULTS.md` 中报告的默认 14-case 正确性 suite。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 BF16 `(4,128)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 FP16 `(64,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 BF16 `(4,128)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 FP16 `(8,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 BF16 `(256,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E2M1 FP16 `(4,128)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 E2M1 BF16 `(1,2,2)`, `scale_alg=2`, `floor`, `dst_type_max=0` | 暂缓 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=2`, `floor`, `dst_type_max=6` | 暂缓 | +| FP4 E2M1 BF16 `(4,128)`, `scale_alg=2`, `rint`, `dst_type_max=7` | 暂缓 | +| FP4 E1M2 BF16 `(4,128)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 FP16 `(8,256)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E1M2 BF16 `(4,128)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 tail BF16 `(100,300)`, E2M1, `scale_alg=0`, `rint` | 暂缓 | + +`test_cce.py` 额外枚举完整 small-shape type/rounding union: + +| Surface family | 额外覆盖 | +| --- | --- | +| FP8 OCP | FP16/BF16 x E4M3/E5M2, shape `(4,128)`, `rint` | +| FP4 E2M1 OCP | FP16/BF16 x `rint/round/floor`, shape `(4,128)` | +| FP4 E2M1 DDR | FP16/BF16 x `rint/round/floor`, shape `(4,128)`, `scale_alg=2` | +| FP4 E1M2 OCP | FP16/BF16 x `rint/round/floor`, shape `(4,128)` | + +VMI 实现以 canonical 14-case suite 作为迁移 checklist;`test_cce.py` union +在 FP4 设计完成后作为 surface 完整性 checklist。 + +## swiglu_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/swiglu_mx_quant/test/test_equivalence.py`。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP8 E4M3 BF16 `(4,8)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 FP16 `(64,512)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 BF16 `(4,8)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E5M2 FP16 `(128,256)`, `scale_alg=0`, `rint` | 首批必须支持 | +| FP8 E4M3 BF16 `(64,512)`, `scale_alg=1`, `rint` | 暂缓;CCE 标记 CUBLAS 路径异常 | +| FP8 E5M2 FP16 `(64,512)`, `scale_alg=1`, `rint` | 暂缓;CCE 标记 CUBLAS 路径异常 | +| FP4 E2M1 BF16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 FP16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E2M1 BF16 `(64,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E2M1 BF16 `(4,8)`, `scale_alg=0`, `floor` | 暂缓 | +| FP4 E2M1 BF16 `(128,256)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 BF16 `(4,8)`, `scale_alg=0`, `rint` | 暂缓 | +| FP4 E1M2 FP16 `(64,512)`, `scale_alg=0`, `round` | 暂缓 | +| FP4 E1M2 BF16 `(4,8)`, `scale_alg=0`, `floor` | 暂缓 | + +`test_smoke.py` 在 shape `(4,8)`、`(64,512)`、`(128,256)`、dtype BF16/FP16、 +FP4/FP8 输出模式和 FP8 `scale_alg=1` 上跑 48 个执行面。它不是正确性 oracle; +只有在等价性 case 覆盖后才使用。 + +`test_constant_input.py` 用于诊断 `(4,8)` 和 `(64,512)` 上 BF16 E4M3 OCP +constant input。它可以支撑 tiny deterministic VMI case,但除非 equivalence suite +新增对应项,否则不应为每个 dtype/output type 建一套并行矩阵。 + +## tutorial / block_mx_quant + +来源:`.work/external/a5-kernel-standalone/cce/tutorial/block_mx_quant/README.md`。 + +tutorial kernel 是教学用途,和主 `block_mx_quant` 共享算法:BF16 输入、`scale_alg=0`、 +FP8/FP4 输出、32x32 block scale 和 scale2 interleaving。README 说明 smoke 有 9 个 +BF16 output-type case,cross-check 有 7 个 byte-exact case,但当前快照中没有详细测试文件。 +因此 tutorial 覆盖只作为主 `block_mx_quant` 表的代表 shape,不作为独立 family。 + +## 当前 VMI 目录裁剪结果 + +该目录已裁剪为 target-scoped runtime case。删除的 case 只有在目标仓库新增匹配的正确性入口, +或迁移到独立的非目标 probe suite 后,才应重新引入。 + +当前 `test/vpto/cases/vmi/kernels` 已缩减为 35 个 case 目录。上面的目标 CCE canonical +正确性范围在 `block_mx_quant` 采用 14-case canonical suite 时有 64 行;如果把 +`block_mx_quant/test_cce.py` 作为完整 small-shape surface union 计入,则有 80 行。 +这些数量不能直接和当前支持集比较,因为目标列表仍包含当前 VMI 有意暂缓的 FP4 行。 + +| Area | 当前 VMI 目录数 | 目标 canonical 正确性 | 差异 | +| --- | ---: | ---: | --- | +| `quant_minimum` / `tquant` | 4 | 4 | 对齐 `MINIMUM_CASES` | +| `block_quant` | 7 | 7 | 对齐 `test_equivalence.py` | +| `dynamic_quant` | 9 | 9 | 对齐 `test_dq_equivalence.py` | +| `anti_mx_quant` | 7 | 16 | 保留当前 FP8 目标行;暂缓的 FP4 行不表达 | +| `block_mx_quant` | 4 | 14 canonical / 30 full union | 保留 canonical FP8 目标行;暂缓的 FP4/DDR 和额外 `test_cce.py` union 行不表达 | +| `swiglu_mx_quant` | 4 | 14 | 保留当前 FP8/OCP 目标行;暂缓的 FP4 和异常 CUBLAS 行不表达 | +| historical `anti_quant` | 0 | 0 | 已从 target-scoped 目录移除 | +| historical `swiglu_quant` | 0 | 0 | 已从 target-scoped 目录移除 | +| other probe | 0 | 0 | 已从 target-scoped 目录移除 | + +## 当前支持目录清单 + +当前 target-scoped runtime 目录精确包含以下 35 个 VMI case: + +| CCE family | VMI case 目录 | +| --- | --- | +| `quant_minimum` / `tquant` | `tquant-mxfp8-32x32-nd`, `tquant-mxfp8-32x64-nz`, `tquant-int8-sym-64x128`, `tquant-int8-asym-64x128` | +| `block_quant` | `block-quant-bf16-fp8-2x128`, `block-quant-bf16-fp8-4x128`, `block-quant-bf16-fp8-4x128-min-scale`, `block-quant-bf16-fp8-32x128`, `block-quant-f16-fp8-4x256`, `block-quant-f16-fp8-8x128`, `block-quant-f16-fp8-16x256` | +| `dynamic_quant` | `dynamic-quant-pertoken-f16-4x32`, `dynamic-quant-pertoken-f16-16x128`, `dynamic-quant-pertoken-smooth-f16-8x64`, `dynamic-quant-pertoken-smooth-f16-16x128`, `dynamic-quant-perchannel-f16-128x128`, `dynamic-quant-perchannel-f16-256x256`, `dynamic-quant-pertoken-bf16-4x32`, `dynamic-quant-pertoken-smooth-bf16-8x64`, `dynamic-quant-perchannel-bf16-128x128` | +| `dequant/anti_mx_quant` | `anti-mx-f8-bf16-scaled-4x128`, `anti-mx-f8-f32-scaled-4x128`, `anti-mx-f8-f16-scaled-4x128`, `anti-mx-f8-bf16-scaled-16x512`, `anti-mx-f8-bf16-scaled-64x2048`, `anti-mx-f8e5m2-bf16-scaled-4x128`, `anti-mx-f8e5m2-bf16-scaled-16x512` | +| `block_mx_quant` | `block-mx-quant-bf16-e4m3-4x128`, `block-mx-quant-f16-e4m3-64x256`, `block-mx-quant-bf16-e5m2-4x128`, `block-mx-quant-f16-e5m2-8x256` | +| `swiglu_mx_quant` | `swiglu-mx-quant-bf16-e4m3-4x8`, `swiglu-mx-quant-f16-e4m3-64x512`, `swiglu-mx-quant-bf16-e5m2-4x8`, `swiglu-mx-quant-f16-e5m2-128x256` | + +| 已移除的 VMI 区域 | 范围说明 | +| --- | --- | +| 额外 `anti-mx` FP8 E5M2 -> FP16/FP32 large-shape case | 对称 decode 覆盖;未列入目标 `anti_mx_quant/test_equivalence.py` | +| 额外 `dynamic_quant` BF16 larger smooth/no-smooth case | 实现扩展覆盖;不在 9-case 目标 equivalence 列表中 | +| 额外 `block_mx_quant` random/shared-scale case | 对独立 golden 覆盖有用;不是直接目标测试名 | +| 额外 `block_mx_quant` FP16 `(4,128)` E4M3/E5M2 行 | 只存在于更宽的 `test_cce.py` union;不是 canonical equivalence checklist 的一部分 | +| 额外 `swiglu_mx_quant` constant BF16 4x8 proxy | 诊断输入模式;除非迁移为精确 equivalence 行,否则不保留在 target-scoped runtime case 中 | +| HIF8 `block_quant` probe | compiler/runtime surface probe;不是该目标仓库中的 raw CCE 正确性 case | diff --git a/test/vpto/cases/vmi/kernels/README.md b/test/vpto/cases/vmi/kernels/README.md new file mode 100644 index 0000000000..ff0fad599f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/README.md @@ -0,0 +1,49 @@ + + +# VMI Kernel 用例 + +本目录只保留从目标仓库 `.work/external/a5-kernel-standalone/cce` 的 raw CCE +正确性、等价性和 minimum 测试入口迁移而来的 VMI runtime case。 +范围定义见 [CCE_CASE_SCOPE.md](CCE_CASE_SCOPE.md)。 + +不要从历史 VMI probe、benchmark sweep、debug shape、random stress 或当前目录外的 +实验脚本反推支持范围。新增 case 前先确认目标 CCE 测试入口是否提供对应正确性语义。 + +## 当前目录范围 + +当前目录保留 35 个 runtime case: + +| CCE family | 当前 case 数 | 范围 | +| --- | ---: | --- | +| `quant_minimum` / `tquant` | 4 | 对齐 `MINIMUM_CASES` | +| `block_quant` | 7 | 对齐 `test_equivalence.py` | +| `dynamic_quant` | 9 | 对齐 `test_dq_equivalence.py` | +| `dequant/anti_mx_quant` | 7 | 当前保留 VMI 能表达的 FP8 行;FP4 输入暂缓 | +| `block_mx_quant` | 4 | 当前保留 canonical FP8/OCP 等价性行;FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | +| `swiglu_mx_quant` | 4 | 当前保留 FP8/OCP 等价性行;FP4 和 CCE 已标记异常的 CUBLAS `scale_alg=1` 暂缓 | + +## 设计上暂缓 + +下列目标 CCE 行是真实存在的,但在对应 VMI 语义设计清楚前,不应通过临时拼凑的 +runtime case 表达: + +| Case 类别 | 原因 | +| --- | --- | +| FP4 packed input/output | VMI 尚未定义 logical FP4 lane、packed-byte layout 和 FP4 load/store 语义 | +| `block_mx_quant` FP4 DDR `scale_alg=2` | 依赖 FP4 语义和 DDR scale 规则 | +| `swiglu_mx_quant` FP8 CUBLAS `scale_alg=1` | CCE 源码已标记该路径异常 | +| HIF8 `block_quant` | 只是 compiler/runtime surface probe,不是目标仓库里的 raw CCE 正确性 case | + +## 验证策略 + +每个保留的 runtime case 都应通过 `test/vpto/scripts/run_host_vpto_validation.sh`。 +新增 CCE 迁移 case 时,先在 [CCE_CASE_SCOPE.md](CCE_CASE_SCOPE.md) 记录对应的 +目标 CCE 来源行。除非新 case 覆盖不同的 VMI 语义或 lowering 约束,否则不要重复添加同构 shape。 diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py new file mode 100644 index 0000000000..09a91b695c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 512 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], + dtype=np.float32, +) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto new file mode 100644 index 0000000000..130e50768d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_16x512_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c32_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c512 : index + %row_scale_off = arith.muli %row, %c64 : index + scf.for %col = %c0 to %c512 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp new file mode 100644 index 0000000000..bb37bcb8fa --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_bf16_scaled_16x512_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_bf16_scaled_16x512_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp new file mode 100644 index 0000000000..b9080bacc5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 512; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_16x512_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py new file mode 100644 index 0000000000..b714f6d4db --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/golden.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..8063c2406b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp similarity index 72% rename from test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp rename to test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp index 630c7d55af..8fd6bdeaab 100644 --- a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/launch.cpp +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/launch.cpp @@ -31,13 +31,15 @@ struct MrgSortExecutedNumList { #endif extern "C" __global__ [aicore] void -vmi_simdvf_per_token_cast_to_fp8_kernel(__gm__ float *src, - __gm__ float *scale, - __gm__ uint8_t *out8); +vmi_anti_mx_f8_bf16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); -void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, - uint8_t *out8, +void LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, void *stream) { - vmi_simdvf_per_token_cast_to_fp8_kernel<<<1, nullptr, stream>>>( - (__gm__ float *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out8); + vmi_anti_mx_f8_bf16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); } diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..414bb6bd53 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py new file mode 100644 index 0000000000..abce51a142 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 2048 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], + dtype=np.float32, +) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto new file mode 100644 index 0000000000..2ffb1b8951 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_bf16_scaled_64x2048_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + %c32768 = arith.constant 32768 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c65536_i64 = arith.constant 65536 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c65536_i64 : i64 -> !pto.ptr + + scf.for %tile = %c0 to %c4 step %c1 { + %src_tile_off = arith.muli %tile, %c32768 : index + %scale_tile_base_off = arith.muli %tile, %c1024 : index + %dst_tile_off = arith.muli %tile, %c32768 : index + %src_tile_gm = pto.addptr %src_gm, %src_tile_off + : !pto.ptr -> !pto.ptr + %scale_tile_gm = pto.addptr %mxscale_gm, %scale_tile_base_off + : !pto.ptr -> !pto.ptr + %dst_tile_gm = pto.addptr %dst_gm, %dst_tile_off + : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %src_tile_gm, %ub_src_u8, %c0_i64, %c2048_i64 + nburst(%c16_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_tile_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c128_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c2048 : index + %row_scale_off = arith.muli %row, %c256 : index + scf.for %col = %c0 to %c2048 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_tile_gm, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + } + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp new file mode 100644 index 0000000000..3768c6fd0c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_bf16_scaled_64x2048_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_bf16_scaled_64x2048_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp new file mode 100644 index 0000000000..068ff83a6b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 2048; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_bf16_scaled_64x2048_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py new file mode 100644 index 0000000000..d5dcd9b576 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] f16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py new file mode 100644 index 0000000000..fc1502fce0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_F16 = np.uint16(0x7E00) + + +def f32_to_f16_bits(values: np.ndarray) -> np.ndarray: + return values.astype(np.float16).view(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_F16, dtype=np.uint16) + golden = f32_to_f16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..8ca4aa1654 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_f16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..8f8b341af5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_f16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ half *dst); + +void LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8_f16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, (__gm__ half *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..c357020a90 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_f16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py new file mode 100644 index 0000000000..cc852efbc2 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.float32) + out = np.fromfile("v3.bin", dtype=np.float32) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = golden[idx] if idx >= 0 else "n/a" + out_value = out[idx] if idx >= 0 else "n/a" + print( + f"[ERROR] f32 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py new file mode 100644 index 0000000000..d453724750 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8], dtype=np.uint8) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_F32 = np.float32(-777.0) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E4M3FN_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) + golden = scaled.reshape(-1).astype(np.float32) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..cd291a4465 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8_f32_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E4M3FN> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E4M3FN> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %scaled, %ub_dst[%offset] + : !pto.vmi.vreg<256xf32>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..66a052ad5c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8_f32_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ float *dst); + +void LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + float *dst, + void *stream) { + vmi_anti_mx_f8_f32_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, (__gm__ float *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp new file mode 100644 index 0000000000..839873e18a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + float *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(float); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + float *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + float *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8_f32_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py new file mode 100644 index 0000000000..77fed89e97 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/golden.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 512 +ELEMS = ROWS * COLS +MXSCALE_BYTES = ROWS * (COLS // 32) +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) +E8M0_BYTES = np.array( + [0x7E, 0x7F, 0x80, 0x81, 0x7D, 0x82, 0x7C, 0x83], dtype=np.uint8 +) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E5M2_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + groups = COLS // 32 + scale_repeats = (groups + len(E8M0_BYTES) - 1) // len(E8M0_BYTES) + scale_row = np.tile(E8M0_BYTES, scale_repeats)[:groups].astype(np.uint8) + mxscale_matrix = np.tile(scale_row, (ROWS, 1)).astype(np.uint8) + mxscale[:] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(groups): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto new file mode 100644 index 0000000000..8b6ce29071 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c32_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c16 step %c1 { + %row_elem_off = arith.muli %row, %c512 : index + %row_scale_off = arith.muli %row, %c64 : index + scf.for %col = %c0 to %c512 step %c256 { + %offset = arith.addi %row_elem_off, %col : index + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E5M2> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E5M2> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %col, %c256 : index + %scale_tile_off = arith.muli %scale_tile, %c32 : index + %scale_off = arith.addi %row_scale_off, %scale_tile_off : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp new file mode 100644 index 0000000000..08d84b318d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp new file mode 100644 index 0000000000..5700dec3a9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 512; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = kRows * (kCols / 32); + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8e5m2_bf16_scaled_16x512_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py new file mode 100644 index 0000000000..f9d83f2328 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/compare.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v3.bin", dtype=np.uint16) + out = np.fromfile("v3.bin", dtype=np.uint16) + + if golden.shape != out.shape or not np.array_equal(golden, out): + diff = np.nonzero(golden != out)[0] if golden.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + golden_value = f"0x{int(golden[idx]):04x}" if idx >= 0 else "n/a" + out_value = f"0x{int(out[idx]):04x}" if idx >= 0 else "n/a" + print( + f"[ERROR] bf16 compare failed idx={idx} " + f"golden={golden_value} output={out_value}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py new file mode 100644 index 0000000000..e3bc0b6db1 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/golden.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +ELEMS = ROWS * COLS +MXSCALE_BYTES = 32 +VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) +E8M0_BYTES = np.array([0x7E, 0x7F, 0x80, 0x81], dtype=np.uint8) +SENTINEL_BF16 = np.uint16(0x7FC0) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (ELEMS + len(VALUES) - 1) // len(VALUES) + src = np.tile(F8E5M2_BYTES, repeats)[:ELEMS].astype(np.uint8).reshape(ROWS, COLS) + decoded = np.tile(VALUES, repeats)[:ELEMS].astype(np.float32).reshape(ROWS, COLS) + mxscale = np.full(MXSCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + mxscale_matrix = np.tile(E8M0_BYTES, (ROWS, 1)).astype(np.uint8) + mxscale[: ROWS * 4] = mxscale_matrix.reshape(-1) + scale_values = np.ldexp( + np.ones_like(mxscale_matrix, dtype=np.float32), + mxscale_matrix.astype(np.int32) - 127, + ) + scaled = decoded.copy() + for row in range(ROWS): + for group in range(4): + start = group * 32 + stop = start + 32 + scaled[row, start:stop] *= scale_values[row, group] + dst = np.full(ELEMS, SENTINEL_BF16, dtype=np.uint16) + golden = f32_to_bf16_bits(scaled.reshape(-1)) + + output_dir.mkdir(parents=True, exist_ok=True) + src.reshape(-1).tofile(output_dir / "v1.bin") + mxscale.tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v3.bin") + golden.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto new file mode 100644 index 0000000000..dec0c5ecca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel( + %src_gm: !pto.ptr, %mxscale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c23_i32 = arith.constant 23 : i32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src_u8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src_f8 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_mxscale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %mxscale_gm, %ub_mxscale, %c0_i64, %c8_i64 + nburst(%c2_i64, %c8_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %_:1 = scf.for %offset = %c0 to %c512 step %c256 iter_args(%dummy = %c0) -> (index) { + %packed = pto.vmi.load %ub_src_f8[%offset] + : !pto.ptr -> !pto.vmi.vreg<256xf8E5M2> + %wide = pto.vmi.extf %packed + : !pto.vmi.vreg<256xf8E5M2> -> !pto.vmi.vreg<256xf32> + %scale_tile = arith.divui %offset, %c256 : index + %scale_off = arith.muli %scale_tile, %c32 : index + %scale_u8 = pto.vmi.group_slot_load %ub_mxscale[%scale_off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xui8> + %scale_u32 = pto.vmi.extui %scale_u8 + : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + %scale_i32 = pto.vmi.bitcast %scale_u32 + : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_i32, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %wide, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %out, %ub_dst[%offset] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + scf.yield %dummy : index + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp new file mode 100644 index 0000000000..0e2fdf35e7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(__gm__ uint8_t *src, + __gm__ uint8_t *mxscale, + __gm__ bfloat16_t *dst); + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream) { + vmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ uint8_t *)src, (__gm__ uint8_t *)mxscale, + (__gm__ bfloat16_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp new file mode 100644 index 0000000000..9125352fc8 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(uint8_t *src, + uint8_t *mxscale, + uint16_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kMxScaleBytes = 32; + size_t srcBytes = kElems * sizeof(uint8_t); + size_t mxscaleBytes = kMxScaleBytes; + size_t dstBytes = kElems * sizeof(uint16_t); + uint8_t *srcHost = nullptr; + uint8_t *mxscaleHost = nullptr; + uint16_t *dstHost = nullptr; + uint8_t *srcDevice = nullptr; + uint8_t *mxscaleDevice = nullptr; + uint16_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&mxscaleHost), mxscaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&mxscaleDevice, mxscaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", mxscaleBytes, mxscaleHost, mxscaleBytes); + ReadFile("./v3.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(mxscaleDevice, mxscaleBytes, mxscaleHost, mxscaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_anti_mx_f8e5m2_bf16_scaled_4x128_kernel(srcDevice, mxscaleDevice, + dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(mxscaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(mxscaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py similarity index 62% rename from test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py rename to test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py index c6e34633b5..98eebe4477 100644 --- a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/compare.py +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/compare.py @@ -12,22 +12,6 @@ import numpy as np -def check_f32(name: str, atol: float, rtol: float) -> bool: - golden = np.fromfile(f"golden_{name}.bin", dtype=np.float32) - output = np.fromfile(f"{name}.bin", dtype=np.float32) - close = golden.shape == output.shape and np.allclose(golden, output, atol=atol, rtol=rtol) - if close: - return True - diff = np.nonzero(~np.isclose(golden, output, atol=atol, rtol=rtol))[0] - idx = int(diff[0]) if diff.size else -1 - print( - f"[ERROR] compare failed {name} idx={idx} " - f"golden={golden[idx] if idx >= 0 else 'n/a'} " - f"output={output[idx] if idx >= 0 else 'n/a'}" - ) - return False - - def check_u8(name: str) -> bool: golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) output = np.fromfile(f"{name}.bin", dtype=np.uint8) @@ -35,12 +19,15 @@ def check_u8(name: str) -> bool: return True diff = np.nonzero(golden != output)[0] idx = int(diff[0]) if diff.size else -1 - print(f"[ERROR] compare failed {name} idx={idx} golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}") + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) return False def main() -> None: - if not check_f32("v2", 1e-5, 1e-5) or not check_u8("v3"): + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): sys.exit(2) print("[INFO] compare passed") diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py new file mode 100644 index 0000000000..0b646c2c73 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE1_BYTES = 16 +SCALE2_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(f32_to_bf16_bits(q_row / np.float32(256.0)), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x77), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x77) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto new file mode 100644 index 0000000000..4010bd6b29 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_bf16_e4m3_4x128_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c4 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_slot = arith.divui %row, %c2 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c119_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c2_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp new file mode 100644 index 0000000000..462595a2f6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_bf16_e4m3_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_bf16_e4m3_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp new file mode 100644 index 0000000000..760e76a1c9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 16; + constexpr size_t kScale2Bytes = 256; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_bf16_e4m3_4x128_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py new file mode 100644 index 0000000000..e8c20fb5fe --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/golden.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE1_BYTES = 16 +SCALE2_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(f32_to_bf16_bits(q_row), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x7F), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x7F) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto new file mode 100644 index 0000000000..7c0ce6174b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto @@ -0,0 +1,151 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_bf16_e5m2_4x128_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c4 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_slot = arith.divui %row, %c2 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c2_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp new file mode 100644 index 0000000000..35bd3e233c --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_bf16_e5m2_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_bf16_e5m2_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp new file mode 100644 index 0000000000..8fc855e6b7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 16; + constexpr size_t kScale2Bytes = 256; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_bf16_e5m2_4x128_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py new file mode 100644 index 0000000000..5dc8e5bcce --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/golden.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 256 +SCALE1_BYTES = 512 +SCALE2_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile((q_row / np.float32(256.0)).astype(np.float16), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x77), dtype=np.uint8) + golden_scale2 = np.full(SCALE2_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto new file mode 100644 index 0000000000..c248b12f8b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto @@ -0,0 +1,138 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_f16_e4m3_64x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c64 step %c1 { + %elem_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %scale2_i32 = pto.vmi.broadcast %c119_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + pto.vmi.store %scale2_u8, %ub_scale2[%c256] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c64_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp new file mode 100644 index 0000000000..642c2b33d9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_f16_e4m3_64x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_f16_e4m3_64x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp new file mode 100644 index 0000000000..ec205e289a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 512; + constexpr size_t kScale2Bytes = 512; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_f16_e4m3_64x256_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py new file mode 100644 index 0000000000..98eebe4477 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py new file mode 100644 index 0000000000..27efbace62 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 256 +SCALE1_BYTES = 64 +SCALE2_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row.astype(np.float16), (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale1 = np.full(SCALE1_BYTES, np.uint8(0x7F), dtype=np.uint8) + golden_scale2 = np.zeros(SCALE2_BYTES, dtype=np.uint8) + golden_scale2[0::2] = np.uint8(0x7F) + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + scale1 = np.full(SCALE1_BYTES, SENTINEL_U8, dtype=np.uint8) + scale2 = np.full(SCALE2_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale1.tofile(output_dir / "v3.bin") + scale2.tofile(output_dir / "v4.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale1.tofile(output_dir / "golden_v3.bin") + golden_scale2.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto new file mode 100644 index 0000000000..40587088c3 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto @@ -0,0 +1,154 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_mx_quant_f16_e5m2_8x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale1_gm: !pto.ptr, + %scale2_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale1 = pto.castptr %c49152_i64 : i64 -> !pto.ptr + %ub_scale2 = pto.castptr %c53248_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c8 step %c1 { + %elem_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + %scale2_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_one = pto.vmi.broadcast %c1_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_parity = pto.vmi.andi %scale2_lane, %scale2_one + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_zero = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_even = pto.vmi.cmpi "eq", %scale2_parity, %scale2_zero + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %scale2_valid = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale2_i32 = pto.vmi.select %scale2_even, %scale2_valid, %scale2_zero + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale2_u8 = pto.vmi.trunci %scale2_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %scale2_u8, %ub_scale2[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + pto.vmi.store %scale2_u8, %ub_scale2[%c256] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale1, %scale1_gm, %c8_i64 + nburst(%c8_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale2, %scale2_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp new file mode 100644 index 0000000000..1b139bdee6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_mx_quant_f16_e5m2_8x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale1, + __gm__ uint8_t *scale2); + +void LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream) { + vmi_block_mx_quant_f16_e5m2_8x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale1, + (__gm__ uint8_t *)scale2); +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp new file mode 100644 index 0000000000..f5932ec784 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/main.cpp @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale1, uint8_t *scale2, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScale1Bytes = 64; + constexpr size_t kScale2Bytes = 512; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t outBytes = kElems * sizeof(uint8_t); + size_t scale1Bytes = kScale1Bytes; + size_t scale2Bytes = kScale2Bytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scale1Host = nullptr; + uint8_t *scale2Host = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scale1Device = nullptr; + uint8_t *scale2Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale1Host), scale1Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scale2Host), scale2Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale1Device, scale1Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scale2Device, scale2Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scale1Bytes, scale1Host, scale1Bytes); + ReadFile("./v4.bin", scale2Bytes, scale2Host, scale2Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale1Device, scale1Bytes, scale1Host, scale1Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scale2Device, scale2Bytes, scale2Host, scale2Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_mx_quant_f16_e5m2_8x256_kernel( + srcDevice, outDevice, scale1Device, scale2Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale1Host, scale1Bytes, scale1Device, scale1Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scale2Host, scale2Bytes, scale2Device, scale2Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scale1Host, scale1Bytes); + WriteFile("./v4.bin", scale2Host, scale2Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scale1Device); + aclrtFree(scale2Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scale1Host); + aclrtFreeHost(scale2Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py new file mode 100644 index 0000000000..e139f87144 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 2 +COLS = 128 +SCALE_SLOTS = ROWS +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto similarity index 50% rename from test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto rename to test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto index f2dcc0cd16..b920c0da85 100644 --- a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto @@ -7,33 +7,30 @@ // See LICENSE in the root of the software repository for the full text of the License. module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { - func.func @vmi_simdvf_per_token_cast_to_fp8_kernel(%src_gm: !pto.ptr, - %scale_gm: !pto.ptr, - %out8_gm: !pto.ptr) attributes {pto.kernel} { + func.func @vmi_block_quant_bf16_fp8_2x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index %c0_i64 = arith.constant 0 : i64 %c1_i64 = arith.constant 1 : i64 + %c8_i64 = arith.constant 8 : i64 %c256_i64 = arith.constant 256 : i64 - %c1024_i64 = arith.constant 1024 : i64 + %c512_i64 = arith.constant 512 : i64 %c4096_i64 = arith.constant 4096 : i64 %c8192_i64 = arith.constant 8192 : i64 - %eps = arith.constant 1.000000e-04 : f32 %fp8_max = arith.constant 4.480000e+02 : f32 - %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr - %ub_out8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr - %ub_out8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr - pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 - nburst(%c1_i64, %c1024_i64, %c1024_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c1024_i64 - nburst(%c1_i64, %c1024_i64, %c1024_i64) - : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 - pto.mte_gm_ub %out8_gm, %ub_out8_u8, %c0_i64, %c256_i64 + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 @@ -42,35 +39,39 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<256xpred> - %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> - %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> - %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> -> !pto.vmi.vreg<256xf32> - %eps1 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> - %amax = pto.vmi.maxf %amax_raw, %eps1 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> - %fp8_max1 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> - %scale = pto.vmi.divf %amax, %fp8_max1 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> - pto.vmi.group_store %scale, %ub_scale[%c0], %c8 {num_groups = 2} + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.ptr - %scale_vec = pto.vmi.group_broadcast %scale - {num_groups = 2} : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> %q8 = pto.vmi.truncf %q : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> - pto.vmi.store %q8, %ub_out8_f8[%c0] + pto.vmi.store %q8, %ub_out_f8[%c0] : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] - pto.mte_ub_gm %ub_scale, %scale_gm, %c1024_i64 - nburst(%c1_i64, %c1024_i64, %c1024_i64) + pto.mte_ub_gm %ub_scale, %scale_gm, %c8_i64 + nburst(%c1_i64, %c8_i64, %c8_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 - pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c256_i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c256_i64 nburst(%c1_i64, %c256_i64, %c256_i64) : !pto.ptr, !pto.ptr, i64, i64, i64, i64 pto.barrier #pto.pipe diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp new file mode 100644 index 0000000000..a1bb355958 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_2x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_2x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_2x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp new file mode 100644 index 0000000000..632018a6ff --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_2x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 2; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_2x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py new file mode 100644 index 0000000000..a166b6aa4a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 128 +SCALE_SLOTS = ROWS +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto new file mode 100644 index 0000000000..a7162689b8 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto @@ -0,0 +1,83 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_32x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c2 { + %elem_off = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp new file mode 100644 index 0000000000..24599cfec9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_32x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_32x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_32x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp new file mode 100644 index 0000000000..cd5211e167 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_32x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_32x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py new file mode 100644 index 0000000000..e4e1f95824 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/golden.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALE_LIMIT = np.float32(0.25) +SCALES = np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [ + 0.0, + 224.0, + -224.0, + 112.0, + -112.0, + 64.0, + -64.0, + 32.0, + -32.0, + 16.0, + -16.0, + 8.0, + -8.0, + 4.0, + -4.0, + 2.0, + -2.0, + 1.0, + -1.0, + 0.5, + -0.5, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def decode_f8e4m3fn(byte: int) -> np.float32: + sign = -1.0 if byte & 0x80 else 1.0 + exp = (byte >> 3) & 0x0F + mant = byte & 0x07 + if byte in (0x7F, 0xFF): + return np.float32(np.nan) + if exp == 0: + return np.float32(sign * (mant / 8.0) * (2.0**-6)) + return np.float32(sign * (1.0 + mant / 8.0) * (2.0 ** (exp - 7))) + + +def f8e4m3fn_exact_bytes(values: np.ndarray) -> np.ndarray: + exact = {} + for byte in range(0x100): + decoded = decode_f8e4m3fn(byte) + if not np.isnan(decoded): + exact.setdefault(np.float32(decoded).item(), byte) + return np.array([exact[np.float32(value).item()] for value in values], dtype=np.uint8) + + +def f8e4m3fn_saturating_bytes(values: np.ndarray) -> np.ndarray: + clipped = np.clip(values.astype(np.float32), -FP8_MAX, FP8_MAX) + return f8e4m3fn_exact_bytes(clipped) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + raw_scale = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + scale = np.minimum(raw_scale, SCALE_LIMIT).astype(np.float32) + golden_scale[row] = scale + golden_out[row] = f8e4m3fn_saturating_bytes(x_f32 / scale) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto new file mode 100644 index 0000000000..4d2b183fed --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_4x128_min_scale_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + %scale_limit = arith.constant 2.500000e-01 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale_raw = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scale_limit_v = pto.vmi.broadcast %scale_limit + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.minf %scale_raw, %scale_limit_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c256] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale_raw = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scale_limit_v = pto.vmi.broadcast %scale_limit + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.minf %scale_raw, %scale_limit_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c256] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp new file mode 100644 index 0000000000..cf7c8a8aa5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_4x128_min_scale_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_4x128_min_scale_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp new file mode 100644 index 0000000000..8218bd476e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_4x128_min_scale_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py new file mode 100644 index 0000000000..136c639dde --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto new file mode 100644 index 0000000000..330e6341ec --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto @@ -0,0 +1,113 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_bf16_fp8_4x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.vecscope { + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c256] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c256] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp new file mode 100644 index 0000000000..18de0443fd --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_bf16_fp8_4x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_bf16_fp8_4x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_bf16_fp8_4x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp new file mode 100644 index 0000000000..b021bda736 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_bf16_fp8_4x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_bf16_fp8_4x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py new file mode 100644 index 0000000000..934d197785 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 256 +BLOCK_COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 8) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (BLOCK_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:BLOCK_COLS].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:BLOCK_COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + for block in range(COLS // BLOCK_COLS): + group = row * (COLS // BLOCK_COLS) + block + start = block * BLOCK_COLS + stop = start + BLOCK_COLS + src[row, start:stop] = (q_block * SCALES[group]).astype(np.float16) + x_f32 = src[row, start:stop].astype(np.float32) + golden_scale[group] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row, start:stop] = f8_block + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto new file mode 100644 index 0000000000..014cb9b8ce --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_16x256_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c16384_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c32768_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c16 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp new file mode 100644 index 0000000000..9aaa904f48 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_16x256_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_16x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_16x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp new file mode 100644 index 0000000000..677cf3c033 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_f16_fp8_16x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_16x256_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py new file mode 100644 index 0000000000..14e3870858 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/golden.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 256 +BLOCK_COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 2) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (BLOCK_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:BLOCK_COLS].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:BLOCK_COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + for block in range(COLS // BLOCK_COLS): + group = row * (COLS // BLOCK_COLS) + block + start = block * BLOCK_COLS + stop = start + BLOCK_COLS + src[row, start:stop] = (q_block * SCALES[group]).astype(np.float16) + x_f32 = src[row, start:stop].astype(np.float32) + golden_scale[group] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row, start:stop] = f8_block + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto new file mode 100644 index 0000000000..5a447251cc --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_4x256_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c4 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp new file mode 100644 index 0000000000..123cfafdc5 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_4x256_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_4x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_4x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp similarity index 72% rename from test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp rename to test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp index cbb7149b86..74c90073ae 100644 --- a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/main.cpp +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/main.cpp @@ -25,21 +25,23 @@ using namespace PtoTestCommon; } \ } while (0) -void LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(float *src, float *scale, - uint8_t *out8, - void *stream); +void LaunchVmi_block_quant_f16_fp8_4x256_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); int main() { - constexpr size_t kElems = 256; - size_t srcBytes = kElems * sizeof(float); - size_t scaleBytes = kElems * sizeof(float); - size_t out8Bytes = kElems * sizeof(uint8_t); - float *srcHost = nullptr; + constexpr size_t kRows = 4; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; float *scaleHost = nullptr; - uint8_t *out8Host = nullptr; - float *srcDevice = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; float *scaleDevice = nullptr; - uint8_t *out8Device = nullptr; + uint8_t *outDevice = nullptr; int rc = 0; bool aclInited = false; bool deviceSet = false; @@ -55,32 +57,32 @@ int main() { ACL_CHECK(aclrtCreateStream(&stream)); ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); - ACL_CHECK(aclrtMallocHost((void **)(&out8Host), out8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&out8Device, out8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); - ReadFile("./v3.bin", out8Bytes, out8Host, out8Bytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(out8Device, out8Bytes, out8Host, out8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); - LaunchVmi_simdvf_per_token_cast_to_fp8_kernel(srcDevice, scaleDevice, - out8Device, stream); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_4x256_kernel(srcDevice, scaleDevice, + outDevice, stream); ACL_CHECK(aclrtSynchronizeStream(stream)); ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); - ACL_CHECK(aclrtMemcpy(out8Host, out8Bytes, out8Device, out8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); WriteFile("./v2.bin", scaleHost, scaleBytes); - WriteFile("./v3.bin", out8Host, out8Bytes); + WriteFile("./v3.bin", outHost, outBytes); cleanup: aclrtFree(srcDevice); aclrtFree(scaleDevice); - aclrtFree(out8Device); + aclrtFree(outDevice); aclrtFreeHost(srcHost); aclrtFreeHost(scaleHost); - aclrtFreeHost(out8Host); + aclrtFreeHost(outHost); if (stream) aclrtDestroyStream(stream); if (deviceSet) diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py new file mode 100644 index 0000000000..c40bdb3270 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/compare.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=0, atol=0 + ): + diff = np.nonzero(golden_scale != scale)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py new file mode 100644 index 0000000000..6345d8a069 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/golden.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 128 +SCALE_SLOTS = 128 +FP8_MAX = np.float32(448.0) +SCALES = np.tile(np.array([0.25, 0.5, 1.0, 2.0], dtype=np.float32), 2) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + src[row] = (q_row * SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + golden_scale[row] = np.max(np.abs(x_f32)).astype(np.float32) / FP8_MAX + golden_out[row] = f8_row + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto new file mode 100644 index 0000000000..e4f00cab6e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto @@ -0,0 +1,89 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_block_quant_f16_fp8_8x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %fp8_max = arith.constant 4.480000e+02 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %chunk = %c0 to %c4 step %c1 { + %elem_off = arith.muli %chunk, %c256 : index + %scale_off = arith.muli %chunk, %c2 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %fp8_max_v = pto.vmi.broadcast %fp8_max + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %fp8_max_v + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_u8, %out_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp new file mode 100644 index 0000000000..75929341b7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_block_quant_f16_fp8_8x128_kernel(__gm__ half *src, __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_block_quant_f16_fp8_8x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream) { + vmi_block_quant_f16_fp8_8x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp new file mode 100644 index 0000000000..433abffbd7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_block_quant_f16_fp8_8x128_kernel(uint16_t *src, float *scale, + uint8_t *out, void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = 128; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_block_quant_f16_fp8_8x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py new file mode 100644 index 0000000000..f9bc779e12 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/golden.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = f32_to_bf16_bits(q[:, None] * col_scales[None, :]) + x_f32 = bf16_bits_to_f32(src) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto new file mode 100644 index 0000000000..81d857b8f6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/kernel.pto @@ -0,0 +1,188 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_bf16_128x128_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c102400_i64 = arith.constant 102400 : i64 + %c106496_i64 = arith.constant 106496 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scratch = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale_padded = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c102400_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c106496_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c128 step %c1 { + %elem_offset = arith.muli %row, %c128 : index + %x_bf16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<128xbf16> + %x_f32 = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<128xbf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x_f32, %ub_scratch[%elem_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + scf.for %col = %c0 to %c128 step %c1 { + %col_i32 = arith.index_cast %col : index to i32 + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %init = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %stride = pto.vmi.broadcast %c128_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %row_offsets = pto.vmi.muli %lane, %stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_vec = pto.vmi.broadcast %col_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %indices = pto.vmi.addi %row_offsets, %col_vec + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_values = pto.vmi.gather %ub_scratch[%indices], %mask, %zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %col_values + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.reduce_maxf %abs, %init, %mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %padded_col = arith.muli %col, %c8 : index + pto.vmi.store %scale, %ub_scale_padded[%padded_col] + : !pto.vmi.vreg<1xf32>, !pto.ptr + } + + %scale_mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %scale_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %scale_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_indices = pto.vmi.muli %scale_lane, %scale_stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %scale_dense = pto.vmi.gather %ub_scale_padded[%scale_indices], %scale_mask, %scale_zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %scale_dense, %ub_scale[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %scale_pair_mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %scale_pair_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_pair_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_group_mask = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_col = pto.vmi.andi %scale_pair_lane, %scale_pair_group_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_pair_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_indices = pto.vmi.muli %scale_pair_col, %scale_pair_stride + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_vec = pto.vmi.gather %ub_scale_padded[%scale_pair_indices], %scale_pair_mask, %scale_pair_zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.for %pair = %c0 to %c64 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_scratch[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp new file mode 100644 index 0000000000..08a9931a37 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_bf16_128x128_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_bf16_128x128_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp new file mode 100644 index 0000000000..451bb794b6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_bf16_128x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-bf16-128x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py new file mode 100644 index 0000000000..a249eec0ee --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = (q[:, None] * col_scales[None, :]).astype(np.float16) + x_f32 = src.astype(np.float32) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto new file mode 100644 index 0000000000..0bcf64283d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/kernel.pto @@ -0,0 +1,188 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_f16_128x128_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c102400_i64 = arith.constant 102400 : i64 + %c106496_i64 = arith.constant 106496 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scratch = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_scale_padded = pto.castptr %c98304_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c102400_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c106496_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c128 step %c1 { + %elem_offset = arith.muli %row, %c128 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x_f32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x_f32, %ub_scratch[%elem_offset] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + + scf.for %col = %c0 to %c128 step %c1 { + %col_i32 = arith.index_cast %col : index to i32 + %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %init = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %stride = pto.vmi.broadcast %c128_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %row_offsets = pto.vmi.muli %lane, %stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_vec = pto.vmi.broadcast %col_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %indices = pto.vmi.addi %row_offsets, %col_vec + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %col_values = pto.vmi.gather %ub_scratch[%indices], %mask, %zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %col_values + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.reduce_maxf %abs, %init, %mask + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<1xf32>, + !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<1xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<1xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %padded_col = arith.muli %col, %c8 : index + pto.vmi.store %scale, %ub_scale_padded[%padded_col] + : !pto.vmi.vreg<1xf32>, !pto.ptr + } + + %scale_mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %scale_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<128xf32> + %scale_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<128xi32> + %scale_indices = pto.vmi.muli %scale_lane, %scale_stride + : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<128xi32> + %scale_dense = pto.vmi.gather %ub_scale_padded[%scale_indices], %scale_mask, %scale_zero + : !pto.ptr, !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + pto.vmi.store %scale_dense, %ub_scale[%c0] + : !pto.vmi.vreg<128xf32>, !pto.ptr + + %scale_pair_mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %scale_pair_zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_pair_lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_group_mask = pto.vmi.broadcast %c127_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_col = pto.vmi.andi %scale_pair_lane, %scale_pair_group_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_pair_stride = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_pair_indices = pto.vmi.muli %scale_pair_col, %scale_pair_stride + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_vec = pto.vmi.gather %ub_scale_padded[%scale_pair_indices], %scale_pair_mask, %scale_pair_zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.for %pair = %c0 to %c64 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_scratch[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp new file mode 100644 index 0000000000..abec34be2e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_f16_128x128_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_f16_128x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp new file mode 100644 index 0000000000..741d5e3d9e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_f16_128x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-128x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py new file mode 100644 index 0000000000..dc9933534f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 256 +COLS = 256 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +ROW_Q = np.array( + [ + -127, + -96, + -64, + -32, + -7, + -1, + 0, + 1, + 7, + 16, + 31, + 63, + 95, + 120, + 127, + 64, + ], + dtype=np.float32, +) +COL_SCALES = np.array([0.125, 0.25, 0.5, 1.0, 2.0], dtype=np.float32) + + +def generate(output_dir: Path) -> None: + q = np.tile(ROW_Q, (ROWS + len(ROW_Q) - 1) // len(ROW_Q))[:ROWS] + col_scales = np.tile(COL_SCALES, (COLS + len(COL_SCALES) - 1) // len(COL_SCALES))[:COLS] + + src = (q[:, None] * col_scales[None, :]).astype(np.float16) + x_f32 = src.astype(np.float32) + golden_scale = (np.max(np.abs(x_f32), axis=0) / INT8_MAX).astype(np.float32) + scale_safe = np.where(golden_scale > 0, golden_scale, np.ones_like(golden_scale)) + golden_out = np.round(x_f32 / scale_safe[None, :]).astype(np.float32) + golden_out = np.clip(golden_out, -128, 127).astype(np.int8) + + scale = np.full(COLS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto new file mode 100644 index 0000000000..89507a9228 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/kernel.pto @@ -0,0 +1,117 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_perchannel_f16_256x256_kernel( + %src_gm: !pto.ptr, %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c135168_i64 = arith.constant 135168 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c135168_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c131072_i64 + nburst(%c1_i64, %c131072_i64, %c131072_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %zero_acc = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale_acc = scf.for %row = %c0 to %c256 step %c1 + iter_args(%acc = %zero_acc) -> (!pto.vmi.vreg<256xf32>) { + %elem_offset = arith.muli %row, %c256 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %next = pto.vmi.maxf %acc, %abs + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + scf.yield %next : !pto.vmi.vreg<256xf32> + } + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %scale_acc, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.store %scale, %ub_scale[%c0] + : !pto.vmi.vreg<256xf32>, !pto.ptr + + scf.for %row = %c0 to %c256 step %c1 { + %elem_offset = arith.muli %row, %c256 : index + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp new file mode 100644 index 0000000000..a805d4f3f4 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_perchannel_f16_256x256_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_perchannel_f16_256x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp new file mode 100644 index 0000000000..0b840eff7a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 256; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_perchannel_f16_256x256_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-perchannel-f16-256x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py new file mode 100644 index 0000000000..dc44fd67ba --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/golden.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 32 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * ROW_SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto new file mode 100644 index 0000000000..76941c1daa --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_bf16_4x32_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_bf16 = arith.constant 1.000000e+00 : bf16 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %pad_bf16 = pto.vmi.broadcast %c1_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %pad_bf16, %ub_src[%c128] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.mem_bar "VST_VLD" + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp new file mode 100644 index 0000000000..f8514dfdb6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_bf16_4x32_kernel(__gm__ bfloat16_t *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_bf16_4x32_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp new file mode 100644 index 0000000000..f11096d3e6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_bf16_4x32_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py new file mode 100644 index 0000000000..e17afddf0a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/golden.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 128 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto new file mode 100644 index 0000000000..f40ca1d9d7 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_f16_16x128_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c8 step %c1 { + %row = arith.muli %tile, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp new file mode 100644 index 0000000000..5f63392588 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_f16_16x128_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_f16_16x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp new file mode 100644 index 0000000000..1385903f12 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_f16_16x128_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py new file mode 100644 index 0000000000..0142cbc20f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v2.bin", dtype=np.float32) + scale = np.fromfile("v2.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py new file mode 100644 index 0000000000..351c27ef75 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/golden.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +COLS = 32 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.empty(ROWS, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[row] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(ROWS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_scale.tofile(output_dir / "golden_v2.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto new file mode 100644 index 0000000000..deba0f577f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto @@ -0,0 +1,112 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_f16_4x32_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1_f16 = arith.constant 1.000000e+00 : f16 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c16_i64 = arith.constant 16 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c1024_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c2048_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %pad_f16 = pto.vmi.broadcast %c1_f16 + : f16 -> !pto.vmi.vreg<256xf16> + pto.vmi.store %pad_f16, %ub_src[%c128] + : !pto.vmi.vreg<256xf16>, !pto.ptr + pto.mem_bar "VST_VLD" + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c16_i64 + nburst(%c1_i64, %c16_i64, %c16_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp new file mode 100644 index 0000000000..4c4675e167 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_f16_4x32_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_f16_4x32_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp new file mode 100644 index 0000000000..e0616d1d83 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(uint16_t *src, + float *scale, + uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleElems = kRows; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_f16_4x32_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", scaleHost, scaleBytes); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py new file mode 100644 index 0000000000..45952cc50d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/golden.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 64 +SCALE_SLOTS = 16 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def bf16_bits_to_f32(values: np.ndarray) -> np.ndarray: + return (values.astype(np.uint32) << 16).view(np.float32) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = f32_to_bf16_bits(np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS]) + + src = np.empty((ROWS, COLS), dtype=np.uint16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = f32_to_bf16_bits(q_row * ROW_SCALES[row]) + x_f32 = bf16_bits_to_f32(src[row]) * bf16_bits_to_f32(smooth) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 4) * 8 + (row % 4)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + smooth.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto new file mode 100644 index 0000000000..e5b384e96b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto @@ -0,0 +1,156 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth1, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth3, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c2 step %c1 { + %row = arith.muli %tile, %c4 : index + %elem_offset = arith.muli %row, %c64 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_bf16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x = pto.vmi.extf %x_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %smooth_bf16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %smooth = pto.vmi.extf %smooth_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp new file mode 100644 index 0000000000..c601fb3637 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel(__gm__ bfloat16_t *src, + __gm__ bfloat16_t *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ bfloat16_t *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp new file mode 100644 index 0000000000..ec2867832a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 16; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_bf16_8x64_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py new file mode 100644 index 0000000000..a0592fd234 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/golden.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 128 +SCALE_SLOTS = 64 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS].astype(np.float16) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) * smooth.astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 2) * 8 + (row % 2)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + smooth.view(np.uint16).tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto new file mode 100644 index 0000000000..8bd115dac2 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto @@ -0,0 +1,160 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + %ub_smooth4 = pto.castptr %c12800_i64 : i64 -> !pto.ptr + %ub_smooth5 = pto.castptr %c12928_i64 : i64 -> !pto.ptr + %ub_smooth6 = pto.castptr %c13056_i64 : i64 -> !pto.ptr + %ub_smooth7 = pto.castptr %c13184_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth4, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth6, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c8 step %c1 { + %row = arith.muli %tile, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smooth_f16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %smooth = pto.vmi.extf %smooth_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 2} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c2048_i64 + nburst(%c1_i64, %c2048_i64, %c2048_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp new file mode 100644 index 0000000000..0a95b0c6dd --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel(__gm__ half *src, + __gm__ half *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp new file mode 100644 index 0000000000..d17c9c9d3f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 64; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_f16_16x128_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py new file mode 100644 index 0000000000..35bf41a58d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/compare.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_scale = np.fromfile("golden_v3.bin", dtype=np.float32) + scale = np.fromfile("v3.bin", dtype=np.float32) + golden_out = np.fromfile("golden_v4.bin", dtype=np.uint8) + out = np.fromfile("v4.bin", dtype=np.uint8) + + if golden_scale.shape != scale.shape or not np.allclose( + golden_scale, scale, rtol=1.0e-6, atol=1.0e-6 + ): + if golden_scale.shape != scale.shape: + idx = -1 + else: + diff = np.nonzero(~np.isclose(golden_scale, scale, rtol=1.0e-6, atol=1.0e-6))[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] scale compare failed idx={idx} " + f"golden={golden_scale[idx] if idx >= 0 else 'n/a'} " + f"output={scale[idx] if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] if golden_out.shape == out.shape else [] + idx = int(diff[0]) if len(diff) else -1 + print( + f"[ERROR] int8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py new file mode 100644 index 0000000000..df65413c04 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/golden.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 8 +COLS = 64 +SCALE_SLOTS = 16 +INT8_MAX = np.float32(127.0) +SENTINEL_F32 = np.float32(-777.0) +SENTINEL_U8 = np.uint8(0xA5) +SMOOTH_VALUES = np.array( + [0.5, 0.75, 1.0, 1.25, 1.5, 0.625, 0.875, 1.125], + dtype=np.float32, +) + +Q_VALUES = np.array( + [-127, -96, -64, -32, -7, -1, 0, 1, 7, 16, 31, 63, 95, 120, 127], + dtype=np.float32, +) +ROW_SCALES = np.array( + [ + 0.25, + 0.5, + 1.0, + 2.0, + 0.375, + 0.75, + 1.5, + 3.0, + 0.125, + 0.625, + 1.25, + 2.5, + 0.3125, + 0.9375, + 1.875, + 3.75, + ], + dtype=np.float32, +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + smooth_repeats = (COLS + len(SMOOTH_VALUES) - 1) // len(SMOOTH_VALUES) + smooth = np.tile(SMOOTH_VALUES, smooth_repeats)[:COLS].astype(np.float16) + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + golden_out = np.empty((ROWS, COLS), dtype=np.int8) + for row in range(ROWS): + src[row] = (q_row * ROW_SCALES[row]).astype(np.float16) + x_f32 = src[row].astype(np.float32) * smooth.astype(np.float32) + scale = (np.max(np.abs(x_f32)) / INT8_MAX).astype(np.float32) + golden_scale[(row // 4) * 8 + (row % 4)] = scale + quant = np.round(x_f32 / scale).astype(np.float32) + golden_out[row] = np.clip(quant, -128, 127).astype(np.int8) + + scale = np.full(SCALE_SLOTS, SENTINEL_F32, dtype=np.float32) + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + smooth.view(np.uint16).tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + out.tofile(output_dir / "v4.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + golden_out.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto new file mode 100644 index 0000000000..4c8691ad4d --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto @@ -0,0 +1,156 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel(%src_gm: !pto.ptr, + %smooth_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c12416_i64 = arith.constant 12416 : i64 + %c12544_i64 = arith.constant 12544 : i64 + %c12672_i64 = arith.constant 12672 : i64 + %c12800_i64 = arith.constant 12800 : i64 + %c12928_i64 = arith.constant 12928 : i64 + %c13056_i64 = arith.constant 13056 : i64 + %c13184_i64 = arith.constant 13184 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_smooth = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_smooth1 = pto.castptr %c12416_i64 : i64 -> !pto.ptr + %ub_smooth2 = pto.castptr %c12544_i64 : i64 -> !pto.ptr + %ub_smooth3 = pto.castptr %c12672_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth1, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth2, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %smooth_gm, %ub_smooth3, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_gm, %ub_out, %c0_i64, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %tile = %c0 to %c2 step %c1 { + %row = arith.muli %tile, %c4 : index + %elem_offset = arith.muli %row, %c64 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x_f16 = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smooth_f16 = pto.vmi.load %ub_smooth[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %smooth = pto.vmi.extf %smooth_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %smoothed = pto.vmi.mulf %x, %smooth + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %smoothed + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + %max_int8 = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %scale = pto.vmi.divf %amax, %max_int8 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scale_offset = arith.muli %tile, %c8 : index + pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} + : !pto.vmi.vreg<256xf32>, !pto.ptr + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.divf %smoothed, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_out[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_scale, %scale_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out, %out_gm, %c512_i64 + nburst(%c1_i64, %c512_i64, %c512_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp new file mode 100644 index 0000000000..851f3282a0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/launch.cpp @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel(__gm__ half *src, + __gm__ half *smooth, + __gm__ float *scale, + __gm__ uint8_t *out); + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream) { + vmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ half *)smooth, (__gm__ float *)scale, + (__gm__ uint8_t *)out); +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp new file mode 100644 index 0000000000..a65f375355 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/main.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + uint16_t *src, uint16_t *smooth, float *scale, uint8_t *out, + void *stream); + +int main() { + constexpr size_t kRows = 8; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kSmoothElems = kCols; + constexpr size_t kScaleElems = 16; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t smoothBytes = kSmoothElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + uint16_t *smoothHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + uint16_t *smoothDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&smoothHost), smoothBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&smoothDevice, smoothBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", smoothBytes, smoothHost, smoothBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(smoothDevice, smoothBytes, smoothHost, smoothBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_dynamic_quant_pertoken_smooth_f16_8x64_kernel( + srcDevice, smoothDevice, scaleDevice, outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", scaleHost, scaleBytes); + WriteFile("./v4.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(smoothDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(smoothHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py b/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py deleted file mode 100644 index 39f0af76f7..0000000000 --- a/test/vpto/cases/vmi/kernels/simdvf-per-token-cast-to-fp8/golden.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import argparse -from pathlib import Path - -import numpy as np - -ELEMS = 256 -GROUPS = 2 -GROUP_SIZE = ELEMS // GROUPS -FP8_MAX = np.float32(448.0) -SCALES = np.array([0.25, 0.5], dtype=np.float32) -SENTINEL_F32 = np.float32(-777.0) -SENTINEL_U8 = np.uint8(0xA5) - -Q_VALUES = np.array([0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32) -F8E4M3FN_BYTES = np.array([0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8) - - -def generate(output_dir: Path) -> None: - repeats = (GROUP_SIZE + len(Q_VALUES) - 1) // len(Q_VALUES) - q_group = np.tile(Q_VALUES, repeats)[:GROUP_SIZE].astype(np.float32) - q = np.concatenate([q_group, q_group]).astype(np.float32) - src = np.empty(ELEMS, dtype=np.float32) - golden_scale = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) - for group in range(GROUPS): - begin = group * GROUP_SIZE - end = begin + GROUP_SIZE - src[begin:end] = (q_group * SCALES[group]).astype(np.float32) - amax = np.max(np.abs(src[begin:end])).astype(np.float32) - scale = np.maximum(amax, np.float32(1.0e-4)) / FP8_MAX - golden_scale[group * 8] = scale - golden_out8_group = np.tile(F8E4M3FN_BYTES, repeats)[:GROUP_SIZE].astype(np.uint8) - golden_out8 = np.concatenate([golden_out8_group, golden_out8_group]).astype(np.uint8) - - scale_out = np.full(ELEMS, SENTINEL_F32, dtype=np.float32) - out8 = np.full(ELEMS, SENTINEL_U8, dtype=np.uint8) - - output_dir.mkdir(parents=True, exist_ok=True) - src.tofile(output_dir / "v1.bin") - scale_out.tofile(output_dir / "v2.bin") - out8.tofile(output_dir / "v3.bin") - golden_scale.tofile(output_dir / "golden_v2.bin") - golden_out8.tofile(output_dir / "golden_v3.bin") - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--output-dir", type=Path, default=Path(".")) - args = parser.parse_args() - generate(args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py new file mode 100644 index 0000000000..9c6b791567 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +INPUT_COLS = 8 +OUT_COLS = 4 +SCALE_BYTES = 4 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([1.0, -1.0, 0.5, 448.0], dtype=np.float32) +F8E4M3FN_BYTES = np.array([0x38, 0xB8, 0x30, 0x7E], dtype=np.uint8) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + x2 = f32_to_bf16_bits(Q_VALUES / np.float32(4096.0)) + x1 = f32_to_bf16_bits(np.full(OUT_COLS, np.float32(16.0), dtype=np.float32)) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(F8E4M3FN_BYTES, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto new file mode 100644 index 0000000000..49aacafc8f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_bf16 = arith.constant 0.000000e+00 : bf16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_x2 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_x1 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c128_i64 : i64 -> !pto.ptr + %src_x1_gm = pto.addptr %src_gm, %c4 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %zero_pad, %ub_x2[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.vmi.store %zero_pad, %ub_x1[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + + pto.mte_gm_ub %src_gm, %ub_x2, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + pto.mte_gm_ub %src_x1_gm, %ub_x1, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c4 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %x2_bf16 = pto.vmi.load %ub_x2[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x1_bf16 = pto.vmi.load %ub_x1[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x2 = pto.vmi.extf %x2_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_u8 = pto.vmi.trunci %e8m0_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xui8>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4_i64 + nburst(%c4_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp new file mode 100644 index 0000000000..09fbbaa897 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, + (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp new file mode 100644 index 0000000000..ed9deaa718 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kInputCols = 8; + constexpr size_t kOutCols = 4; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = kRows; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_bf16_e4m3_4x8_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py new file mode 100644 index 0000000000..9634af6120 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/golden.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 4 +INPUT_COLS = 8 +OUT_COLS = 4 +SCALE_BYTES = 4 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array([1.0, -1.0, 0.5, 57344.0], dtype=np.float32) +F8E5M2_BYTES = np.array([0x3C, 0xBC, 0x38, 0x7B], dtype=np.uint8) + + +def f32_to_bf16_bits(values: np.ndarray) -> np.ndarray: + bits = values.astype(np.float32).view(np.uint32) + lsb = (bits >> 16) & 1 + rounded = bits + np.uint32(0x7FFF) + lsb + return (rounded >> 16).astype(np.uint16) + + +def generate(output_dir: Path) -> None: + x2 = f32_to_bf16_bits(Q_VALUES / np.float32(16.0)) + x1 = f32_to_bf16_bits(np.full(OUT_COLS, np.float32(16.0), dtype=np.float32)) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(F8E5M2_BYTES, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto new file mode 100644 index 0000000000..b3cd60e99a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto @@ -0,0 +1,157 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c16_i64 = arith.constant 16 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_bf16 = arith.constant 0.000000e+00 : bf16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_x2 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_x1 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c128_i64 : i64 -> !pto.ptr + %src_x1_gm = pto.addptr %src_gm, %c4 : !pto.ptr -> !pto.ptr + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_bf16 + : bf16 -> !pto.vmi.vreg<256xbf16> + pto.vmi.store %zero_pad, %ub_x2[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + pto.vmi.store %zero_pad, %ub_x1[%c0] + : !pto.vmi.vreg<256xbf16>, !pto.ptr + } + pto.set_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE2", "EVENT_ID0"] + + pto.mte_gm_ub %src_gm, %ub_x2, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + pto.mte_gm_ub %src_x1_gm, %ub_x1, %c0_i64, %c8_i64 + nburst(%c4_i64, %c16_i64, %c64_i64) + pad(%c0_bf16, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64, + pad bf16, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask = pto.vmi.create_group_mask %c4 {num_groups = 8, group_size = 32} + : index -> !pto.vmi.mask<256xpred> + %x2_bf16 = pto.vmi.load %ub_x2[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x1_bf16 = pto.vmi.load %ub_x1[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + %x2 = pto.vmi.extf %x2_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_bf16 + : !pto.vmi.vreg<256xbf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_u8 = pto.vmi.trunci %e8m0_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xui8>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%c0] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c4_i64 + nburst(%c4_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c1_i64, %c0_i64, %c0_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp new file mode 100644 index 0000000000..51b4c50290 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel(__gm__ bfloat16_t *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel<<<1, nullptr, stream>>>( + (__gm__ bfloat16_t *)src, (__gm__ uint8_t *)out, + (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp new file mode 100644 index 0000000000..bdfc82d090 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 4; + constexpr size_t kInputCols = 8; + constexpr size_t kOutCols = 4; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = kRows; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_bf16_e5m2_4x8_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py new file mode 100644 index 0000000000..b26ff0646b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/golden.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +INPUT_COLS = 512 +OUT_COLS = 256 +SCALE_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (OUT_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:OUT_COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:OUT_COLS].astype(np.uint8) + + x2 = (q_row / np.float32(4096.0)).astype(np.float16) + x1 = np.full(OUT_COLS, np.float16(16.0), dtype=np.float16) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x77), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto new file mode 100644 index 0000000000..32d9fc4985 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto @@ -0,0 +1,153 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c8_i32 = arith.constant 8 : i32 + %c119_i32 = arith.constant 119 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c65536_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c98304_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c64 step %c1 { + %src_row_off = arith.muli %row, %c512 : index + %x1_off = arith.addi %src_row_off, %c256 : index + %out_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x2_f16 = pto.vmi.load %ub_src[%src_row_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x1_f16 = pto.vmi.load %ub_src[%x1_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x2 = pto.vmi.extf %x2_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c8_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_f8[%out_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c16384_i64 + nburst(%c1_i64, %c16384_i64, %c16384_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c8_i64 + nburst(%c64_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp new file mode 100644 index 0000000000..32dfb4b472 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_f16_e4m3_64x512_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp new file mode 100644 index 0000000000..e20dc8e25f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kInputCols = 512; + constexpr size_t kOutCols = 256; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = 512; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_f16_e4m3_64x512_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py new file mode 100644 index 0000000000..20aea94e4f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/golden.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 128 +INPUT_COLS = 256 +OUT_COLS = 128 +SCALE_BYTES = 512 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 4.0, -4.0, 57344.0], + dtype=np.float32, +) +F8E5M2_BYTES = np.array( + [0x00, 0x3C, 0xBC, 0x38, 0xB8, 0x40, 0xC0, 0x44, 0xC4, 0x7B], + dtype=np.uint8, +) + + +def generate(output_dir: Path) -> None: + repeats = (OUT_COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:OUT_COLS].astype(np.float32) + f8_row = np.tile(F8E5M2_BYTES, repeats)[:OUT_COLS].astype(np.uint8) + + x2 = (q_row / np.float32(16.0)).astype(np.float16) + x1 = np.full(OUT_COLS, np.float16(16.0), dtype=np.float16) + src_row = np.concatenate([x2, x1]) + src = np.tile(src_row, (ROWS, 1)) + golden_out = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_scale = np.full(SCALE_BYTES, np.uint8(0x7F), dtype=np.uint8) + + out = np.full((ROWS, OUT_COLS), SENTINEL_U8, dtype=np.uint8) + scale = np.full(SCALE_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + out.tofile(output_dir / "v2.bin") + scale.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v2.bin") + golden_scale.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto new file mode 100644 index 0000000000..5cbb12c575 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto @@ -0,0 +1,163 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel(%src_gm: !pto.ptr, + %out_gm: !pto.ptr, + %scale_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c32768 = arith.constant 32768 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c128_i64 = arith.constant 128 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c2048_i64 = arith.constant 2048 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c49152_i64 = arith.constant 49152 : i64 + %c53248_i64 = arith.constant 53248 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c163840_i64 = arith.constant 163840 : i64 + %c2139095040_i32 = arith.constant 2139095040 : i32 + %c23_i32 = arith.constant 23 : i32 + %c15_i32 = arith.constant 15 : i32 + %c127_i32 = arith.constant 127 : i32 + %c254_i32 = arith.constant 254 : i32 + %c0_f16 = arith.constant 0.000000e+00 : f16 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c1_f32 = arith.constant 1.000000e+00 : f32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_u8 = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_out_f8 = pto.castptr %c131072_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c163840_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c65536_i64 + nburst(%c1_i64, %c65536_i64, %c65536_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %zero_pad = pto.vmi.broadcast %c0_f16 + : f16 -> !pto.vmi.vreg<256xf16> + pto.vmi.store %zero_pad, %ub_src[%c32768] + : !pto.vmi.vreg<256xf16>, !pto.ptr + scf.for %row = %c0 to %c128 step %c1 { + %src_row_off = arith.muli %row, %c256 : index + %x1_off = arith.addi %src_row_off, %c128 : index + %out_off = arith.muli %row, %c256 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x2_f16 = pto.vmi.load %ub_src[%src_row_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x1_f16 = pto.vmi.load %ub_src[%x1_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x2 = pto.vmi.extf %x2_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %x1 = pto.vmi.extf %x1_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %neg_x1 = pto.vmi.subf %zero, %x1 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %exp_neg = pto.vmi.exp %neg_x1 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %one = pto.vmi.broadcast %c1_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %den = pto.vmi.addf %one, %exp_neg + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %silu_x1 = pto.vmi.divf %x1, %den + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %swiglu = pto.vmi.mulf %silu_x1, %x2 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %swiglu + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %c2139095040_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c23_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %c15_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %c254_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + + %scale_ub_off = arith.muli %row, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %swiglu, %scale_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + pto.vmi.store %q8, %ub_out_f8[%out_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_u8, %out_gm, %c128_i64 + nburst(%c128_i64, %c256_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_scale, %scale_gm, %c4_i64 + nburst(%c128_i64, %c32_i64, %c4_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp new file mode 100644 index 0000000000..dbfc94e8b9 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel(__gm__ half *src, + __gm__ uint8_t *out, + __gm__ uint8_t *scale); + +void LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream) { + vmi_swiglu_mx_quant_f16_e5m2_128x256_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ uint8_t *)out, (__gm__ uint8_t *)scale); +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp new file mode 100644 index 0000000000..cfe997291f --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/main.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + uint16_t *src, uint8_t *out, uint8_t *scale, void *stream); + +int main() { + constexpr size_t kRows = 128; + constexpr size_t kInputCols = 256; + constexpr size_t kOutCols = 128; + constexpr size_t kInputElems = kRows * kInputCols; + constexpr size_t kOutElems = kRows * kOutCols; + constexpr size_t kScaleBytes = 512; + size_t srcBytes = kInputElems * sizeof(uint16_t); + size_t outBytes = kOutElems * sizeof(uint8_t); + size_t scaleBytes = kScaleBytes; + uint16_t *srcHost = nullptr; + uint8_t *outHost = nullptr; + uint8_t *scaleHost = nullptr; + uint16_t *srcDevice = nullptr; + uint8_t *outDevice = nullptr; + uint8_t *scaleDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outBytes, outHost, outBytes); + ReadFile("./v3.bin", scaleBytes, scaleHost, scaleBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_swiglu_mx_quant_f16_e5m2_128x256_kernel( + srcDevice, outDevice, scaleDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(scaleHost, scaleBytes, scaleDevice, scaleBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outHost, outBytes); + WriteFile("./v3.bin", scaleHost, scaleBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outDevice); + aclrtFree(scaleDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(outHost); + aclrtFreeHost(scaleHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py new file mode 100644 index 0000000000..85beaede4b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v4 idx={idx} " + f"golden={int(golden[idx]) if idx >= 0 else 'n/a'} " + f"output={int(output[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py new file mode 100644 index 0000000000..7e0bce8527 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/golden.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 128 +RNG_SEED = 19 + + +def generate(output_dir: Path) -> None: + np.random.seed(RNG_SEED) + src = np.random.uniform(low=-2, high=2, size=(ROWS, COLS)).astype(np.float32) + row_min = np.min(src, axis=1, keepdims=True) + row_max = np.max(src, axis=1, keepdims=True) + scale = ((row_max - row_min) / np.float32(255.0)).astype(np.float32) + inv_scale = np.where(scale != 0, np.float32(1.0) / scale, np.float32(0.0)).astype(np.float32) + offset = np.clip(np.round(-row_min / scale), 0, 255).astype(np.float32) + rounded = np.round(src * inv_scale + offset).astype(np.float32) + golden = np.clip(rounded.astype(np.float16), 0, 255).astype(np.uint8) + dst = np.full((ROWS, COLS), 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + inv_scale.reshape(ROWS).tofile(output_dir / "v2.bin") + offset.reshape(ROWS).tofile(output_dir / "v3.bin") + dst.tofile(output_dir / "v4.bin") + golden.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto new file mode 100644 index 0000000000..6724cd56d0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/kernel.pto @@ -0,0 +1,106 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_int8_asym_64x128_kernel(%src_gm: !pto.ptr, + %inv_scale_gm: !pto.ptr, + %offset_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c36864_i64 = arith.constant 36864 : i64 + %c40960_i64 = arith.constant 40960 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_inv_scale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_offset = pto.castptr %c36864_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c40960_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %inv_scale_gm, %ub_inv_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %offset_gm, %ub_offset, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %pair = %c0 to %c32 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %row_i32 = arith.index_cast %row : index to i32 + %gather_mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c7_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %local_group = pto.vmi.shrui %lane, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %row_base = pto.vmi.broadcast %row_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %param_indices = pto.vmi.addi %row_base, %local_group + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.gather %ub_inv_scale[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %offset = pto.vmi.gather %ub_offset[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %shifted = pto.vmi.addf %scaled, %offset + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %shifted + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_dst[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp new file mode 100644 index 0000000000..aeaba21d8a --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_int8_asym_64x128_kernel(__gm__ float *src, + __gm__ float *inv_scale, + __gm__ float *offset, __gm__ uint8_t *dst); + +void LaunchVmi_tquant_int8_asym_64x128_kernel(float *src, float *inv_scale, + float *offset, uint8_t *dst, + void *stream) { + vmi_tquant_int8_asym_64x128_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)inv_scale, (__gm__ float *)offset, + (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp new file mode 100644 index 0000000000..70bd93e436 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/main.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_int8_asym_64x128_kernel(float *src, float *inv_scale, + float *offset, uint8_t *dst, + void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kRows * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + float *offsetHost = nullptr; + uint8_t *dstHost = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + float *offsetDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&offsetHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&offsetDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", scaleBytes, offsetHost, scaleBytes); + ReadFile("./v4.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(offsetDevice, scaleBytes, offsetHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_int8_asym_64x128_kernel(srcDevice, scaleDevice, + offsetDevice, dstDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(offsetDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(offsetHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-asym-64x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py new file mode 100644 index 0000000000..85beaede4b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/compare.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden = np.fromfile("golden_v4.bin", dtype=np.uint8) + output = np.fromfile("v4.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + print("[INFO] compare passed") + return + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed v4 idx={idx} " + f"golden={int(golden[idx]) if idx >= 0 else 'n/a'} " + f"output={int(output[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py new file mode 100644 index 0000000000..3e5bc93768 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/golden.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 64 +COLS = 128 +RNG_SEED = 19 + + +def generate(output_dir: Path) -> None: + np.random.seed(RNG_SEED) + src = np.random.uniform(low=-2, high=2, size=(ROWS, COLS)).astype(np.float32) + scale = (np.max(np.abs(src), axis=1, keepdims=True) / np.float32(127.0)).astype(np.float32) + inv_scale = np.where(scale != 0, np.float32(1.0) / scale, np.float32(0.0)).astype(np.float32) + rounded = np.round(src * inv_scale).astype(np.float32) + golden = np.clip(rounded.astype(np.float16), -128, 127).astype(np.int8) + dst = np.full((ROWS, COLS), 0xA5, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + inv_scale.reshape(ROWS).tofile(output_dir / "v2.bin") + dst.tofile(output_dir / "v4.bin") + golden.view(np.uint8).tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto new file mode 100644 index 0000000000..70b04d7efe --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/kernel.pto @@ -0,0 +1,121 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_int8_sym_64x128_kernel(%src_gm: !pto.ptr, + %inv_scale_gm: !pto.ptr, + %dst_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i32 = arith.constant 0 : i32 + %c7_i32 = arith.constant 7 : i32 + %c256_i32 = arith.constant 256 : i32 + %c-128_f32 = arith.constant -1.280000e+02 : f32 + %c0_f32 = arith.constant 0.000000e+00 : f32 + %c127_f32 = arith.constant 1.270000e+02 : f32 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i64 = arith.constant 256 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c36864_i64 = arith.constant 36864 : i64 + %c40960_i64 = arith.constant 40960 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_inv_scale = pto.castptr %c32768_i64 : i64 -> !pto.ptr + %ub_dst = pto.castptr %c40960_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c32768_i64 + nburst(%c1_i64, %c32768_i64, %c32768_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %inv_scale_gm, %ub_inv_scale, %c0_i64, %c256_i64 + nburst(%c1_i64, %c256_i64, %c256_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %dst_gm, %ub_dst, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %pair = %c0 to %c32 step %c1 { + %row = arith.muli %pair, %c2 : index + %elem_offset = arith.muli %row, %c128 : index + %x = pto.vmi.load %ub_src[%elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %row_i32 = arith.index_cast %row : index to i32 + %gather_mask = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + %zero = pto.vmi.broadcast %c0_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lo = pto.vmi.broadcast %c-128_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %hi = pto.vmi.broadcast %c127_f32 + : f32 -> !pto.vmi.vreg<256xf32> + %lane = pto.vmi.iota %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %c7_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %local_group = pto.vmi.shrui %lane, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %row_base = pto.vmi.broadcast %row_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %param_indices = pto.vmi.addi %row_base, %local_group + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale = pto.vmi.gather %ub_inv_scale[%param_indices], %gather_mask, %zero + : !pto.ptr, !pto.vmi.vreg<256xi32>, + !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scale + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped_lo = pto.vmi.maxf %scaled, %lo + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %clamped = pto.vmi.minf %clamped_lo, %hi + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %i32 = pto.vmi.fptosi %clamped + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %zero_i32 = pto.vmi.broadcast %c0_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %byte_bias = pto.vmi.broadcast %c256_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %neg = pto.vmi.cmpi "slt", %i32, %zero_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.mask<256xpred> + %wrapped = pto.vmi.addi %i32, %byte_bias + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %byte_i32 = pto.vmi.select %neg, %wrapped, %i32 + : !pto.vmi.mask<256xpred>, !pto.vmi.vreg<256xi32>, + !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %u8 = pto.vmi.trunci %byte_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %u8, %ub_dst[%elem_offset] + : !pto.vmi.vreg<256xui8>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_dst, %dst_gm, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp new file mode 100644 index 0000000000..24c20be676 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/launch.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_int8_sym_64x128_kernel(__gm__ float *src, __gm__ float *inv_scale, + __gm__ uint8_t *dst); + +void LaunchVmi_tquant_int8_sym_64x128_kernel(float *src, float *inv_scale, + uint8_t *dst, void *stream) { + vmi_tquant_int8_sym_64x128_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ float *)inv_scale, (__gm__ uint8_t *)dst); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp new file mode 100644 index 0000000000..3e4222a58e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/main.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_int8_sym_64x128_kernel(float *src, float *inv_scale, + uint8_t *dst, void *stream); + +int main() { + constexpr size_t kRows = 64; + constexpr size_t kCols = 128; + constexpr size_t kElems = kRows * kCols; + size_t srcBytes = kElems * sizeof(float); + size_t scaleBytes = kRows * sizeof(float); + size_t dstBytes = kElems * sizeof(uint8_t); + float *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *dstHost = nullptr; + float *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *dstDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&dstHost), dstBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&dstDevice, dstBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v4.bin", dstBytes, dstHost, dstBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(dstDevice, dstBytes, dstHost, dstBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_int8_sym_64x128_kernel(srcDevice, scaleDevice, dstDevice, + stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(dstHost, dstBytes, dstDevice, dstBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v4.bin", dstHost, dstBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(dstDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(dstHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-int8-sym-64x128/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py new file mode 100644 index 0000000000..bc45e33883 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v2") or not check_u8("v3"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py new file mode 100644 index 0000000000..787c4a61d3 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/golden.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 32 +E8M0_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row / np.float32(256.0), (ROWS, 1)).astype(np.float32) + golden_fp8 = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + golden_e8m0[:ROWS] = np.uint8(0x77) + + out_fp8 = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + out_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + out_fp8.tofile(output_dir / "v2.bin") + out_e8m0.tofile(output_dir / "v3.bin") + golden_fp8.tofile(output_dir / "golden_v2.bin") + golden_e8m0.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto new file mode 100644 index 0000000000..da8dca54e6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto @@ -0,0 +1,114 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_mxfp8_32x32_nd_kernel(%src_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c4_i64 = arith.constant 4 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %exp_mask_i32 = arith.constant 2139095040 : i32 + %shift_i32 = arith.constant 23 : i32 + %emax_i32 = arith.constant 8 : i32 + %scale_exp_bias_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_out_fp8_u8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_fp8_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out_e8m0 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %out_fp8_gm, %ub_out_fp8_u8, %c0_i64, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %row = %c0 to %c32 step %c8 { + %elem_off = arith.muli %row, %c32 : index + %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> + %x = pto.vmi.load %ub_src[%elem_off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_slot = arith.divui %row, %c8 : index + %scale_ub_off = arith.muli %scale_slot, %c32 : index + pto.vmi.group_store %e8m0_i32, %ub_out_e8m0[%scale_ub_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xi32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scaling = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scaling_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_out_fp8_f8[%elem_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out_fp8_u8, %out_fp8_gm, %c1024_i64 + nburst(%c1_i64, %c1024_i64, %c1024_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_out_e8m0, %out_e8m0_gm, %c8_i64 + nburst(%c4_i64, %c32_i64, %c8_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp new file mode 100644 index 0000000000..1af7bbdc45 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/launch.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_mxfp8_32x32_nd_kernel(__gm__ float *src, __gm__ uint8_t *out_fp8, + __gm__ uint8_t *out_e8m0); + +void LaunchVmi_tquant_mxfp8_32x32_nd_kernel(float *src, uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream) { + vmi_tquant_mxfp8_32x32_nd_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint8_t *)out_fp8, + (__gm__ uint8_t *)out_e8m0); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp new file mode 100644 index 0000000000..827877248e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/main.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_mxfp8_32x32_nd_kernel(float *src, uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 32; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kE8m0Bytes = 256; + size_t srcBytes = kElems * sizeof(float); + size_t outFp8Bytes = kElems * sizeof(uint8_t); + size_t outE8m0Bytes = kE8m0Bytes * sizeof(uint8_t); + float *srcHost = nullptr; + uint8_t *outFp8Host = nullptr; + uint8_t *outE8m0Host = nullptr; + float *srcDevice = nullptr; + uint8_t *outFp8Device = nullptr; + uint8_t *outE8m0Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outFp8Host), outFp8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outE8m0Host), outE8m0Bytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outFp8Device, outFp8Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outE8m0Device, outE8m0Bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", outFp8Bytes, outFp8Host, outFp8Bytes); + ReadFile("./v3.bin", outE8m0Bytes, outE8m0Host, outE8m0Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outFp8Device, outFp8Bytes, outFp8Host, outFp8Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outE8m0Device, outE8m0Bytes, outE8m0Host, outE8m0Bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_mxfp8_32x32_nd_kernel(srcDevice, outFp8Device, + outE8m0Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outFp8Host, outFp8Bytes, outFp8Device, outFp8Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outE8m0Host, outE8m0Bytes, outE8m0Device, outE8m0Bytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v2.bin", outFp8Host, outFp8Bytes); + WriteFile("./v3.bin", outE8m0Host, outE8m0Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(outFp8Device); + aclrtFree(outE8m0Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(outFp8Host); + aclrtFreeHost(outE8m0Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py new file mode 100644 index 0000000000..cffec13f08 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/compare.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def check_u8(name: str) -> bool: + golden = np.fromfile(f"golden_{name}.bin", dtype=np.uint8) + output = np.fromfile(f"{name}.bin", dtype=np.uint8) + if golden.shape == output.shape and np.array_equal(golden, output): + return True + diff = np.nonzero(golden != output)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] compare failed {name} idx={idx} " + f"golden=0x{int(golden[idx]):02x} output=0x{int(output[idx]):02x}" + ) + return False + + +def main() -> None: + if not check_u8("v3") or not check_u8("v4"): + sys.exit(2) + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py new file mode 100644 index 0000000000..e3daa6ae34 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 32 +COLS = 64 +GROUPS = ROWS * COLS // 32 +E8M0_BYTES = 256 +IDX_BYTES = 256 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def make_e8m0_zz_indices() -> np.ndarray: + index_array = np.arange(GROUPS, dtype=np.int64).reshape(ROWS, COLS // 32) + index_reshaped = index_array.reshape(ROWS // 16, 16, (COLS // 32) // 2, 2) + index_zz = np.transpose(index_reshaped, [0, 2, 1, 3]).flatten() + return (index_zz // 2)[::2].astype(np.uint16) + + +def generate(output_dir: Path) -> None: + repeats = (COLS + len(Q_VALUES) - 1) // len(Q_VALUES) + q_row = np.tile(Q_VALUES, repeats)[:COLS].astype(np.float32) + f8_row = np.tile(F8E4M3FN_BYTES, repeats)[:COLS].astype(np.uint8) + + src = np.tile(q_row / np.float32(256.0), (ROWS, 1)).astype(np.float32) + fp8_nd = np.tile(f8_row, (ROWS, 1)).astype(np.uint8) + golden_fp8 = np.transpose(fp8_nd.reshape(ROWS, COLS // 32, 32), [1, 0, 2]).flatten() + + e8m0_nd = np.full((ROWS, COLS // 32), np.uint8(0x77), dtype=np.uint8) + e8m0_zz = np.transpose(e8m0_nd.reshape(ROWS // 16, 16, (COLS // 32) // 2, 2), [0, 2, 1, 3]).flatten() + golden_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + golden_e8m0[:GROUPS] = e8m0_zz + + idx = np.zeros(IDX_BYTES // np.dtype(np.uint16).itemsize, dtype=np.uint16) + zz_indices = make_e8m0_zz_indices() + idx[: zz_indices.size] = zz_indices + + out_fp8 = np.full(ROWS * COLS, SENTINEL_U8, dtype=np.uint8) + out_e8m0 = np.full(E8M0_BYTES, SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.tofile(output_dir / "v1.bin") + idx.tofile(output_dir / "v2.bin") + out_fp8.tofile(output_dir / "v3.bin") + out_e8m0.tofile(output_dir / "v4.bin") + golden_fp8.tofile(output_dir / "golden_v3.bin") + golden_e8m0.tofile(output_dir / "golden_v4.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto new file mode 100644 index 0000000000..8d09bf51d6 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto @@ -0,0 +1,174 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_tquant_mxfp8_32x64_nz_kernel(%src_gm: !pto.ptr, + %idx_gm: !pto.ptr, + %out_fp8_gm: !pto.ptr, + %out_e8m0_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c1056 = arith.constant 1056 : index + %c1024 = arith.constant 1024 : index + %c0_i16 = arith.constant 0 : i16 + %c1_i16 = arith.constant 1 : i16 + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c8_i64 = arith.constant 8 : i64 + %c32_i64 = arith.constant 32 : i64 + %c64_i64 = arith.constant 64 : i64 + %c256_i64 = arith.constant 256 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c131328_i64 = arith.constant 131328 : i64 + %c131584_i64 = arith.constant 131584 : i64 + %c196864_i64 = arith.constant 196864 : i64 + %c197120_i64 = arith.constant 197120 : i64 + %exp_mask_i32 = arith.constant 2139095040 : i32 + %shift_i32 = arith.constant 23 : i32 + %emax_i32 = arith.constant 8 : i32 + %scale_exp_bias_i32 = arith.constant 254 : i32 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_fp8_nd_f8 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_max = pto.castptr %c131328_i64 : i64 -> !pto.ptr + %ub_e8m0_nd_u8 = pto.castptr %c131584_i64 : i64 -> !pto.ptr + %ub_e8m0_nd_u16 = pto.castptr %c131584_i64 : i64 -> !pto.ptr + %ub_e8m0_zz_u8 = pto.castptr %c196864_i64 : i64 -> !pto.ptr + %ub_e8m0_zz_u16 = pto.castptr %c196864_i64 : i64 -> !pto.ptr + %ub_idx_u16 = pto.castptr %c197120_i64 : i64 -> !pto.ptr + %out_fp8_hi_gm = pto.addptr %out_fp8_gm, %c1024 + : !pto.ptr -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %idx_gm, %ub_idx_u16, %c0_i64, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + %mask_f32 = pto.vmi.create_mask %c256 + : index -> !pto.vmi.mask<256xpred> + scf.for %row = %c0 to %c32 step %c4 { + %src_off = arith.muli %row, %c64 : index + %scale_off = arith.muli %row, %c2 : index + %nd_off = arith.muli %row, %c64 : index + + %x = pto.vmi.load %ub_src[%src_off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask_f32 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> + -> !pto.vmi.vreg<256xf32> + + %amax_bits = pto.vmi.bitcast %amax + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp_bits = pto.vmi.andi %amax_bits, %exp_mask + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp = pto.vmi.shrui %exp_bits, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_i32 = pto.vmi.subi %exp, %emax + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + pto.vmi.group_store %amax, %ub_max[%scale_off], %c1 {num_groups = 8} + : !pto.vmi.vreg<256xf32>, !pto.ptr + + %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scale_bits = pto.vmi.shli %scale_exp, %shift + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %scaling = pto.vmi.bitcast %scale_bits + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x, %scaling_vec + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %q8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %q8, %ub_fp8_nd_f8[%nd_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + + %max256 = pto.vmi.load %ub_max[%c0] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %max256_bits = pto.vmi.bitcast %max256 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + %exp_mask256 = pto.vmi.broadcast %exp_mask_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %shift256 = pto.vmi.broadcast %shift_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %emax256 = pto.vmi.broadcast %emax_i32 + : i32 -> !pto.vmi.vreg<256xi32> + %exp256_bits = pto.vmi.andi %max256_bits, %exp_mask256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %exp256 = pto.vmi.shrui %exp256_bits, %shift256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_256_i32 = pto.vmi.subi %exp256, %emax256 + : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<256xi32> + %e8m0_256_u8 = pto.vmi.trunci %e8m0_256_i32 + : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + pto.vmi.store %e8m0_256_u8, %ub_e8m0_nd_u8[%c0] + : !pto.vmi.vreg<256xui8>, !pto.ptr + + %idx_mask = pto.vmi.create_mask %c32 + : index -> !pto.vmi.mask<128xpred> + %idx_vec = pto.vmi.load %ub_idx_u16[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xui16> + %e8m0_zz = pto.vmi.gather %ub_e8m0_nd_u16[%idx_vec], %idx_mask, %idx_vec + : !pto.ptr, !pto.vmi.vreg<128xui16>, + !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xui16> + -> !pto.vmi.vreg<128xui16> + pto.vmi.store %e8m0_zz, %ub_e8m0_zz_u16[%c0] + : !pto.vmi.vreg<128xui16>, !pto.ptr + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_fp8_nd_f8, %out_fp8_gm, %c32_i64 + nburst(%c32_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + %ub_fp8_nd_hi_f8 = pto.addptr %ub_fp8_nd_f8, %c32 + : !pto.ptr -> !pto.ptr + pto.mte_ub_gm %ub_fp8_nd_hi_f8, %out_fp8_hi_gm, %c32_i64 + nburst(%c32_i64, %c64_i64, %c32_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.mte_ub_gm %ub_e8m0_zz_u8, %out_e8m0_gm, %c64_i64 + nburst(%c1_i64, %c64_i64, %c64_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp new file mode 100644 index 0000000000..4959bd44b0 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/launch.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_tquant_mxfp8_32x64_nz_kernel(__gm__ float *src, __gm__ uint16_t *idx, + __gm__ uint8_t *out_fp8, + __gm__ uint8_t *out_e8m0); + +void LaunchVmi_tquant_mxfp8_32x64_nz_kernel(float *src, uint16_t *idx, + uint8_t *out_fp8, + uint8_t *out_e8m0, void *stream) { + vmi_tquant_mxfp8_32x64_nz_kernel<<<1, nullptr, stream>>>( + (__gm__ float *)src, (__gm__ uint16_t *)idx, + (__gm__ uint8_t *)out_fp8, (__gm__ uint8_t *)out_e8m0); +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp new file mode 100644 index 0000000000..9440162a71 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/main.cpp @@ -0,0 +1,116 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_tquant_mxfp8_32x64_nz_kernel(float *src, uint16_t *idx, + uint8_t *out_fp8, uint8_t *out_e8m0, + void *stream); + +int main() { + constexpr size_t kRows = 32; + constexpr size_t kCols = 64; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kE8m0Bytes = 256; + constexpr size_t kIdxBytes = 256; + size_t srcBytes = kElems * sizeof(float); + size_t idxBytes = kIdxBytes; + size_t outFp8Bytes = kElems * sizeof(uint8_t); + size_t outE8m0Bytes = kE8m0Bytes * sizeof(uint8_t); + float *srcHost = nullptr; + uint16_t *idxHost = nullptr; + uint8_t *outFp8Host = nullptr; + uint8_t *outE8m0Host = nullptr; + float *srcDevice = nullptr; + uint16_t *idxDevice = nullptr; + uint8_t *outFp8Device = nullptr; + uint8_t *outE8m0Device = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&idxHost), idxBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outFp8Host), outFp8Bytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outE8m0Host), outE8m0Bytes)); + ACL_CHECK( + aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK( + aclrtMalloc((void **)&idxDevice, idxBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outFp8Device, outFp8Bytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outE8m0Device, outE8m0Bytes, + ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", idxBytes, idxHost, idxBytes); + ReadFile("./v3.bin", outFp8Bytes, outFp8Host, outFp8Bytes); + ReadFile("./v4.bin", outE8m0Bytes, outE8m0Host, outE8m0Bytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(idxDevice, idxBytes, idxHost, idxBytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outFp8Device, outFp8Bytes, outFp8Host, outFp8Bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outE8m0Device, outE8m0Bytes, outE8m0Host, outE8m0Bytes, + ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_tquant_mxfp8_32x64_nz_kernel(srcDevice, idxDevice, outFp8Device, + outE8m0Device, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outFp8Host, outFp8Bytes, outFp8Device, outFp8Bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + ACL_CHECK(aclrtMemcpy(outE8m0Host, outE8m0Bytes, outE8m0Device, outE8m0Bytes, + ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", outFp8Host, outFp8Bytes); + WriteFile("./v4.bin", outE8m0Host, outE8m0Bytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(idxDevice); + aclrtFree(outFp8Device); + aclrtFree(outE8m0Device); + aclrtFreeHost(srcHost); + aclrtFreeHost(idxHost); + aclrtFreeHost(outFp8Host); + aclrtFreeHost(outE8m0Host); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From a94920d3d460c881baa486b1f9db89a877d176b8 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Sun, 28 Jun 2026 20:21:10 +0800 Subject: [PATCH 34/54] fix: adapt vgather2 u16 offset carrier --- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 29 +++++++++++++++++-- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 29 +++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index 4d4b82f5a8..ee1da2c13f 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -3675,6 +3675,22 @@ static FailureOr buildVgather2Callee(MLIRContext *context, .getValue(); } +static FailureOr getVgather2OffsetsCarrierType(PatternRewriter &rewriter, + Type resultType, + Type offsetsType) { + Type elementType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elementType || !lanes) + return failure(); + + if (pto::getPTOStorageElemBitWidth(elementType) == 16) { + if (*lanes % 2 != 0) + return failure(); + return VectorType::get({*lanes / 2}, rewriter.getI32Type()); + } + return offsetsType; +} + static FailureOr buildVgather2BcCallee(MLIRContext *context, Type resultType) { return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); @@ -7225,13 +7241,22 @@ class LowerVgather2OpPattern final if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + Value offsets = adaptor.getOffsets(); + FailureOr offsetsCarrierType = getVgather2OffsetsCarrierType( + rewriter, op.getResult().getType(), offsets.getType()); + if (failed(offsetsCarrierType)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 offsets carrier"); + if (offsets.getType() != *offsetsCarrierType) + offsets = rewriter.create(op.getLoc(), *offsetsCarrierType, + offsets); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + TypeRange{adaptor.getSource().getType(), *offsetsCarrierType, adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + ValueRange{adaptor.getSource(), offsets, adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index bee22fed58..2de59cc620 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -3611,6 +3611,22 @@ static FailureOr buildVgather2Callee(MLIRContext *context, .getValue(); } +static FailureOr getVgather2OffsetsCarrierType(PatternRewriter &rewriter, + Type resultType, + Type offsetsType) { + Type elementType = getElementTypeFromVectorLike(resultType); + auto lanes = getElementCountFromVectorLike(resultType); + if (!elementType || !lanes) + return failure(); + + if (pto::getPTOStorageElemBitWidth(elementType) == 16) { + if (*lanes % 2 != 0) + return failure(); + return VectorType::get({*lanes / 2}, rewriter.getI32Type()); + } + return offsetsType; +} + static FailureOr buildVgather2BcCallee(MLIRContext *context, Type resultType) { return buildLaneTypedCallee(context, resultType, "vgather2.bc", ""); @@ -7152,13 +7168,22 @@ class LowerVgather2OpPattern final if (failed(calleeName)) return rewriter.notifyMatchFailure(op, "unsupported vgather2 signature"); + Value offsets = adaptor.getOffsets(); + FailureOr offsetsCarrierType = getVgather2OffsetsCarrierType( + rewriter, op.getResult().getType(), offsets.getType()); + if (failed(offsetsCarrierType)) + return rewriter.notifyMatchFailure(op, "unsupported vgather2 offsets carrier"); + if (offsets.getType() != *offsetsCarrierType) + offsets = rewriter.create(op.getLoc(), *offsetsCarrierType, + offsets); + auto funcType = rewriter.getFunctionType( - TypeRange{adaptor.getSource().getType(), adaptor.getOffsets().getType(), + TypeRange{adaptor.getSource().getType(), *offsetsCarrierType, adaptor.getMask().getType()}, TypeRange{resultType}); auto call = rewriter.create( op.getLoc(), *calleeName, TypeRange{resultType}, - ValueRange{adaptor.getSource(), adaptor.getOffsets(), adaptor.getMask()}); + ValueRange{adaptor.getSource(), offsets, adaptor.getMask()}); state.plannedDecls.push_back(PlannedDecl{calleeName->str(), funcType}); rewriter.replaceOp(op, call.getResults()); return success(); From a6ca1300606464c8784364b1d387423fc4990d8b Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Sun, 28 Jun 2026 21:28:32 +0800 Subject: [PATCH 35/54] Rename VMI sparse layout to lane stride --- docs/designs/vmi-implementation-manual.md | 4 +- docs/designs/vmi-introduction.md | 16 ++++--- .../vmi-layout-assignment-implementation.md | 8 ++++ .../vmi-layout-assignment-lowering-design.md | 13 ++++-- docs/designs/vmi-layout-lowering-cases.md | 39 +++++++++------- include/PTO/IR/VMIAttrs.td | 13 +++--- lib/PTO/IR/VMI.cpp | 45 +++++-------------- lib/PTO/Transforms/VMILayoutAssignment.cpp | 2 +- lib/PTO/Transforms/VMIToVPTO.cpp | 16 +++---- ..._layout_assignment_trunci_lane_stride.pto} | 8 ++-- test/lit/vmi/vmi_to_vpto_integer_casts.pto | 8 ++-- 11 files changed, 88 insertions(+), 84 deletions(-) rename test/lit/vmi/{vmi_layout_assignment_trunci_sparse.pto => vmi_layout_assignment_trunci_lane_stride.pto} (92%) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index e9790f56ea..4baf0756da 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -827,7 +827,7 @@ deinterleaved=4: part3 chunks for lanes 3,7,11,... num_groups=G: - sparse group-slot reduce result layout + group-slot reduce result layout physical storage is contiguous chunk order only canonical group_slot(g) lanes contain semantic values ``` @@ -3098,7 +3098,7 @@ pto.vmi.group_reduce_addf: requires {reassoc} N = logical lane count; G = num_groups; S = N / G L = physical lanes per 256B chunk for the element type. - The result carries #pto.vmi.layout, a sparse + The result carries #pto.vmi.layout, a group-slot group-slot layout. It is not a dense vector layout: only slot lanes have semantic values. Supported K values are: K = 8 for VCGADD-style packed results, where group g is stored in diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 8e884bf677..59c03655b2 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -125,18 +125,24 @@ element-parity consumer,assignment 可以选择 `deinterleaved=4, block_elems= ```mlir #pto.vmi.layout #pto.vmi.layout +#pto.vmi.layout ``` -这是 sparse group-result layout。它不表示全部 `N` 个 logical lane 都有语义值。 +这是 group-slot result layout。它不表示全部 `N` 个 logical lane 都有语义值。 只有 `G` 个 group 结果 slot 有语义值。 ```text slot_block(g) = g / K -slot_lane(g) = g % K +slot_lane(g) = (g % K) * lane_stride physical part slot_block(g) 的 lane slot_lane(g) 保存 group g 的结果 ``` +`lane_stride` 缺省为 1,单位是 logical element-sized physical slot。 +它描述 group result 在物理存储中的固定间距,不改变 VMI 的逻辑元素类型。 +例如 `ui8 lane_stride=4` 表示 group slot 存在 byte lane 0, 4, 8, ... +这种形态可以 lower 为 `PK4_B32` store,物理上使用 b32 carrier 的 low byte。 + `num_groups=16, slots=8` 的例子: ```text @@ -369,7 +375,7 @@ baseline assignment 保留 C2 已有的 natural layout;若没有 natural layou group_reduce: source 需要适配 group reduce 指令形态; - result 使用 group_slots(num_groups, slots) 描述 sparse group result。 + result 使用 group_slots(num_groups, slots) 描述 group-slot result。 cast: widening/narrowing 根据 cast support 决定 source request 和 result layout。 @@ -406,7 +412,7 @@ group_store value: stride_store value: wants contiguous。block/repeat stride 只描述 memory write address map, - 不表示 source vreg 是 sparse 或 NZ layout。 + 不表示 source vreg 是 lane-strided 或 NZ layout。 truncf/trunci/extf/extsi/extui source: wants cast support 给出的 source layout @@ -664,7 +670,7 @@ pto.vmi.store %sum, %dst[%off] 原因: ```text -dense store 不能把 sparse group_slots 当 dense vector 读取。 +dense store 不能把 group_slots 当 dense vector 读取。 应使用 group_store、group_broadcast 或显式支持的 group-to-dense op。 ``` diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index eb634ede79..1a5ef9f35a 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -199,6 +199,7 @@ Represent layout as a closed attribute family: #pto.vmi.layout #pto.vmi.layout #pto.vmi.layout +#pto.vmi.layout ``` C++ form: @@ -216,6 +217,7 @@ struct VMILayoutKey { int64_t blockElems = 1; int64_t numGroups = 0; int64_t slots = 0; + int64_t laneStride = 1; }; ``` @@ -235,6 +237,7 @@ group_slots: K > 0 G % K == 0 K fits in one physical vreg for element type + LS > 0 ``` Parser compatibility during migration: @@ -251,11 +254,16 @@ New `vmi-layout-assignment` output must print one of: ```text #pto.vmi.layout #pto.vmi.layout +#pto.vmi.layout ``` so `vmi-to-vpto` can lower from the assigned type without reconstructing group slot placement from producer or consumer context. +`lane_stride` is counted in logical element-sized physical slots and records a +regular gap between stored group slots. It is used for carrier-style packed +stores such as `ui8` group slots lowered through b32 `PK4_B32`. + ### 3.2 VMI Types Surface: diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 40edc656ad..1bfc46174f 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -165,7 +165,7 @@ design so far: physical dense layout: contiguous, deinterleaved=2/4, block_elems=1/8 -sparse result layout: +group-slot result layout: group_slots(G, slots=8) for packed VCG results group_slots(G, slots=1) for row-local S=64 results @@ -291,22 +291,29 @@ deinterleaved=2, block_elems=8 are different layouts. They cannot be treated as compatible because `F` is the same. -### 2.2 Sparse Group-Slot Layouts +### 2.2 Group-Slot Layouts ```text #pto.vmi.layout +#pto.vmi.layout ``` Only `G` lanes have semantic values: ```text slot_block(g) = g / K -slot_lane(g) = g % K +slot_lane(g) = (g % K) * LS ``` All non-slot lanes are undefined and may only be read by group-aware operations. Ordinary dense `add/mul/store/truncf` cannot consume `group_slots`. +`LS` defaults to 1 and is measured in logical element-sized physical slots. It +is not a new group semantic; it records regular physical spacing for each stored +group slot. For example, `ui8 lane_stride=4` maps slot values to byte lanes +0, 4, 8, ... and lets `group_store` lower through a b32 carrier `PK4_B32` +store. + `K` is selected by the assigned producer/result contract: ```text diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 3fbe2bd3f1..ab77a35632 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -56,12 +56,13 @@ B > 0 N % (F * B) == 0 for the direct full-chunk paths in this document ``` -### 1.2 Sparse Group-Slot Layout +### 1.2 Group-Slot Layout -Sparse group-slot layout is not dense. Only `G` lanes have semantic values. +Group-slot layout is not dense. Only `G` lanes have semantic values. ```text #pto.vmi.layout +#pto.vmi.layout ``` Physical slot mapping: @@ -71,7 +72,7 @@ N = logical lane count S = N / G // logical lanes per source group slot_block(g) = g / K -slot_lane(g) = g % K +slot_lane(g) = (g % K) * LS ``` Required invariants: @@ -81,8 +82,14 @@ G > 0 K > 0 G % K == 0 K must fit in the physical vreg element count +LS > 0 ``` +`LS` defaults to 1 and is counted in logical element-sized physical slots. It +is used when the group result value is intentionally stored with a regular lane +gap. For example, `ui8 lane_stride=4` places group slots in byte positions 0, +4, 8, ... and can be lowered to a b32 carrier plus `PK4_B32` store. + `K` is selected by the producer/consumer layout support rule. It is not always 8. For `VCGADD`-packed results, `K = 8` matches the eight 32B block results written to the low lanes of one destination vreg. For row-local reductions where each @@ -123,7 +130,7 @@ Illegal consumer mix: group_slots value -> ordinary dense store/add/mul ``` -This must fail unless an explicit semantic op converts the sparse value: +This must fail unless an explicit semantic op converts the group-slot value: ```text group_broadcast @@ -643,7 +650,7 @@ sum_out[group_tile_off + 1] = reduce row1 lanes 0..15 sum_out[group_tile_off + 7] = reduce row7 lanes 0..15 ``` -This endpoint is fully specified: the only sparse value is `%sum`; `group_store` +This endpoint is fully specified: the only group-slot value is `%sum`; `group_store` stores the low 8 slot lanes with an ordinary prefix store. #### 3.5.2 Reduce, Broadcast, Elementwise, Reduce, Store @@ -1659,7 +1666,7 @@ It must not be diagnosed as: dense store materializes group slots implicitly ``` -That behavior would silently reinterpret a sparse group-slot value as a dense +That behavior would silently reinterpret a group-slot value as a dense vector. ### 3.10 Non-Load Producer Feeding S=32 `group_reduce` @@ -2272,7 +2279,7 @@ group_load: loads group_size data elements per group and produces dense grouped data. group_slot_load: - loads one scalar value per group and produces sparse group slots. + loads one scalar value per group and produces group slots. ``` Surface form: @@ -2290,7 +2297,7 @@ semantic group slot g = base[off + g * source_group_stride] ``` The result logical lane count `N` remains the surrounding VMI value shape. Only -the `G` group slots are semantic. Layout assignment chooses the sparse physical +the `G` group slots are semantic. Layout assignment chooses the group-slot physical placement requested by the consumer: ```text @@ -2760,7 +2767,7 @@ VMI-LAYOUT-CONTRACT: ### 3.20 `group_slots` Control-Flow Join `group_slots` values must be allowed to cross control flow. The join type is a -sparse physical tuple, not a dense vector. +group-slot physical tuple, not a dense vector. VMI input: @@ -4249,7 +4256,7 @@ This does not legalize packed `slots = 8` casts from section 3.13. ### 3.35 `group_slots` Fanout To `group_store` And `group_broadcast` -This case fixes the fanout rule for sparse values. A `group_slots` value may +This case fixes the fanout rule for group-slot values. A `group_slots` value may feed multiple group-aware consumers directly. Layout assignment must not materialize it as dense just because one later use broadcasts it. @@ -4315,11 +4322,11 @@ VPTO lowering result for one full 8-row tile: : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> %sum_block = pto.vadd %lo_sum, %hi_sum, %slot8 : !pto.vreg<64xf32> -// First sparse consumer: store the group slots without changing layout. +// First group-slot consumer: store the group slots without changing layout. pto.vsts %sum_block, %sum_out[%group_off], %slot8 {dist = "NORM_B32"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask -// Second sparse consumer: materialize only this use as dense grouped data. +// Second group-slot consumer: materialize only this use as dense grouped data. %broadcast_idx0 = compute index vector [0 repeated 16, 1 repeated 16, 2 repeated 16, 3 repeated 16] : !pto.vreg<64xi32> @@ -4363,7 +4370,7 @@ Required assignment rule: `%sum` keeps one assigned layout: #pto.vmi.layout -`group_store` consumes that sparse layout directly. +`group_store` consumes that group-slot layout directly. `group_broadcast` is a use-site materialization to a dense layout. It must not rewrite the defining `group_reduce` result or the sibling `group_store` use. ``` @@ -4372,7 +4379,7 @@ rewrite the defining `group_reduce` result or the sibling `group_store` use. The same memory scalar stream may be used by both packed S=16 group-slot compute and row-local S=64 group-slot compute. The two uses require different -logical vector shapes and different sparse layouts, so the source must be +logical vector shapes and different group-slot layouts, so the source must be rematerialized as two VMI values. There is no single `group_slots` layout that serves both uses. @@ -4964,7 +4971,7 @@ materialization after seeing both users. ### 3.42 `group_slots` `scf.for` Loop-Carried Accumulator -Section 3.22 covers dense loop-carried values. Sparse group-slot values need a +Section 3.22 covers dense loop-carried values. Group-slot values need a separate case because the loop-carried block argument has no dense lane semantics outside the live group slots. @@ -5284,7 +5291,7 @@ Lowering: %x_p0, %x_p2 = pto.vdintlv %x01_lo, %x23_lo %x_p1, %x_p3 = pto.vdintlv %x01_hi, %x23_hi -// The reduce-side grouped mask is not built by guessing the final sparse +// The reduce-side grouped mask is not built by guessing the final group-slot // predicate image. It is first materialized as the same contiguous grouped // mask used by masked_load, then converted to the reduce layout with predicate // deinterleave. This keeps predicate reordering identical to the data diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index 867567de30..111c63fa8a 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -28,23 +28,20 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, int64_t factor, int64_t blockElems = 1); - static VMILayoutAttr getSparse(::mlir::MLIRContext *context, - int64_t sparseFactor); static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, int64_t numGroups, int64_t slots = 0, - int64_t sparseFactor = 1); + int64_t laneStride = 1); bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } - bool isSparse() const { return getKind() == "sparse"; } bool isGroupSlots() const { return getKind() == "num_groups"; } int64_t getNumGroups() const { return getFactor(); } - bool hasSparseFactor() const { - return isSparse() || (isGroupSlots() && getBlockElems() != 1); + bool hasLaneStride() const { + return isGroupSlots() && getBlockElems() != 1; } - int64_t getSparseFactor() const { - return isSparse() ? getFactor() : getBlockElems(); + int64_t getLaneStride() const { + return isGroupSlots() ? getBlockElems() : 1; } }]; } diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index d07423ae60..f3cb3f27b5 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -167,17 +167,17 @@ static FailureOr getLayoutBlockElems(Type type) { static FailureOr getVMIPhysicalElementType(VMIVRegType type) { Type elementType = type.getElementType(); VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.hasSparseFactor()) + if (!layout || !layout.hasLaneStride()) return elementType; auto integerType = dyn_cast(elementType); if (!integerType || !integerType.isUnsigned()) return failure(); unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); - int64_t sparseFactor = layout.getSparseFactor(); - if (elementBits == 0 || sparseFactor <= 1) + int64_t laneStride = layout.getLaneStride(); + if (elementBits == 0 || laneStride <= 1) return failure(); - int64_t physicalBits = static_cast(elementBits) * sparseFactor; + int64_t physicalBits = static_cast(elementBits) * laneStride; if (physicalBits != 16 && physicalBits != 32) return failure(); return IntegerType::get(type.getContext(), physicalBits); @@ -462,15 +462,10 @@ VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0); } -VMILayoutAttr VMILayoutAttr::getSparse(MLIRContext *context, - int64_t sparseFactor) { - return VMILayoutAttr::get(context, "sparse", sparseFactor, 1, 0); -} - VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, int64_t numGroups, int64_t slots, - int64_t sparseFactor) { - return VMILayoutAttr::get(context, "num_groups", numGroups, sparseFactor, + int64_t laneStride) { + return VMILayoutAttr::get(context, "num_groups", numGroups, laneStride, slots); } @@ -499,9 +494,6 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { return {}; } } - } else if (kind == "sparse") { - if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) - return {}; } else if (kind == "num_groups") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; @@ -512,20 +504,20 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { if (field == "slots") { if (failed(parser.parseInteger(slots))) return {}; - } else if (field == "sparse") { + } else if (field == "lane_stride") { if (failed(parser.parseInteger(blockElems))) return {}; } else { parser.emitError(parser.getCurrentLocation(), "expected 'slots = ' or " - "'sparse = '"); + "'lane_stride = '"); return {}; } } } else { parser.emitError(parser.getCurrentLocation(), "expected VMI layout kind 'contiguous' or " - "'deinterleaved' or 'sparse' or 'num_groups'"); + "'deinterleaved' or 'num_groups'"); return {}; } @@ -542,14 +534,12 @@ void VMILayoutAttr::print(AsmPrinter &printer) const { printer << " = " << getFactor(); if (getBlockElems() != 1) printer << ", block_elems = " << getBlockElems(); - } else if (isSparse()) { - printer << " = " << getFactor(); } else if (isGroupSlots()) { printer << " = " << getFactor(); if (getSlots() != 0) printer << ", slots = " << getSlots(); if (getBlockElems() != 1) - printer << ", sparse = " << getBlockElems(); + printer << ", lane_stride = " << getBlockElems(); } printer << ">"; } @@ -580,24 +570,13 @@ VMILayoutAttr::verify(function_ref emitError, return success(); } - if (kind == "sparse") { - if (factor <= 1) - return emitError() << "#pto.vmi.layout requires sparse factor greater than 1"; - if (blockElems != 1 || slots != 0) - return emitError() << "#pto.vmi.layout requires block_elems and slots to be their " - "defaults"; - return success(); - } - if (kind == "num_groups") { if (factor <= 0) return emitError() << "#pto.vmi.layout requires num_groups to be positive"; if (blockElems <= 0) return emitError() << "#pto.vmi.layout requires sparse factor to be positive"; + << "> requires lane_stride to be positive"; if (slots < 0) return emitError() << "#pto.vmi.layoutkind == VMICastLayoutKind::Narrow4x) resultLayout = VMILayoutAttr::getGroupSlots( ctx, sourceLayout.getNumGroups(), sourceLayout.getSlots(), - /*sparseFactor=*/4); + /*laneStride=*/4); if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 4e0c7f3829..9255d54ce1 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -244,17 +244,17 @@ materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, static FailureOr getVMIVRegPhysicalElementType(VMIVRegType type) { Type elementType = type.getElementType(); VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.hasSparseFactor()) + if (!layout || !layout.hasLaneStride()) return elementType; auto integerType = dyn_cast(elementType); if (!integerType || !integerType.isUnsigned()) return failure(); unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); - int64_t sparseFactor = layout.getSparseFactor(); - if (elementBits == 0 || sparseFactor <= 1) + int64_t laneStride = layout.getLaneStride(); + if (elementBits == 0 || laneStride <= 1) return failure(); - int64_t physicalBits = static_cast(elementBits) * sparseFactor; + int64_t physicalBits = static_cast(elementBits) * laneStride; if (physicalBits != 16 && physicalBits != 32) return failure(); return IntegerType::get(type.getContext(), physicalBits); @@ -4903,7 +4903,7 @@ struct OneToNVMIGroupStoreOpPattern bool packedByteStore = isPackedByteGroupStore( op.getDestination().getType(), firstVRegType); if (packedByteStore) { - bool sparsePackedByteStore = layout.hasSparseFactor(); + bool laneStridedPackedByteStore = layout.hasLaneStride(); for (Value value : valueParts) { auto vregType = dyn_cast(value.getType()); if (!vregType || vregType != firstVRegType) @@ -4916,7 +4916,7 @@ struct OneToNVMIGroupStoreOpPattern if (failed(maskType)) return rewriter.notifyMatchFailure( op, "unsupported element type for packed group_store mask"); - if (!sparsePackedByteStore && numGroups == 8 && + if (!laneStridedPackedByteStore && numGroups == 8 && valueParts.size() == 1 && isKnownIndexMultipleOf(*offset, 32)) { MLIRContext *ctx = rewriter.getContext(); @@ -7028,8 +7028,8 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { unsigned physicalResultBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (resultLayout.hasSparseFactor() && - resultLayout.getSparseFactor() == 4 && + if (resultLayout.hasLaneStride() && + resultLayout.getLaneStride() == 4 && pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()) == 8 && physicalResultBits == 32) { diff --git a/test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto similarity index 92% rename from test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto rename to test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto index 9ab9d16090..ec9da3042c 100644 --- a/test/lit/vmi/vmi_layout_assignment_trunci_sparse.pto +++ b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto @@ -12,7 +12,7 @@ // RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { - func.func @vmi_layout_assignment_group_slot_trunci_sparse( + func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { @@ -26,12 +26,12 @@ module { } } -// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_sparse( +// ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( // ASSIGN: %[[NARROW:.*]] = pto.vmi.trunci -// ASSIGN-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[NARROW]] -// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_sparse( +// LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( // LOWER-NOT: pto.vcvt // LOWER-NOT: pto.vpack // LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto index d0ab02f361..77ac39ada7 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_casts.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -83,16 +83,16 @@ module { return } - func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_sparse( + func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_lane_stride( %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %narrow = pto.vmi.trunci %wide : !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, !pto.ptr return } @@ -146,7 +146,7 @@ module { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast -// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_sparse( +// CHECK-LABEL: func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_lane_stride( // CHECK-NOT: pto.vcvt // CHECK-NOT: pto.vpack // CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<64xi32>, !pto.ptr, !pto.mask From b8e672434cf81cabb606bdc66ce6eb5f16f8c249 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Sun, 28 Jun 2026 22:53:37 +0800 Subject: [PATCH 36/54] Adjust VMI reduction result shapes --- docs/designs/vmi-implementation-manual.md | 52 ++++++------ docs/designs/vmi-introduction.md | 12 +-- .../vmi-layout-assignment-implementation.md | 23 +++--- include/PTO/IR/VMIOps.td | 8 +- lib/PTO/IR/VMI.cpp | 79 ++++++++++++++----- lib/PTO/Transforms/VMILayoutAssignment.cpp | 32 +++----- lib/PTO/Transforms/VMILayoutSupport.cpp | 24 +++--- lib/PTO/Transforms/VMIToVPTO.cpp | 27 +++---- .../vmi/vmi_group_reduce_addi_i16_invalid.pto | 2 +- .../vmi/vmi_group_reduce_addi_i8_invalid.pto | 2 +- .../vmi/vmi_group_reduce_maxi_i8_invalid.pto | 2 +- ...assignment_broadcast_dense_group_users.pto | 6 +- ...yout_assignment_call_argument_boundary.pto | 6 +- ...ayout_assignment_create_group_mask_s16.pto | 6 +- ...signment_create_group_mask_s32_dynamic.pto | 4 +- ...ment_dense_group_reduce_multi_consumer.pto | 8 +- ...gnment_dense_store_group_slots_invalid.pto | 8 +- ..._layout_assignment_f32_f8_store_reduce.pto | 6 +- ...ignment_group_broadcast_multi_consumer.pto | 12 +-- ...yout_assignment_group_broadcast_slots8.pto | 4 +- ...ut_assignment_group_load_block8_truncf.pto | 4 +- ...roup_load_s16_compact_stride12_invalid.pto | 4 +- ...assignment_group_load_s16_stride_store.pto | 6 +- ...roup_load_s16_unaligned_stride_invalid.pto | 4 +- ...group_load_s32_stride_broadcast_reduce.pto | 12 +-- ...assignment_group_load_s32_stride_store.pto | 8 +- ...roup_load_s32_unaligned_stride_invalid.pto | 4 +- ...out_assignment_group_reduce_maxf_quant.pto | 22 +++--- ...assignment_group_reduce_partial_slots8.pto | 20 ++--- ...ut_assignment_group_reduce_s12_invalid.pto | 6 +- ...yout_assignment_group_reduce_s16_store.pto | 8 +- ...roup_reduce_s16_truncf_broadcast_store.pto | 13 +-- ...mi_layout_assignment_group_reduce_s256.pto | 10 +-- ...ment_group_reduce_s32_broadcast_reduce.pto | 10 +-- ...nment_group_reduce_s32_multitile_store.pto | 8 +- ...yout_assignment_group_reduce_s32_store.pto | 8 +- ...gnment_group_reduce_s32_tail_full_tile.pto | 12 +-- ...p_reduce_s32_tail_no_full_tile_invalid.pto | 4 +- ...vmi_layout_assignment_group_reduce_s64.pto | 10 +-- ...ment_group_reduce_s64_broadcast_reduce.pto | 12 +-- ...assignment_group_reduce_s64_tail_store.pto | 6 +- ...out_assignment_group_reduce_s64_truncf.pto | 12 +-- ..._layout_assignment_group_reduce_slots8.pto | 10 +-- ...t_assignment_group_reduce_slots8_store.pto | 8 +- ...i_layout_assignment_group_reduce_typed.pto | 12 +-- .../vmi_layout_assignment_group_slot_load.pto | 42 +++++----- ...assignment_group_slot_load_dual_layout.pto | 32 ++++---- ...lot_load_slots1_dynamic_stride_invalid.pto | 4 +- ...t_load_slots1_unaligned_stride_invalid.pto | 4 +- ..._layout_assignment_group_slots_cf_join.pto | 18 ++--- ...i_layout_assignment_group_slots_fanout.pto | 16 ++-- ..._layout_assignment_group_slots_scf_for.pto | 26 +++--- ...ignment_group_store_slots1_unit_stride.pto | 4 +- ...signment_masked_load_dense_group_users.pto | 6 +- ..._assignment_masked_load_group_tail_s32.pto | 4 +- ..._layout_assignment_non_load_s32_reduce.pto | 8 +- ...ment_packed_group_slots_truncf_invalid.pto | 10 +-- ...i_layout_assignment_trunci_lane_stride.pto | 10 +-- ...yout_assignment_widen_f16_store_reduce.pto | 6 +- .../vmi_layout_gate_bitcast_group_slots.pto | 10 +-- ...t_gate_group_broadcast_support_invalid.pto | 4 +- ...te_group_reduce_slots1_support_invalid.pto | 2 +- ...yout_gate_group_reduce_support_invalid.pto | 2 +- ...t_gate_group_slot_load_support_invalid.pto | 2 +- ..._group_slots_unsupported_slots_invalid.pto | 6 +- ...ayout_gate_group_store_support_invalid.pto | 4 +- ...vmi_layout_gate_helper_support_invalid.pto | 6 +- test/lit/vmi/vmi_layout_gate_support.pto | 2 +- ...vmi_layout_gate_truncf_support_invalid.pto | 6 +- .../vmi/vmi_layout_group_slots_invalid.pto | 2 +- test/lit/vmi/vmi_op_verifier_basic.pto | 4 +- .../vmi/vmi_to_vpto_bitcast_group_slots.pto | 10 +-- .../vmi/vmi_to_vpto_group_broadcast_deint.pto | 4 +- ...group_broadcast_s32_deint2_small_group.pto | 4 +- .../vmi_to_vpto_group_broadcast_slots8.pto | 4 +- ...to_vpto_group_broadcast_slots8_support.pto | 4 +- .../vmi/vmi_to_vpto_group_broadcast_vselr.pto | 4 +- test/lit/vmi/vmi_to_vpto_group_ops.pto | 12 +-- ...vpto_group_reduce_legacy_slots_invalid.pto | 4 +- ...mi_to_vpto_group_reduce_partial_slots8.pto | 12 +-- ...mi_to_vpto_group_reduce_s256_broadcast.pto | 4 +- test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto | 4 +- .../vmi_to_vpto_group_reduce_s64_support.pto | 4 +- .../vmi/vmi_to_vpto_group_reduce_slots8.pto | 4 +- ...mi_to_vpto_group_reduce_slots8_support.pto | 4 +- .../vmi/vmi_to_vpto_group_reduce_typed.pto | 12 +-- .../vmi/vmi_to_vpto_group_reduce_vcgadd.pto | 4 +- ...to_vpto_group_reduce_vcgadd_multichunk.pto | 4 +- test/lit/vmi/vmi_to_vpto_group_slot_load.pto | 44 +++++------ ...group_slot_load_nonunit_slots8_invalid.pto | 6 +- .../vmi_to_vpto_group_slot_load_support.pto | 4 +- .../vmi_to_vpto_group_slot_truncf_slots1.pto | 8 +- ..._vpto_group_slot_truncf_slots1_support.pto | 8 +- .../vmi_to_vpto_group_store_slots1_1pt.pto | 4 +- ...oup_store_slots1_unit_stride_alignment.pto | 8 +- ...pto_group_store_slots8_nonunit_invalid.pto | 4 +- ...to_vpto_group_store_slots8_packed_byte.pto | 8 +- .../vmi/vmi_to_vpto_integer_cast_reduce.pto | 6 +- test/lit/vmi/vmi_to_vpto_integer_casts.pto | 24 +++--- test/lit/vmi/vmi_type_attr_parse.pto | 8 +- .../anti-mx-f8-bf16-scaled-16x512/kernel.pto | 16 ++-- .../anti-mx-f8-bf16-scaled-4x128/kernel.pto | 16 ++-- .../anti-mx-f8-bf16-scaled-64x2048/kernel.pto | 16 ++-- .../anti-mx-f8-f16-scaled-4x128/kernel.pto | 16 ++-- .../anti-mx-f8-f32-scaled-4x128/kernel.pto | 16 ++-- .../kernel.pto | 16 ++-- .../kernel.pto | 16 ++-- .../block-mx-quant-bf16-e4m3-4x128/kernel.pto | 41 +++++----- .../block-mx-quant-bf16-e5m2-4x128/kernel.pto | 41 +++++----- .../block-mx-quant-f16-e4m3-64x256/kernel.pto | 41 +++++----- .../block-mx-quant-f16-e5m2-8x256/kernel.pto | 41 +++++----- .../block-quant-bf16-fp8-2x128/kernel.pto | 12 +-- .../block-quant-bf16-fp8-32x128/kernel.pto | 12 +-- .../kernel.pto | 36 ++++----- .../block-quant-bf16-fp8-4x128/kernel.pto | 24 +++--- .../block-quant-f16-fp8-16x256/kernel.pto | 12 +-- .../block-quant-f16-fp8-4x256/kernel.pto | 12 +-- .../block-quant-f16-fp8-8x128/kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../kernel.pto | 12 +-- .../swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto | 42 +++++----- .../swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto | 42 +++++----- .../kernel.pto | 41 +++++----- .../kernel.pto | 41 +++++----- .../kernels/tquant-mxfp8-32x32-nd/kernel.pto | 40 +++++----- .../kernels/tquant-mxfp8-32x64-nz/kernel.pto | 40 +++++----- 130 files changed, 876 insertions(+), 864 deletions(-) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 4baf0756da..8824d605b4 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -2534,8 +2534,9 @@ bitcast: 65xf32 -> 130xi16, where the second physical chunk carries 32 logical bits on both sides, and uneven deinterleaved tails such as 129xf32 -> 129xi32. Partial/tail bitcast remains unsupported if source padding bits would become - result logical bits. group_slots bitcast is unsupported until a slot-wise - bitcast contract is defined. + result logical bits. group_slots bitcast follows the same rule: it is valid + only when the source/result group_slots layout is identical and every + physical group-slot chunk carries the same logical bit footprint. load: baseline result layout is deterministic from explicit layout attrs or the @@ -3019,7 +3020,7 @@ pto.vmi.reduce_addi: current direct lowering: source element width must be 32 bits; narrower vcadd widens its result and needs a separate result type plan source must materialize to one or more full physical chunks with no padding logical lanes - init/result must be rank-0 VMI vectors and each materialize to one physical chunk + init/result must be 1-lane VMI vectors and each materialize to one physical chunk mask must materialize to the same number of physical chunks as source lower as: first_lane = pto.pge_b32 "PAT_VL1" @@ -3048,7 +3049,7 @@ pto.vmi.reduce_addf: current direct lowering: source element type must be f32 source must materialize to one or more full physical chunks with no padding logical lanes - init/result must be rank-0 VMI vectors and each materialize to one physical chunk + init/result must be 1-lane VMI vectors and each materialize to one physical chunk mask must materialize to the same number of b32 physical chunks as source lower as: first_lane = pto.pge_b32 "PAT_VL1" @@ -3111,8 +3112,10 @@ pto.vmi.group_reduce_addf: Non-slot lanes are not consumed by pto.vmi.group_broadcast. The current direct lowering materializes them as zero where the hardware path does not already define them. - The result remains a VMI vector with the same element type and logical lane - count as the source, but its layout is an explicit group-slot layout. + The result remains a VMI vector with the same element type as the source, + but its logical lane count is G: one scalar result per group. Its layout + is an explicit group-slot layout that describes where those G scalars are + placed in physical registers. layout assignment: source use is requested as contiguous result natural layout is #pto.vmi.layout @@ -3120,7 +3123,9 @@ pto.vmi.group_reduce_addf: element width current direct lowering: source/result element type must be f32 - source, result, and mask must have matching physical arity and full chunks + source and mask must have compatible full physical chunks. The result is + `GxT` group-slot data and may have different physical arity from the + source tile. if S=8 for f32, lower each physical chunk with pto.vcgadd. This is the hardware 32B VLane group reduction path for f32: each source chunk produces eight 8-lane group sums in the low lanes of that physical chunk. The @@ -3139,18 +3144,19 @@ pto.vmi.group_reduce_addf: pto.vmi.group_broadcast: semantic: - N = logical lane count; G = num_groups; S = N / G - source must carry #pto.vmi.layout. For each group - g, the source value is read from the slot lane defined by K. The result broadcasts it back to - each logical group: + source logical lane count is G; result logical lane count is N. + S = N / G. + source must carry #pto.vmi.layout. For each + group g, the source value is read from the slot lane defined by K. The + result broadcasts it back to each logical group: result[g * S + i] = source[group_slot(g)] layout assignment: source use is requested as #pto.vmi.layout result is consumer-driven. If no consumer requests another layout, it defaults to contiguous. current direct lowering: - source must carry #pto.vmi.layout with full - physical chunks + source must carry #pto.vmi.layout with one + logical lane per group result may be contiguous with full physical chunks result may also be deinterleaved when S is large enough that every physical result chunk stays inside one logical group, for example N=512, G=2, S=256, @@ -3184,7 +3190,7 @@ pto.vmi.reduce_maxf / pto.vmi.reduce_minf: NaN and signed-zero behavior follows pto.vcmax/pto.vcmin for the chunk reduction and pto.vmax/pto.vmin for serial chunk accumulation. The index lane produced by pto.vcmax/pto.vcmin is ignored because VMI exposes only the - rank-0 value result. + 1-lane value result. layout assignment: source use is requested as contiguous init use is requested as contiguous @@ -3193,7 +3199,7 @@ pto.vmi.reduce_maxf / pto.vmi.reduce_minf: current direct lowering: source element type must be f16 or f32 source must materialize to one or more full physical chunks with no padding logical lanes - init/result must be rank-0 VMI vectors and each materialize to one physical chunk + init/result must be 1-lane VMI vectors and each materialize to one physical chunk mask must materialize to the same number of physical chunks as source lower reduce_maxf as: first_lane = pto.pge_b16/b32 "PAT_VL1" @@ -3261,12 +3267,11 @@ pto.vmi.truncf, direct path: pto.vmi.bitcast: for each physical part: emit pto.vbitcast(source_part) -> result_part_type - source/result layouts must match and must be contiguous/deinterleaved, - physical arity must match, and every corresponding physical chunk must carry - the same number of logical bits. Padding bits may map only to result padding - bits; any shape where source padding would become result logical data remains - unsupported. group_slots bitcast is rejected before vmi-to-vpto until it has - a slot-wise contract. + source/result layouts must match, physical arity must match, and every + corresponding physical chunk must carry the same number of logical bits. + This includes contiguous, deinterleaved, and identical group_slots layouts. + Padding bits may map only to result padding bits; any shape where source + padding would become result logical data remains unsupported. pto.vmi.channel_split / pto.vmi.channel_merge: support 2-way and 4-way channel transforms for contiguous per-channel values @@ -3596,7 +3601,7 @@ Unsupported diagnostics: or f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk unsupported pto.vmi.bitcast shape: - VMI-UNSUPPORTED: pto.vmi.bitcast requires matching non-group_slots source/result layouts with identical physical + VMI-UNSUPPORTED: pto.vmi.bitcast requires matching source/result layouts with identical physical arity and matching per-chunk logical bit footprints (...) unsupported pto.vmi.channel_split / pto.vmi.channel_merge channel count: @@ -4594,8 +4599,7 @@ use VMI-UNSUPPORTED in preflight: partial/tail memory access pred-only constant mask without concrete b8/b16/b32 granularity shuffle that requires vselr index-vector materialization - bitcast with mismatched per-chunk logical bit footprints or group_slots - bitcast without a slot-wise contract + bitcast with mismatched layouts or per-chunk logical bit footprints use VMI-RESIDUAL-OP: conversion framework finished but VMI op/type/helper/cast remains. diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 59c03655b2..8c61be40a7 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -660,10 +660,10 @@ vmi-to-vpto 之前没有物理 VPTO value 泄漏到 VMI IR 中 ```mlir %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} - : ... -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + : ... -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> pto.vmi.store %sum, %dst[%off] - : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr ``` @@ -689,7 +689,7 @@ VMI op 做 local lowering。 -> 两个 physical !pto.vreg<64xf32> part part0 携带 even lanes,part1 携带 odd lanes -!pto.vmi.vreg<256xf32, #pto.vmi.layout> +!pto.vmi.vreg<32xf32, #pto.vmi.layout> -> 四个 physical part part0 携带 group 0..7,part1 携带 group 8..15,... ``` @@ -824,8 +824,8 @@ Surface 意图: ```mlir %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} -%sum16 = pto.vmi.truncf %sum32 -%rows16 = pto.vmi.group_broadcast %sum16 {num_groups = 8} +%rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} +%rows16 = pto.vmi.truncf %rows32 pto.vmi.store %rows16, %dst[%off] ``` @@ -833,7 +833,7 @@ pto.vmi.store %rows16, %dst[%off] ```mlir %sum32 = pto.vmi.group_reduce_addf ... - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 1a5ef9f35a..d3778ab8fa 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -550,7 +550,11 @@ data/mask helper materialization: group_slot_load: assigned result layout is group_slots(G, slots=8) for packed slots or - group_slots(G, slots=1) for row-local slots. + group_slots(G, slots=1) for row-local slots. Because the result type is + `GxT`, assignment does not derive this choice from result lane count. A + constant unit `source_group_stride` selects slots=8; non-unit or dynamic + stride selects slots=1 first, then the support query rejects dynamic or + unaligned row-local lowering when the target cannot materialize it. block8 group_load: assigned result layout is deinterleaved=2/4 with block_elems=8 only when the @@ -588,10 +592,9 @@ extsi/extui/trunci: slot-wise transform is represented explicitly. bitcast: - per-part vbitcast is valid for contiguous/deinterleaved layouts when - source/result layouts match, physical arity matches, and every physical chunk - carries the same logical bit footprint. group_slots bitcast is unsupported - until a slot-wise bitcast contract is defined. + per-part vbitcast is valid when source/result layouts match, physical arity + matches, and every physical chunk carries the same logical bit footprint. + This includes contiguous, deinterleaved, and identical group_slots layouts. ``` `vmi-layout-fold-consumers`, rematerialization, sink/hoist, and private @@ -940,7 +943,7 @@ Canonical assigned IR shape for `group_broadcast` multi-use: ```text %b = pto.vmi.group_broadcast %slots - : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> %b_c = pto.vmi.ensure_layout %b @@ -1730,8 +1733,8 @@ runtime SIM: test/vpto/cases/vmi/group-slots-fanout-store-broadcast ``` -Current checked-in coverage for 3.8 `group_reduce -> truncf -> -group_broadcast -> dense store` and 3.17 `group_broadcast` feeding a +Current checked-in coverage for 3.8 `group_reduce -> group_broadcast -> +truncf -> dense store` and 3.17 `group_broadcast` feeding a deinterleaved consumer: ```text @@ -2078,7 +2081,7 @@ test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto test/lit/vmi/vmi_layout_gate_bitcast_support_invalid.pto -test/lit/vmi/vmi_layout_gate_bitcast_group_slots_invalid.pto +test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto ``` Current checked-in direct `vmi-to-vpto` preflight coverage for bitcast local @@ -2086,7 +2089,7 @@ lowering is: ```text test/lit/vmi/vmi_to_vpto_bitcast_footprint_invalid.pto -test/lit/vmi/vmi_to_vpto_bitcast_group_slots_invalid.pto +test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto ``` Current checked-in coverage for 3.32 f32 feeding f8 store and S=32 reduce: diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 1b44fc0e40..cc29ea4666 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -50,7 +50,7 @@ def VMIConstantOp : VMI_Op<"constant"> { } def VMIBroadcastOp : VMI_Op<"broadcast"> { - let summary = "Broadcast one scalar or rank-0 VMI vector to a VMI logical vector"; + let summary = "Broadcast one scalar or 1-lane VMI vector to a VMI logical vector"; let arguments = (ins AnyType:$value); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; @@ -367,7 +367,7 @@ def VMICompressStoreOp : VMI_Op<"compress_store", [DeclareOpInterfaceMethods { - let summary = "VMI masked integer add reduction with a rank-0 vector init"; + let summary = "VMI masked integer add reduction with a 1-lane vector init"; let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_VRegTypeConstraint:$init, VMI_MaskTypeConstraint:$mask); @@ -388,7 +388,7 @@ def VMIReduceAddFOp : VMI_Op<"reduce_addf"> { } def VMIReduceMaxFOp : VMI_Op<"reduce_maxf"> { - let summary = "VMI masked floating-point maximum reduction with a rank-0 vector init"; + let summary = "VMI masked floating-point maximum reduction with a 1-lane vector init"; let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_VRegTypeConstraint:$init, VMI_MaskTypeConstraint:$mask); @@ -398,7 +398,7 @@ def VMIReduceMaxFOp : VMI_Op<"reduce_maxf"> { } def VMIReduceMinFOp : VMI_Op<"reduce_minf"> { - let summary = "VMI masked floating-point minimum reduction with a rank-0 vector init"; + let summary = "VMI masked floating-point minimum reduction with a 1-lane vector init"; let arguments = (ins VMI_VRegTypeConstraint:$source, VMI_VRegTypeConstraint:$init, VMI_MaskTypeConstraint:$mask); diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index f3cb3f27b5..b84d63f1c4 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -641,11 +641,11 @@ LogicalResult VMIVRegType::verify(function_ref emitError, << "' expected layout to be #pto.vmi.layout"; if (auto layoutAttr = llvm::dyn_cast_or_null(layout)) { if (layoutAttr.isGroupSlots() && - elementCount % layoutAttr.getNumGroups() != 0) + elementCount != layoutAttr.getNumGroups()) return emitError() << "'" << formatVMIVRegType(elementCount, elementType, layout) - << "' expected num_groups layout to evenly divide " - "the VMI logical lane count"; + << "' expected num_groups layout to describe exactly " + "one logical result lane per group"; } return success(); @@ -1116,7 +1116,7 @@ LogicalResult VMIReduceAddIOp::verify() { return emitOpError( "requires source, init, and result element types to match"); if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) - return emitOpError("requires init and result to be rank-0 VMI vectors"); + return emitOpError("requires init and result to be 1-lane VMI vectors"); if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {initType, resultType}, /*requireSameElement=*/true))) @@ -1140,7 +1140,7 @@ LogicalResult VMIReduceAddFOp::verify() { return emitOpError( "requires source, init, and result element types to match"); if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) - return emitOpError("requires init and result to be rank-0 VMI vectors"); + return emitOpError("requires init and result to be 1-lane VMI vectors"); if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), {initType, resultType}, /*requireSameElement=*/true))) @@ -1161,7 +1161,7 @@ template LogicalResult verifyReduceMinMaxFOp(OpTy op) { return op.emitOpError( "requires source, init, and result element types to match"); if (initType.getElementCount() != 1 || resultType.getElementCount() != 1) - return op.emitOpError("requires init and result to be rank-0 VMI vectors"); + return op.emitOpError("requires init and result to be 1-lane VMI vectors"); if (failed(verifyAllSameVRegShapeAndLayout(op.getOperation(), {initType, resultType}, /*requireSameElement=*/true))) @@ -1185,9 +1185,9 @@ static LogicalResult verifyGroupReduceFloatOp(OpTy op, bool requiresReassoc) { if (!isVMIFloatLikeType(sourceType.getElementType())) return op.emitOpError( "requires floating-point-like VMI source element type"); - if (sourceType.getElementCount() != resultType.getElementCount()) + if (resultType.getElementCount() != op.getNumGroupsAttr().getInt()) return op.emitOpError( - "requires source and result logical lane counts to match"); + "requires result logical lane count to match num_groups"); if (sourceType.getElementType() != resultType.getElementType()) return op.emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { @@ -1238,9 +1238,9 @@ static LogicalResult verifyGroupReduceIntegerOp(OpTy op) { "requires i32 accumulator element type; cast i8/i16 storage to i32 " "before grouped reduction because integer reduction widens narrow " "inputs"); - if (sourceType.getElementCount() != resultType.getElementCount()) + if (resultType.getElementCount() != op.getNumGroupsAttr().getInt()) return op.emitOpError( - "requires source and result logical lane counts to match"); + "requires result logical lane count to match num_groups"); if (sourceType.getElementType() != resultType.getElementType()) return op.emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { @@ -1281,25 +1281,28 @@ LogicalResult VMIGroupReduceMaxIOp::verify() { LogicalResult VMIGroupBroadcastOp::verify() { auto sourceType = cast(getSource().getType()); auto resultType = cast(getResult().getType()); - if (sourceType.getElementCount() != resultType.getElementCount()) + int64_t numGroups = getNumGroupsAttr().getInt(); + if (sourceType.getElementCount() != numGroups) return emitOpError( - "requires source and result logical lane counts to match"); + "requires source logical lane count to match num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return emitOpError( + "requires num_groups to evenly divide result logical lane count"); if (sourceType.getElementType() != resultType.getElementType()) return emitOpError("requires source and result element types to match"); if (auto sourceLayout = sourceType.getLayoutAttr()) { if (!sourceLayout.isGroupSlots() || - sourceLayout.getNumGroups() != getNumGroupsAttr().getInt()) + sourceLayout.getNumGroups() != numGroups) return emitOpError() << "requires layout-assigned source to use " "#pto.vmi.layout"; + << numGroups << ">"; } if (auto resultLayout = resultType.getLayoutAttr()) { if (resultLayout.isGroupSlots()) return emitOpError( "requires layout-assigned result to use a dense VMI layout"); } - return verifyNumGroups(getOperation(), sourceType, - getNumGroupsAttr().getInt()); + return verifyNumGroups(getOperation(), resultType, numGroups); } template static LogicalResult verifyVMIHistogramOp(OpTy op) { @@ -1536,18 +1539,21 @@ void VMIGroupLoadOp::getEffects( LogicalResult VMIGroupSlotLoadOp::verify() { auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + if (resultType.getElementCount() != numGroups) + return emitOpError( + "requires result logical lane count to match num_groups"); if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), resultType, "source"))) return failure(); if (auto resultLayout = resultType.getLayoutAttr()) { if (!resultLayout.isGroupSlots() || - resultLayout.getNumGroups() != getNumGroupsAttr().getInt()) + resultLayout.getNumGroups() != numGroups) return emitOpError() << "requires layout-assigned result to use " "#pto.vmi.layout"; + << numGroups << ">"; } - return verifyNumGroups(getOperation(), resultType, - getNumGroupsAttr().getInt()); + return verifyNumGroups(getOperation(), resultType, numGroups); } void VMIGroupSlotLoadOp::getEffects( @@ -1958,6 +1964,16 @@ mlir::pto::mapLogicalLaneToPhysical(Type type, int64_t logicalLane) { if (logicalLane < 0 || logicalLane >= *elementCount) return failure(); + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + int64_t lane = logicalLane % slots; + if (lane >= *lanesPerPart) + return failure(); + return VMIPhysicalLane{/*part=*/0, logicalLane / slots, lane}; + } + int64_t part = 0; std::optional indexInPart = mapDenseLogicalLaneToPartIndex( *elementCount, *factor, *blockElems, logicalLane, part); @@ -1981,6 +1997,18 @@ FailureOr mlir::pto::mapPhysicalLaneToLogical(Type type, int64_t part, lane >= *lanesPerPart) return failure(); + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + if (part != 0 || lane >= slots) + return failure(); + int64_t logicalLane = chunk * slots + lane; + if (logicalLane >= *elementCount) + return failure(); + return logicalLane; + } + int64_t indexInPart = chunk * *lanesPerPart + lane; std::optional logicalLane = mapDensePartIndexToLogicalLane( *elementCount, *factor, *blockElems, part, indexInPart); @@ -2002,6 +2030,17 @@ FailureOr mlir::pto::isPaddingLane(Type type, int64_t part, int64_t chunk, lane >= *lanesPerPart) return failure(); + FailureOr layout = getAssignedVMILayout(type); + if (succeeded(layout) && (*layout).isGroupSlots() && + (*layout).getSlots() > 0) { + int64_t slots = (*layout).getSlots(); + if (part != 0) + return true; + if (lane >= slots) + return true; + return chunk * slots + lane >= *elementCount; + } + int64_t lanesInPart = getDenseLogicalLanesInPart(*elementCount, *factor, *blockElems, part); int64_t indexInPart = chunk * *lanesPerPart + lane; diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index be60564746..0350c48166 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -269,19 +269,17 @@ struct LayoutSolver { return getContiguousLayout(); } - VMILayoutAttr getPreferredGroupSlotLoadLayout(VMIVRegType type, - int64_t numGroups) { + VMILayoutAttr getPreferredGroupSlotLoadLayout(VMIGroupSlotLoadOp op) { + auto type = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); if (VMILayoutAttr existing = type.getLayoutAttr()) if (existing.isGroupSlots() && existing.getSlots() > 0) return existing; - if (numGroups > 0 && type.getElementCount() % numGroups == 0) { - int64_t groupSize = type.getElementCount() / numGroups; - FailureOr lanesPerPart = - getDataLanesPerPart(type.getElementType()); - if (succeeded(lanesPerPart) && groupSize == *lanesPerPart) - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); - } - return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + std::optional sourceGroupStride = + getConstantIndexValue(op.getSourceGroupStride()); + if (sourceGroupStride && *sourceGroupStride == 1) + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/8); + return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } VMILayoutAttr getPreferredGroupBroadcastSourceLayout(Value value, @@ -296,8 +294,8 @@ struct LayoutSolver { if (solved && solved.isGroupSlots() && solved.getNumGroups() == numGroups && solved.getSlots() > 0) return solved; - if (value.getDefiningOp()) - return getPreferredGroupSlotLoadLayout(type, numGroups); + if (auto load = value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(load); return getPreferredGroupSlotsLayout(type, numGroups); } @@ -370,8 +368,8 @@ struct LayoutSolver { value.getDefiningOp() || value.getDefiningOp()) return getPreferredGroupSlotsLayout(type, numGroups); - if (value.getDefiningOp()) - return getPreferredGroupSlotLoadLayout(type, numGroups); + if (auto load = value.getDefiningOp()) + return getPreferredGroupSlotLoadLayout(load); return getContiguousLayout(); } @@ -1244,12 +1242,8 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto load = dyn_cast(op)) { - auto resultType = cast(load.getResult().getType()); if (failed(setNaturalLayout( - load.getResult(), - getPreferredGroupSlotLoadLayout( - resultType, load.getNumGroupsAttr().getInt()), - op))) + load.getResult(), getPreferredGroupSlotLoadLayout(load), op))) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 393b5b687c..6920b0ca38 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -647,10 +647,12 @@ FailureOr VMILayoutSupport::getGroupSlotLoadSupport( }; auto resultType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (resultType.getElementCount() != numGroups) + return fail("requires result logical lane count to match num_groups"); VMILayoutAttr layout = resultType.getLayoutAttr(); if (!layout || !layout.isGroupSlots() || - layout.getNumGroups() != op.getNumGroupsAttr().getInt() || - layout.getSlots() <= 0) + layout.getNumGroups() != numGroups || layout.getSlots() <= 0) return fail("requires explicit group_slots result layout matching " "num_groups"); @@ -858,8 +860,8 @@ FailureOr getGroupReduceAddSupportImpl( if (sourceType.getElementType() != resultType.getElementType()) return fail("stable group_reduce_add layout support requires matching " "source/result element types"); - if (sourceType.getElementCount() != resultType.getElementCount()) - return fail("requires source/result lane count to match"); + if (resultType.getElementCount() != numGroups) + return fail("requires result lane count to match num_groups"); FailureOr groupSize = getGroupSizeFromNumGroups(sourceType, numGroups, reason); @@ -1028,9 +1030,12 @@ FailureOr VMILayoutSupport::getGroupBroadcastSupport( return failure(); }; - if (sourceType.getElementType() != resultType.getElementType() || - sourceType.getElementCount() != resultType.getElementCount()) - return fail("requires source/result shape and element type to match"); + if (sourceType.getElementType() != resultType.getElementType()) + return fail("requires source/result element type to match"); + if (sourceType.getElementCount() != numGroups) + return fail("requires source lane count to match num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide result lane count"); VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); VMILayoutAttr resultLayout = resultType.getLayoutAttr(); @@ -1046,9 +1051,6 @@ FailureOr VMILayoutSupport::getGroupBroadcastSupport( "layouts"); std::string fullChunkReason; - if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) - return fail(Twine("requires full source physical chunks; ") + - fullChunkReason); if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) return fail(Twine("requires full result physical chunks; ") + fullChunkReason); @@ -1062,7 +1064,7 @@ FailureOr VMILayoutSupport::getGroupBroadcastSupport( return fail("requires matching physical lanes per part"); FailureOr groupSize = - getGroupSizeFromNumGroups(sourceType, numGroups, reason); + getGroupSizeFromNumGroups(resultType, numGroups, reason); if (failed(groupSize)) return failure(); if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 9255d54ce1..2308337386 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -2444,20 +2444,19 @@ LogicalResult checkFullGroupSlotSourceShape( VMILayoutAttr layout = type.getLayoutAttr(); if (!layout || !layout.isGroupSlots() || layout.getNumGroups() != numGroups) return fail("group slot op requires matching num_groups VMI layout"); - if (failed(checkFullDataPhysicalChunks(type, nullptr))) - return fail("group slot op requires full physical chunks"); + if (type.getElementCount() != numGroups) + return fail("group slot op requires one logical lane per group"); FailureOr lanes = getDataLanesPerPart(type.getElementType()); if (failed(lanes)) return fail("group slot op requires known physical lanes per part"); - if (groupSize <= 0 || type.getElementCount() % groupSize != 0) - return fail("group slot op requires derived group size to evenly divide " - "lane count"); + if (groupSize <= 0) + return fail("group slot op requires positive derived group size"); if (*lanes % groupSize != 0 && groupSize % *lanes != 0) return fail("group slot op requires group size to divide or be a " "multiple of physical lanes per part"); *lanesPerPart = *lanes; - *groupCount = type.getElementCount() / groupSize; + *groupCount = numGroups; return success(); } @@ -6148,7 +6147,7 @@ struct OneToNVMIGroupBroadcastOpPattern auto sourceVMIType = cast(op.getSource().getType()); auto resultVMIType = cast(op.getResult().getType()); FailureOr groupSize = getGroupSizeFromNumGroups( - sourceVMIType, op.getNumGroupsAttr().getInt()); + resultVMIType, op.getNumGroupsAttr().getInt()); if (failed(groupSize)) return rewriter.notifyMatchFailure( op, @@ -8182,10 +8181,9 @@ LogicalResult checkSupportedGroupBroadcastShape( (void)capabilities; auto sourceType = cast(op.getSource().getType()); auto resultType = cast(op.getResult().getType()); - if (sourceType.getElementType() != resultType.getElementType() || - sourceType.getElementCount() != resultType.getElementCount()) { + if (sourceType.getElementType() != resultType.getElementType()) { if (reason) - *reason = "requires source/result shape and element type to match"; + *reason = "requires source/result element type to match"; return failure(); } auto fail = [&](const Twine &message) -> LogicalResult { @@ -8201,6 +8199,10 @@ LogicalResult checkSupportedGroupBroadcastShape( VMILayoutSupport supports; if (succeeded(supports.getGroupBroadcastSupport(capabilities, op, nullptr))) return success(); + if (sourceType.getElementCount() != op.getNumGroupsAttr().getInt()) + return fail("requires source lane count to match num_groups"); + if (resultType.getElementCount() % op.getNumGroupsAttr().getInt() != 0) + return fail("requires num_groups to evenly divide result lane count"); if (!sourceLayout.isGroupSlots() || sourceLayout.getNumGroups() != op.getNumGroupsAttr().getInt()) return fail("requires matching num_groups source layout"); @@ -8213,9 +8215,6 @@ LogicalResult checkSupportedGroupBroadcastShape( "layouts"); std::string fullChunkReason; - if (failed(checkFullDataPhysicalChunks(sourceType, &fullChunkReason))) - return fail(Twine("requires full source physical chunks; ") + - fullChunkReason); if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) return fail(Twine("requires full result physical chunks; ") + fullChunkReason); @@ -8228,7 +8227,7 @@ LogicalResult checkSupportedGroupBroadcastShape( *lanesPerPart != *resultLanesPerPart) return fail("requires matching physical lanes per part"); FailureOr groupSize = getGroupSizeFromNumGroups( - sourceType, op.getNumGroupsAttr().getInt(), reason); + resultType, op.getNumGroupsAttr().getInt(), reason); if (failed(groupSize)) return failure(); if (*lanesPerPart % *groupSize != 0 && *groupSize % *lanesPerPart != 0) diff --git a/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto index 948dfe9c54..33f3516efd 100644 --- a/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto +++ b/test/lit/vmi/vmi_group_reduce_addi_i16_invalid.pto @@ -16,7 +16,7 @@ module { %mask: !pto.vmi.mask<128xpred>) { %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} : !pto.vmi.vreg<128xi16>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xi16> + -> !pto.vmi.vreg<8xi16> return } } diff --git a/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto index 578acc00b9..973a57450e 100644 --- a/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto +++ b/test/lit/vmi/vmi_group_reduce_addi_i8_invalid.pto @@ -16,7 +16,7 @@ module { %mask: !pto.vmi.mask<256xpred>) { %sum = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xi8> + -> !pto.vmi.vreg<8xi8> return } } diff --git a/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto index b9416e81cf..756a8b3527 100644 --- a/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto +++ b/test/lit/vmi/vmi_group_reduce_maxi_i8_invalid.pto @@ -16,7 +16,7 @@ module { %mask: !pto.vmi.mask<256xpred>) { %max = pto.vmi.group_reduce_maxi %source, %mask {num_groups = 8} : !pto.vmi.vreg<256xi8>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xi8> + -> !pto.vmi.vreg<8xi8> return } } diff --git a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto index 1cbdeea1d8..4c8c8e3142 100644 --- a/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_broadcast_dense_group_users.pto @@ -36,9 +36,9 @@ module { %sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -60,7 +60,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[PROD]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_broadcast_dense_group_users( diff --git a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto index 00879170b1..c7049c75a2 100644 --- a/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto +++ b/test/lit/vmi/vmi_layout_assignment_call_argument_boundary.pto @@ -18,9 +18,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } @@ -48,7 +48,7 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN-LABEL: func.func @caller( // ASSIGN: %[[X:.*]] = pto.vmi.load diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto index 5999ace148..868624d330 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s16.pto @@ -21,9 +21,9 @@ module { : index -> !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -36,7 +36,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_create_group_mask_s16( // LOWER: pto.pset_b32 "PAT_ALL" diff --git a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto index fe5920c07b..447f4591b7 100644 --- a/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto +++ b/test/lit/vmi/vmi_layout_assignment_create_group_mask_s32_dynamic.pto @@ -28,9 +28,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto index 6ffab1471d..a7be1a67a4 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_group_reduce_multi_consumer.pto @@ -21,9 +21,9 @@ module { %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr pto.vmi.store %x, %copy_out[%off] : !pto.vmi.vreg<256xf32>, !pto.ptr return @@ -40,9 +40,9 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // ASSIGN: pto.vmi.store %[[X]] // ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr diff --git a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto index c8ded49a2f..a92d0c52f2 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_store_group_slots_invalid.pto @@ -17,11 +17,11 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xf32> - // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store operand #0 has type '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<64xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<64xf32, #pto.vmi.layout>' (unsupported source/result layout pair) + -> !pto.vmi.vreg<8xf32> + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.store operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) pto.vmi.store %sum, %dst[%off] - : !pto.vmi.vreg<64xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index 8dfe2292cf..783baac750 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -22,9 +22,9 @@ module { %mask = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %x8 = pto.vmi.truncf %x32 : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> pto.vmi.store %x8, %out8[%off] @@ -41,7 +41,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto index 20c2754e60..400c10093a 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -21,19 +21,19 @@ module { %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %b_for_mul = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %y = pto.vmi.mulf %x, %b_for_mul : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %b_for_cast = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %h = pto.vmi.truncf %b_for_cast : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %dense_out[%off] @@ -46,7 +46,7 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[B_MUL:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B_MUL]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto index 2c0f4f8ca7..8ac9e7dc06 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_slots8.pto @@ -10,10 +10,10 @@ module { func.func @vmi_layout_assignment_group_broadcast_slots8( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> !pto.vmi.vreg<1024xf32> { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32> return %out : !pto.vmi.vreg<1024xf32> } diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto index 9cf2720915..b96238db35 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto @@ -23,9 +23,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %h = pto.vmi.truncf %x : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %dense_dst[%off] diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto index 113467b492..8b60309aa6 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_compact_stride12_invalid.pto @@ -23,9 +23,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto index a3f045e503..60a0e75884 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_stride_store.pto @@ -22,9 +22,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -35,7 +35,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s16_stride_store( diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto index ed2ed892f9..61b30fd778 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s16_unaligned_stride_invalid.pto @@ -21,9 +21,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto index df03683335..cb6011a650 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_broadcast_reduce.pto @@ -22,18 +22,18 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -44,7 +44,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[B:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Y:.*]] = pto.vmi.mulf %[[X]], %[[B]] @@ -52,7 +52,7 @@ module { // ASSIGN: %[[MASK2:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[YSUM:.*]] = pto.vmi.group_reduce_addf %[[Y]], %[[MASK2]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[YSUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_broadcast_reduce( diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto index abe3301b90..34de7b5064 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_stride_store.pto @@ -22,9 +22,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -35,9 +35,9 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_load_s32_stride_store( // LOWER-COUNT-4: pto.vsldb diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto index 7cd5ffd85d..43b566f895 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_s32_unaligned_stride_invalid.pto @@ -21,9 +21,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto index 1ae3f90a15..e4600d1f77 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -24,19 +24,19 @@ module { %abs = pto.vmi.absf %x : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %amax_raw = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - %eps2 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> + %eps2 = pto.vmi.broadcast %eps : f32 -> !pto.vmi.vreg<2xf32> %amax = pto.vmi.maxf %amax_raw, %eps2 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> - %fp8_max2 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> + %fp8_max2 = pto.vmi.broadcast %fp8_max : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max2 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %scale_out[%off], %c8 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> @@ -53,9 +53,9 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[ABS:.*]] = pto.vmi.absf %[[X]] // ASSIGN: %[[AMAX_RAW:.*]] = pto.vmi.group_reduce_maxf %[[ABS]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> // ASSIGN: %[[SCALE:.*]] = pto.vmi.divf -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SCALE]] // ASSIGN: %[[SCALE_VEC:.*]] = pto.vmi.group_broadcast %[[SCALE]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto index e828ba6b2d..eb9296da40 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_partial_slots8.pto @@ -12,30 +12,30 @@ module { func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( %source: !pto.vmi.vreg<256xf16>, %mask: !pto.vmi.mask<256xpred>) - -> !pto.vmi.vreg<256xf16> { + -> !pto.vmi.vreg<4xf16> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 4, reassoc} : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf16> - return %out : !pto.vmi.vreg<256xf16> + -> !pto.vmi.vreg<4xf16> + return %out : !pto.vmi.vreg<4xf16> } func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( %source: !pto.vmi.vreg<768xf16>, %mask: !pto.vmi.mask<768xpred>) - -> !pto.vmi.vreg<768xf16> { + -> !pto.vmi.vreg<12xf16> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 12, reassoc} : !pto.vmi.vreg<768xf16>, !pto.vmi.mask<768xpred> - -> !pto.vmi.vreg<768xf16> - return %out : !pto.vmi.vreg<768xf16> + -> !pto.vmi.vreg<12xf16> + return %out : !pto.vmi.vreg<12xf16> } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g4( // CHECK-SAME: %arg0: !pto.vmi.vreg<256xf16, #pto.vmi.layout> // CHECK-SAME: %arg1: !pto.vmi.mask<256xb32, #pto.vmi.layout> -// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> // CHECK: %[[SRC4:.*]] = pto.vmi.ensure_layout %arg0 // CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> // CHECK: %[[MASK4_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 @@ -43,13 +43,13 @@ module { // CHECK: %[[MASK4:.*]] = pto.vmi.ensure_mask_granularity %[[MASK4_LAYOUT]] // CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> // CHECK: %[[OUT4:.*]] = pto.vmi.group_reduce_addf %[[SRC4]], %[[MASK4]] -// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> // CHECK: return %[[OUT4]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_f16_s64_g12( // CHECK-SAME: %arg0: !pto.vmi.vreg<768xf16, #pto.vmi.layout> // CHECK-SAME: %arg1: !pto.vmi.mask<768xb32, #pto.vmi.layout> -// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> // CHECK: %[[SRC12:.*]] = pto.vmi.ensure_layout %arg0 // CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> // CHECK: %[[MASK12_LAYOUT:.*]] = pto.vmi.ensure_mask_layout %arg1 @@ -57,5 +57,5 @@ module { // CHECK: %[[MASK12:.*]] = pto.vmi.ensure_mask_granularity %[[MASK12_LAYOUT]] // CHECK-SAME: -> !pto.vmi.mask<768xb16, #pto.vmi.layout> // CHECK: %[[OUT12:.*]] = pto.vmi.group_reduce_addf %[[SRC12]], %[[MASK12]] -// CHECK-SAME: -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> // CHECK: return %[[OUT12]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto index b63d134392..04dbe3952c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s12_invalid.pto @@ -19,13 +19,13 @@ module { // CHECK-SAME: stable group_reduce_add slots=8 support group sizes VLaneElems, 2*VLaneElems, or 4*VLaneElems // CHECK-SAME: VMI types: operand#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> // CHECK-SAME: operand#1=!pto.vmi.mask<96xb32, #pto.vmi.layout> - // CHECK-SAME: result#0=!pto.vmi.vreg<96xf32, #pto.vmi.layout> + // CHECK-SAME: result#0=!pto.vmi.vreg<8xf32, #pto.vmi.layout> %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32>, !pto.vmi.mask<96xpred> - -> !pto.vmi.vreg<96xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<96xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto index fb25c2bd91..d73d66570b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_store.pto @@ -19,9 +19,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -34,9 +34,9 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_store( // LOWER: %[[LO:.*]], %[[HI:.*]] = pto.vdintlv diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto index 6339aa15bc..a492118cad 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -17,12 +17,13 @@ module { %off: index) { %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> - %sum16 = pto.vmi.truncf %sum32 + -> !pto.vmi.vreg<8xf32> + %rows32 = pto.vmi.group_broadcast %sum32 {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + %rows16 = pto.vmi.truncf %rows32 : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> - %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> - pto.vmi.store %rows, %dst[%off] : !pto.vmi.vreg<128xf16>, !pto.ptr + pto.vmi.store %rows16, %dst[%off] + : !pto.vmi.vreg<128xf16>, !pto.ptr return } } @@ -35,7 +36,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %arg1 // ASSIGN-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto index 15fba5a1de..59fee4b8b9 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s256.pto @@ -11,18 +11,18 @@ module { func.func @vmi_layout_assignment_group_reduce_s256( %source: !pto.vmi.vreg<512xf32>, - %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<2xf32> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> - return %out : !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<2xf32> + return %out : !pto.vmi.vreg<2xf32> } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s256( // CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto index 7a72876ff9..d564ea6544 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_broadcast_reduce.pto @@ -19,18 +19,18 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %source, %broadcast : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scaled_sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -38,7 +38,7 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_broadcast_reduce( // ASSIGN-SAME: %[[SOURCE:arg[0-9]+]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf %[[SOURCE]], %[[BROADCAST]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto index b0d5a12676..93d471291b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_multitile_store.pto @@ -19,9 +19,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 16, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<16xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 16} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<16xf32>, !pto.ptr return } } @@ -34,9 +34,9 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<16xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<16xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_multitile_store( // LOWER-COUNT-8: pto.vdintlv diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto index 7fe8c425bf..443b8d822c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_store.pto @@ -19,9 +19,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -34,9 +34,9 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[SOURCE_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_store( // LOWER-COUNT-4: pto.vdintlv diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto index 31e83e37d7..372b445342 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_full_tile.pto @@ -20,9 +20,9 @@ module { %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<192xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> + -> !pto.vmi.vreg<6xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + : !pto.vmi.vreg<6xf32>, !pto.ptr return } @@ -36,9 +36,9 @@ module { %mask = pto.vmi.create_mask %c192 : index -> !pto.vmi.mask<256xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -51,7 +51,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<192xb32, #pto.vmi.layout> -> !pto.vmi.mask<192xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<192xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<6xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile( @@ -76,7 +76,7 @@ module { // ASSIGN: %[[PMASK:.*]] = pto.vmi.ensure_mask_layout %[[PMASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_reduce_addf %[[PX]], %[[PMASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s32_tail_full_tile_ptr_masked( // LOWER-COUNT-4: pto.vlds diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto index af78715f95..de04988cb0 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s32_tail_no_full_tile_invalid.pto @@ -24,9 +24,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<192xf32>, !pto.vmi.mask<192xpred> - -> !pto.vmi.vreg<192xf32> + -> !pto.vmi.vreg<6xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 6} - : !pto.vmi.vreg<192xf32>, !pto.ptr + : !pto.vmi.vreg<6xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto index 2901a43f7e..1f10d8a6ee 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64.pto @@ -11,18 +11,18 @@ module { func.func @vmi_layout_assignment_group_reduce_s64( %source: !pto.vmi.vreg<512xf32>, - %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<512xf32> { + %mask: !pto.vmi.mask<512xpred>) -> !pto.vmi.vreg<8xf32> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> - return %out : !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_s64( // CHECK-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> // CHECK-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto index 882d0cd30e..0e298f6cf1 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_broadcast_reduce.pto @@ -19,29 +19,29 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<512xf32> %scaled = pto.vmi.mulf %source, %broadcast : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scaled_sum, %dst[%off], %c8 {num_groups = 8} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_broadcast_reduce( // LOWER-COUNT-8: pto.vcgadd diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto index 1a68df9d86..ec2efac143 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_tail_store.pto @@ -18,9 +18,9 @@ module { %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<384xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> - -> !pto.vmi.vreg<384xf32> + -> !pto.vmi.vreg<6xf32> pto.vmi.group_store %sum, %dst[%off], %c8 {num_groups = 6} - : !pto.vmi.vreg<384xf32>, !pto.ptr + : !pto.vmi.vreg<6xf32>, !pto.ptr return } } @@ -29,7 +29,7 @@ module { // ASSIGN: %[[X:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]] -// ASSIGN-SAME: -> !pto.vmi.vreg<384xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<6xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s64_tail_store( diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto index d97210bc7b..ff0e67b9ad 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s64_truncf.pto @@ -18,11 +18,11 @@ module { %c16 = arith.constant 16 : index %sum32 = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %sum16 = pto.vmi.truncf %sum32 - : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> pto.vmi.group_store %sum16, %dst[%off], %c16 {num_groups = 8} - : !pto.vmi.vreg<512xf16>, !pto.ptr + : !pto.vmi.vreg<8xf16>, !pto.ptr return } } @@ -31,11 +31,11 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<512xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<512xb32, #pto.vmi.layout> // ASSIGN: %[[SUM32:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.truncf %[[SUM32]] -// ASSIGN-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM16]] -// ASSIGN-SAME: !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.ptr>, !pto.ptr, - %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<64xf32> { + %mask: !pto.vmi.mask<64xpred>) -> !pto.vmi.vreg<8xf32> { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xf32> - return %out : !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8( // CHECK-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // CHECK-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> -// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// CHECK-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto index 0042c64a15..f5af0d2ed1 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_slots8_store.pto @@ -19,9 +19,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<64xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -30,9 +30,9 @@ module { // ASSIGN-SAME: %arg0: !pto.vmi.vreg<64xf32, #pto.vmi.layout> // ASSIGN-SAME: %arg1: !pto.vmi.mask<64xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %arg0, %arg1 -// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_slots8_store( // LOWER: %[[SUM:.*]] = pto.vcgadd %arg0, %arg1 : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto index 34bf1c9633..7d4e4b6bc5 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_typed.pto @@ -20,15 +20,15 @@ module { %mi32: !pto.vmi.mask<128xpred>) { %sum_f16 = pto.vmi.group_reduce_addf %f16, %mf16 {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf16>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf16> + -> !pto.vmi.vreg<8xf16> %wide_i16 = pto.vmi.extsi %i16 : !pto.vmi.vreg<128xi16> -> !pto.vmi.vreg<128xi32> %sum_i16 = pto.vmi.group_reduce_addi %wide_i16, %mi16 {num_groups = 8} : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<8xi32> %sum_i32 = pto.vmi.group_reduce_addi %i32, %mi32 {num_groups = 8} : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<8xi32> return } } @@ -41,16 +41,16 @@ module { // CHECK: %[[MF16_B16:.*]] = pto.vmi.ensure_mask_granularity %[[MF16_SPLIT]] // CHECK-SAME: -> !pto.vmi.mask<256xb16, #pto.vmi.layout> // CHECK: pto.vmi.group_reduce_addf %[[F16_SPLIT]], %[[MF16_B16]] -// CHECK-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> // CHECK: %[[WIDE_I16:.*]] = pto.vmi.extsi %arg2 // CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> // CHECK: %[[MI16_SPLIT:.*]] = pto.vmi.ensure_mask_layout // CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // CHECK: pto.vmi.group_reduce_addi %[[WIDE_I16]], %[[MI16_SPLIT]] -// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> // CHECK: %[[I32_SPLIT:.*]] = pto.vmi.ensure_layout // CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> // CHECK: %[[MI32_SPLIT:.*]] = pto.vmi.ensure_mask_layout // CHECK-SAME: -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // CHECK: pto.vmi.group_reduce_addi %[[I32_SPLIT]], %[[MI32_SPLIT]] -// CHECK-SAME: -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto index 00c15db0a0..153408b78b 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load.pto @@ -10,65 +10,65 @@ module { func.func @vmi_layout_assignment_group_slot_load_slots8( - %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<128xf32> { + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<8xf32> { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> - return %out : !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> } func.func @vmi_layout_assignment_group_slot_load_slots1( %src: !pto.ptr, %off: index) - -> !pto.vmi.vreg<512xf32> { + -> !pto.vmi.vreg<8xf32> { %c8 = arith.constant 8 : index %out = pto.vmi.group_slot_load %src[%off], %c8 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<512xf32> - return %out : !pto.vmi.vreg<512xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> + return %out : !pto.vmi.vreg<8xf32> } func.func @vmi_layout_assignment_group_slot_load_slots8_store( %src: !pto.ptr, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } func.func @vmi_layout_assignment_group_slot_load_extui( - %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xui32> { + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<8xui32> { %c1 = arith.constant 1 : index %narrow = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %wide = pto.vmi.extui %narrow - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> - return %wide : !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> + return %wide : !pto.vmi.vreg<8xui32> } } // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8( -// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots1( -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: return %[[OUT]] // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_slots8_store( // CHECK: %[[OUT:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // CHECK: pto.vmi.group_store %[[OUT]] -// CHECK-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_load_extui( -// CHECK-SAME: -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> // CHECK: %[[NARROW:.*]] = pto.vmi.group_slot_load -// CHECK-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> +// CHECK-SAME: -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> // CHECK: %[[WIDE:.*]] = pto.vmi.extui %[[NARROW]] -// CHECK-SAME: !pto.vmi.vreg<256xui8, #pto.vmi.layout> -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> +// CHECK-SAME: !pto.vmi.vreg<8xui8, #pto.vmi.layout> -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> // CHECK: return %[[WIDE]] diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto index 9992655422..075a9d58a3 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_dual_layout.pto @@ -22,45 +22,45 @@ module { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %rhs16 = pto.vmi.group_slot_load %rhs_base[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %sum16 = pto.vmi.group_reduce_addf %source16, %mask16 {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %outv16 = pto.vmi.addf %sum16, %rhs16 - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %outv16, %out16[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %rhs64 = pto.vmi.group_slot_load %rhs_base[%off], %c8 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<512xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %sum64 = pto.vmi.group_reduce_addf %source64, %mask64 {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %outv64 = pto.vmi.addf %sum64, %rhs64 - : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> - -> !pto.vmi.vreg<512xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %outv64, %out64[%off], %c8 {num_groups = 8} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( // ASSIGN: %[[RHS16:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[SUM16:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM16]], %[[RHS16]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[RHS64:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[SUM64:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[SUM64]], %[[RHS64]] -// ASSIGN-SAME: -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_load_dual_layout( // LOWER: pto.pge_b32 "PAT_VL8" diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto index 01dab5b003..397afcb6da 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_dynamic_stride_invalid.pto @@ -16,9 +16,9 @@ module { // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" - // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + // CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<512xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto index 1589e531dc..6dbb34ad97 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_load_slots1_unaligned_stride_invalid.pto @@ -17,9 +17,9 @@ module { // CHECK-SAME: requires constant positive source_group_stride divisible by 8 elements // CHECK-SAME: packed or unaligned scalar load lowering is not implemented // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_slot_load" - // CHECK-SAME: !pto.vmi.vreg<512xf32, #pto.vmi.layout> + // CHECK-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> %out = pto.vmi.group_slot_load %src[%off], %c2 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<512xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto index d327a7b8bc..99e91bf98c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_cf_join.pto @@ -18,25 +18,25 @@ module { %c1 = arith.constant 1 : index %c128 = arith.constant 128 : index %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> - %sum = scf.if %cond -> !pto.vmi.vreg<128xf32> { + %sum = scf.if %cond -> !pto.vmi.vreg<8xf32> { %x = pto.vmi.load %src[%off] : !pto.ptr -> !pto.vmi.vreg<128xf32> %a = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> - scf.yield %a : !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %a : !pto.vmi.vreg<8xf32> } else { %b = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> - scf.yield %b : !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %b : !pto.vmi.vreg<8xf32> } %bias = pto.vmi.group_slot_load %rhs[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %out = pto.vmi.addf %sum, %bias - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto index 16905f1210..ccf77511d6 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_fanout.pto @@ -20,35 +20,35 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %scaled = pto.vmi.mulf %source, %broadcast : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scaled_sum, %out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // ASSIGN: %[[BROADCAST:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED:.*]] = pto.vmi.mulf // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[SCALED_SUM:.*]] = pto.vmi.group_reduce_addf %[[SCALED]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SCALED_SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_fanout( diff --git a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto index 95fa93474d..ec00c48c70 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slots_scf_for.pto @@ -20,9 +20,9 @@ module { %c2 = arith.constant 2 : index %c16 = arith.constant 16 : index %acc0 = pto.vmi.group_slot_load %init[%off], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %acc = scf.for %i = %c0 to %c2 step %c1 - iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<8xf32>) { %x = pto.vmi.group_load %base[%off], %c16 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32> %mask = pto.vmi.create_group_mask %c16 @@ -31,24 +31,24 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %next = pto.vmi.addf %arg, %sum - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> - scf.yield %next : !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %next : !pto.vmi.vreg<8xf32> } pto.vmi.group_store %acc, %out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( // ASSIGN: %[[ACC0:.*]] = pto.vmi.group_slot_load -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[ACC:.*]] = scf.for // ASSIGN-SAME: iter_args(%[[ARG:.*]] = %[[ACC0]]) -// ASSIGN-SAME: -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) +// ASSIGN-SAME: -> (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) // ASSIGN: %[[X:.*]] = pto.vmi.group_load // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: %[[MASK0:.*]] = pto.vmi.create_group_mask @@ -56,13 +56,13 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.addf %[[ARG]], %[[SUM]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: scf.yield -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[ACC]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_slots_scf_for( // LOWER: pto.vsldb diff --git a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto index 89e67118ab..44625678ac 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_store_slots1_unit_stride.pto @@ -18,9 +18,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto index 796f446b60..e179f5ccc6 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_dense_group_users.pto @@ -30,9 +30,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -51,7 +51,7 @@ module { // ASSIGN: %[[MASK_SPLIT:.*]] = pto.vmi.ensure_mask_layout %[[MASK]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X_SPLIT]], %[[MASK_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // LOWER-LABEL: func.func @vmi_layout_assignment_masked_load_dense_group_users( diff --git a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto index 9d3147aaea..debe3fd571 100644 --- a/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto +++ b/test/lit/vmi/vmi_layout_assignment_masked_load_group_tail_s32.pto @@ -28,9 +28,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %sum_out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto index dd8f2910ab..8bde94f611 100644 --- a/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_non_load_s32_reduce.pto @@ -26,9 +26,9 @@ module { %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } } @@ -45,9 +45,9 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] -// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_non_load_s32_reduce( // LOWER-COUNT-4: pto.vdup %arg1 diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index 71d282577a..f7bc538518 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -18,13 +18,13 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> - // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' (unsupported source/result layout pair) + -> !pto.vmi.vreg<8xf32> + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf16>, !pto.ptr + : !pto.vmi.vreg<8xf16>, !pto.ptr return } } diff --git a/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto index ec9da3042c..4a36d5882a 100644 --- a/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto +++ b/test/lit/vmi/vmi_layout_assignment_trunci_lane_stride.pto @@ -13,22 +13,22 @@ module { func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( - %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %narrow = pto.vmi.trunci %wide - : !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui8> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8> pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8>, !pto.ptr + : !pto.vmi.vreg<8xui8>, !pto.ptr return } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( // ASSIGN: %[[NARROW:.*]] = pto.vmi.trunci -// ASSIGN-SAME: -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[NARROW]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_slot_trunci_lane_stride( diff --git a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto index e9553c2c9d..9bf53802d4 100644 --- a/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_widen_f16_store_reduce.pto @@ -22,9 +22,9 @@ module { %mask = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> %sumv = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sumv, %sum[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr pto.vmi.store %x32, %dense[%off] : !pto.vmi.vreg<128xf32>, !pto.ptr return @@ -41,7 +41,7 @@ module { // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<128xb32, #pto.vmi.layout> -> !pto.vmi.mask<128xb32, #pto.vmi.layout> // ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X32_DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto index 50ac7a6b3f..8a69c96385 100644 --- a/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto +++ b/test/lit/vmi/vmi_layout_gate_bitcast_group_slots.pto @@ -10,12 +10,12 @@ module { func.func @vmi_layout_gate_bitcast_group_slots( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { %out = pto.vmi.bitcast %source - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> - return %out : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xi32, #pto.vmi.layout> } } diff --git a/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto index 64681b5dd3..b2124db3c2 100644 --- a/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_broadcast_support_invalid.pto @@ -10,12 +10,12 @@ module { func.func @vmi_layout_gate_group_broadcast_support_invalid( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support // CHECK-SAME: supports only slots=8 or slots=1 group_broadcast source layouts // CHECK: note: see current operation: %{{.*}} = "pto.vmi.group_broadcast" %out = pto.vmi.group_broadcast %source {num_groups = 8} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> return } diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto index 0c792693f3..9bebc83b97 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_slots1_support_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.vmi.mask<256xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto index 734c9dd497..172600b145 100644 --- a/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_reduce_support_invalid.pto @@ -19,7 +19,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<96xf32, #pto.vmi.layout>, !pto.vmi.mask<96xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<96xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto index 334be3d744..9c34f6a261 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slot_load_support_invalid.pto @@ -17,7 +17,7 @@ module { %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto index f3263148b3..b6f309f693 100644 --- a/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_slots_unsupported_slots_invalid.pto @@ -10,13 +10,13 @@ module { func.func @vmi_layout_gate_group_store_slots2_invalid( - %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: group_slots group_store currently supports only slots=1 or unit-stride slots=8 pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } @@ -34,7 +34,7 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.mask<128xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto index db0794748d..2676b55a1f 100644 --- a/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_group_store_support_invalid.pto @@ -10,14 +10,14 @@ module { func.func @vmi_layout_gate_group_store_support_invalid( - %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_store has no registered group_slots layout support // CHECK-SAME: slots=8 group_store currently requires constant unit row_stride // CHECK: note: see current operation: "pto.vmi.group_store" pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto index 90e49c52dd..1cf2548b79 100644 --- a/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_helper_support_invalid.pto @@ -10,10 +10,10 @@ module { func.func @vmi_layout_gate_helper_support_invalid( - %value: !pto.vmi.vreg<64xf32, #pto.vmi.layout>) { + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { %bad = pto.vmi.ensure_layout %value - : !pto.vmi.vreg<64xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_support.pto b/test/lit/vmi/vmi_layout_gate_support.pto index 629b85c208..9a48ea0721 100644 --- a/test/lit/vmi/vmi_layout_gate_support.pto +++ b/test/lit/vmi/vmi_layout_gate_support.pto @@ -15,7 +15,7 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto index 3021b88a7d..385aa56191 100644 --- a/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_truncf_support_invalid.pto @@ -10,13 +10,13 @@ module { func.func @vmi_layout_gate_truncf_support_invalid( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.truncf has no registered layout support // CHECK-SAME: group-slot truncf requires matching group_slots(num_groups=G, slots=1) // CHECK: note: see current operation: %{{.*}} = "pto.vmi.truncf" %out = pto.vmi.truncf %source - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> return } } diff --git a/test/lit/vmi/vmi_layout_group_slots_invalid.pto b/test/lit/vmi/vmi_layout_group_slots_invalid.pto index 1f4ccd2856..0f3717ee4b 100644 --- a/test/lit/vmi/vmi_layout_group_slots_invalid.pto +++ b/test/lit/vmi/vmi_layout_group_slots_invalid.pto @@ -10,7 +10,7 @@ module { func.func @vmi_layout_group_slots_invalid( - %arg0: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %arg0: !pto.vmi.vreg<10xf32, #pto.vmi.layout>) { return } } diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto index 0970134935..4ff2199aa7 100644 --- a/test/lit/vmi/vmi_op_verifier_basic.pto +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -41,9 +41,9 @@ module { %loaded = pto.vmi.load %ptr[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> %slot_loaded = pto.vmi.group_slot_load %ptr[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %slot_loaded, %ptr[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr pto.vmi.store %loaded, %ptr[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr %small = "pto.vmi.shuffle"(%broadcast) { diff --git a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto index d20be7e33e..e9ccc14ac2 100644 --- a/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto +++ b/test/lit/vmi/vmi_to_vpto_bitcast_group_slots.pto @@ -10,12 +10,12 @@ module { func.func @vmi_to_vpto_bitcast_group_slots( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> { + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { %out = pto.vmi.bitcast %source - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> - return %out : !pto.vmi.vreg<128xi32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xi32, #pto.vmi.layout> } } diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto index 9c2aff3759..7f9e02f144 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_deint.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_deint( - %sum: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %sum: !pto.vmi.vreg<2xf32, #pto.vmi.layout>, %src_f8: !pto.vmi.vreg<512xf8E4M3FN>) -> !pto.vmi.vreg<512xf32> { %src_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<2xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32> %out = pto.vmi.mulf %sum_vec, %src_f32 : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto index f82a877737..7b19a1254a 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_s32_deint2_small_group.pto @@ -10,10 +10,10 @@ module { func.func @vmi_to_vpto_group_broadcast_s32_deint2_small_group( - %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<4xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { %broadcast = pto.vmi.group_broadcast %source {num_groups = 4} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + : !pto.vmi.vreg<4xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> %p0, %p1 = "pto.vmi.unpack"(%broadcast) : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto index 01e40aaae7..6f23678fea 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8.pto @@ -10,14 +10,14 @@ module { func.func @vmi_to_vpto_group_broadcast_slots8( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto index 55ed864da1..c8268869a3 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_slots8_support.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_slots8_support( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto index 3c40457460..3b2452545c 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_vselr.pto @@ -10,13 +10,13 @@ module { func.func @vmi_to_vpto_group_broadcast_vselr( - %source: !pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) { %out = pto.vmi.group_broadcast %source {num_groups = 128} - : !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) diff --git a/test/lit/vmi/vmi_to_vpto_group_ops.pto b/test/lit/vmi/vmi_to_vpto_group_ops.pto index edef94f273..5abb0de3ab 100644 --- a/test/lit/vmi/vmi_to_vpto_group_ops.pto +++ b/test/lit/vmi/vmi_to_vpto_group_ops.pto @@ -21,12 +21,9 @@ module { %r = pto.vmi.group_reduce_addf %v, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - %b = pto.vmi.group_broadcast %r {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> - pto.vmi.group_store %b, %dst[%c0], %row_stride {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.ptr + -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> + pto.vmi.group_store %r, %dst[%c0], %row_stride {num_groups = 2} + : !pto.vmi.vreg<2xf32, #pto.vmi.layout>, !pto.ptr return } } @@ -36,8 +33,7 @@ module { // CHECK: pto.vcgadd // CHECK: pto.vselr // CHECK-COUNT-7: pto.vcgadd -// CHECK-COUNT-8: {position = "LOWEST"} -// CHECK-COUNT-8: pto.vsts +// CHECK-COUNT-2: pto.vsts {{.*}} {dist = "1PT_B32"} // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto index b3e48c56b4..1287a859c6 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_legacy_slots_invalid.pto @@ -18,9 +18,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto index 8efe26cf22..d2886fa5aa 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_partial_slots8.pto @@ -18,9 +18,9 @@ module { {num_groups = 4, reassoc} : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, !pto.vmi.mask<256xb16, #pto.vmi.layout> - -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<4xf16, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 4} - : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + : !pto.vmi.vreg<4xf16, #pto.vmi.layout>, !pto.ptr return } @@ -34,9 +34,9 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, !pto.vmi.mask<512xb16, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<512xf16, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf16, #pto.vmi.layout>, !pto.ptr return } @@ -50,9 +50,9 @@ module { {num_groups = 12, reassoc} : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, !pto.vmi.mask<768xb16, #pto.vmi.layout> - -> !pto.vmi.vreg<768xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<12xf16, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 12} - : !pto.vmi.vreg<768xf16, #pto.vmi.layout>, + : !pto.vmi.vreg<12xf16, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto index 5d06c115ac..cb790d5f91 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s256_broadcast.pto @@ -18,9 +18,9 @@ module { %sum = pto.vmi.group_reduce_addf %source, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<2xf32, #pto.vmi.layout> %broadcast = pto.vmi.group_broadcast %sum {num_groups = 2} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> + : !pto.vmi.vreg<2xf32, #pto.vmi.layout> -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%broadcast) : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto index 935eb2b80f..83b7ee9ae7 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64.pto @@ -19,9 +19,9 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto index 7b7bc76761..aba242c03f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_s64_support.pto @@ -18,9 +18,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, !pto.vmi.mask<512xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto index 2343869ceb..c8d5a85757 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8.pto @@ -17,9 +17,9 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto index 9e6a9faf00..d88c1cf1ad 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_slots8_support.pto @@ -16,9 +16,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto index f01c6865a1..88a8598a82 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_typed.pto @@ -18,9 +18,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.vmi.mask<128xb16, #pto.vmi.layout> - -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<128xf16, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) -> !pto.vreg<128xf16> return %part : !pto.vreg<128xf16> } @@ -35,9 +35,9 @@ module { %out = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, !pto.vmi.mask<128xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xi32, #pto.vmi.layout>) -> !pto.vreg<64xi32> return %part : !pto.vreg<64xi32> } @@ -49,9 +49,9 @@ module { %out = pto.vmi.group_reduce_addi %source, %mask {num_groups = 8} : !pto.vmi.vreg<128xi32, #pto.vmi.layout>, !pto.vmi.mask<128xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<128xi32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xi32, #pto.vmi.layout>) -> !pto.vreg<64xi32> return %part : !pto.vreg<64xi32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto index d6b52468b4..c9e77a1b67 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd.pto @@ -16,9 +16,9 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.mask<64xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<64xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto index d6265bd490..580ab5f4e5 100644 --- a/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto +++ b/test/lit/vmi/vmi_to_vpto_group_reduce_vcgadd_multichunk.pto @@ -19,10 +19,10 @@ module { %out = pto.vmi.group_reduce_addf %source, %mask {num_groups = 128, reassoc} : !pto.vmi.vreg<1024xf32, #pto.vmi.layout>, !pto.vmi.mask<1024xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<1024xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7, %p8, %p9, %p10, %p11, %p12, %p13, %p14, %p15 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<1024xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto index 7403599d5f..9a2d80feb1 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load.pto @@ -15,9 +15,9 @@ module { %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } @@ -30,9 +30,9 @@ module { %out = pto.vmi.group_slot_load %src[%off], %c8 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<512xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<512xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.vreg<64xf32>) return %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 @@ -46,9 +46,9 @@ module { %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> pto.vmi.group_store %out, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } @@ -59,24 +59,24 @@ module { %c23_i32 = arith.constant 23 : i32 %scale_u8 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + : !pto.vmi.vreg<8xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + : !pto.vmi.vreg<8xui32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + : i32 -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> %bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, - !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout>, + !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> %scale = pto.vmi.bitcast %bits - : !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> pto.vmi.store %vec, %dst[%off] : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, @@ -89,12 +89,12 @@ module { %c1 = arith.constant 1 : index %scale_u16 = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui16, #pto.vmi.layout> %scale_u32 = pto.vmi.extui %scale_u16 - : !pto.vmi.vreg<256xui16, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> + : !pto.vmi.vreg<8xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui32, #pto.vmi.layout> %vec = pto.vmi.group_broadcast %scale_u32 {num_groups = 8} - : !pto.vmi.vreg<256xui32, #pto.vmi.layout> + : !pto.vmi.vreg<8xui32, #pto.vmi.layout> -> !pto.vmi.vreg<256xui32, #pto.vmi.layout> pto.vmi.store %vec, %dst[%off] : !pto.vmi.vreg<256xui32, #pto.vmi.layout>, diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto index 8e58305a01..e16ca46fa3 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_nonunit_slots8_invalid.pto @@ -11,12 +11,12 @@ module { func.func @vmi_to_vpto_group_slot_load_nonunit_slots8_invalid( %src: !pto.ptr, %off: index, %stride: index) - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> { + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> { %out = pto.vmi.group_slot_load %src[%off], %stride {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> - return %out : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<8xf32, #pto.vmi.layout> } } diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto index c519205638..44ea0d5e54 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_load_support.pto @@ -14,9 +14,9 @@ module { %c1 = arith.constant 1 : index %out = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} : !pto.ptr - -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> %part = "pto.vmi.unpack"(%out) - : (!pto.vmi.vreg<128xf32, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> !pto.vreg<64xf32> return %part : !pto.vreg<64xf32> } diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto index 3f03f4669a..3b87c4c684 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1.pto @@ -10,15 +10,15 @@ module { func.func @vmi_to_vpto_group_slot_truncf_slots1( - %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) - : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) diff --git a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto index 4874117e69..5b44773a33 100644 --- a/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto +++ b/test/lit/vmi/vmi_to_vpto_group_slot_truncf_slots1_support.pto @@ -10,15 +10,15 @@ module { func.func @vmi_to_vpto_group_slot_truncf_slots1_support( - %source: !pto.vmi.vreg<512xf32, #pto.vmi.layout>) + %source: !pto.vmi.vreg<8xf32, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) { %narrow = pto.vmi.truncf %source - : !pto.vmi.vreg<512xf32, #pto.vmi.layout> - -> !pto.vmi.vreg<512xf16, #pto.vmi.layout> + : !pto.vmi.vreg<8xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xf16, #pto.vmi.layout> %p0, %p1, %p2, %p3, %p4, %p5, %p6, %p7 = "pto.vmi.unpack"(%narrow) - : (!pto.vmi.vreg<512xf16, #pto.vmi.layout>) + : (!pto.vmi.vreg<8xf16, #pto.vmi.layout>) -> (!pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.vreg<128xf16>) diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto index 8f949fc0f7..3dc813a4f2 100644 --- a/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_1pt.pto @@ -10,12 +10,12 @@ module { func.func @vmi_to_vpto_group_store_slots1_1pt( - %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c2 = arith.constant 2 : index pto.vmi.group_store %value, %dst[%off], %c2 {num_groups = 8} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto index dc68aed9db..d5a55bd75e 100644 --- a/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots1_unit_stride_alignment.pto @@ -10,25 +10,25 @@ module { func.func @aligned_unit_stride_group_store( - %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index pto.vmi.group_store %value, %dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } func.func @unaligned_unit_stride_group_store( - %value: !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr, %row: index) { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %off = arith.muli %row, %c2 : index pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<512xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto index 305b039d72..c09de38e51 100644 --- a/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_nonunit_invalid.pto @@ -10,12 +10,12 @@ module { func.func @vmi_to_vpto_group_store_slots8_nonunit_invalid( - %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index, %row_stride: index) { pto.vmi.group_store %value, %dst[%off], %row_stride {num_groups = 8} - : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xf32, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto index 08fa565554..4f0055bb4f 100644 --- a/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto +++ b/test/lit/vmi/vmi_to_vpto_group_store_slots8_packed_byte.pto @@ -10,23 +10,23 @@ module { func.func @vmi_to_vpto_group_store_slots8_i32_to_u8( - %value: !pto.vmi.vreg<1024xi32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<32xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index pto.vmi.group_store %value, %dst[%off], %c1 {num_groups = 32} - : !pto.vmi.vreg<1024xi32, #pto.vmi.layout>, + : !pto.vmi.vreg<32xi32, #pto.vmi.layout>, !pto.ptr return } func.func @vmi_to_vpto_group_store_slots8_i32_to_u8_padded( - %value: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %value: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, %dst: !pto.ptr) { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index pto.vmi.group_store %value, %dst[%c32], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + : !pto.vmi.vreg<8xi32, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto index c3e7403e91..4f844e469a 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_cast_reduce.pto @@ -14,15 +14,15 @@ module { func.func @vmi_extsi_i8_to_i32_group_reduce( %source: !pto.vmi.vreg<256xi8>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>) - -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> { + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> { %wide = pto.vmi.extsi %source : !pto.vmi.vreg<256xi8> -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> %sum = pto.vmi.group_reduce_addi %wide, %mask {num_groups = 8} : !pto.vmi.vreg<256xi32, #pto.vmi.layout>, !pto.vmi.mask<256xb32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> - return %sum : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xi32, #pto.vmi.layout> + return %sum : !pto.vmi.vreg<8xi32, #pto.vmi.layout> } } diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto index 77ac39ada7..caf4a09525 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_casts.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -56,43 +56,43 @@ module { } func.func @vmi_to_vpto_group_slot_trunci_i32_to_ui8( - %wide: !pto.vmi.vreg<128xi32, #pto.vmi.layout>, + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %narrow = pto.vmi.trunci %wide - : !pto.vmi.vreg<128xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<128xui8, #pto.vmi.layout> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xui8, #pto.vmi.layout>, + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, !pto.ptr return } func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8( - %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %narrow = pto.vmi.trunci %wide - : !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, !pto.ptr return } func.func @vmi_to_vpto_group_slot_trunci_slots8_i32_to_ui8_lane_stride( - %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>, + %wide: !pto.vmi.vreg<8xi32, #pto.vmi.layout>, %dst: !pto.ptr, %off: index) { %c1 = arith.constant 1 : index %narrow = pto.vmi.trunci %wide - : !pto.vmi.vreg<256xi32, #pto.vmi.layout> - -> !pto.vmi.vreg<256xui8, #pto.vmi.layout> + : !pto.vmi.vreg<8xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<8xui8, #pto.vmi.layout> pto.vmi.group_store %narrow, %dst[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8, #pto.vmi.layout>, + : !pto.vmi.vreg<8xui8, #pto.vmi.layout>, !pto.ptr return } diff --git a/test/lit/vmi/vmi_type_attr_parse.pto b/test/lit/vmi/vmi_type_attr_parse.pto index b2001c29f0..41c795bcdb 100644 --- a/test/lit/vmi/vmi_type_attr_parse.pto +++ b/test/lit/vmi/vmi_type_attr_parse.pto @@ -24,8 +24,8 @@ module attributes { %wide2: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %wide4: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %wide4_block8: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, - %group_slots8: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, - %group_slots_partial: !pto.vmi.vreg<640xf32, #pto.vmi.layout>, + %group_slots8: !pto.vmi.vreg<8xf32, #pto.vmi.layout>, + %group_slots_partial: !pto.vmi.vreg<10xf32, #pto.vmi.layout>, %surface_mask: !pto.vmi.mask<128xpred>, %mask_b8: !pto.vmi.mask<128xb8, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, @@ -47,8 +47,8 @@ module attributes { // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<640xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<8xf32, #pto.vmi.layout> +// CHECK-SAME: %{{.*}}: !pto.vmi.vreg<10xf32, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xpred> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb8, #pto.vmi.layout> // CHECK-SAME: %{{.*}}: !pto.vmi.mask<128xb16, #pto.vmi.layout> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto index 130e50768d..391687c3a5 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-16x512/kernel.pto @@ -60,20 +60,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto index 8063c2406b..f0def9ef3f 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-4x128/kernel.pto @@ -49,20 +49,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto index 2ffb1b8951..2e73913362 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-bf16-scaled-64x2048/kernel.pto @@ -72,20 +72,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto index 8ca4aa1654..d346c28a3c 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f16-scaled-4x128/kernel.pto @@ -49,20 +49,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto index cd291a4465..e05ab49cc7 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8-f32-scaled-4x128/kernel.pto @@ -49,20 +49,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto index 8b6ce29071..1d31ac88ac 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-16x512/kernel.pto @@ -60,20 +60,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto index dec0c5ecca..4907537215 100644 --- a/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/anti-mx-f8e5m2-bf16-scaled-4x128/kernel.pto @@ -49,20 +49,20 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xui8> + : !pto.ptr -> !pto.vmi.vreg<8xui8> %scale_u32 = pto.vmi.extui %scale_u8 - : !pto.vmi.vreg<256xui8> -> !pto.vmi.vreg<256xui32> + : !pto.vmi.vreg<8xui8> -> !pto.vmi.vreg<8xui32> %scale_i32 = pto.vmi.bitcast %scale_u32 - : !pto.vmi.vreg<256xui32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xui32> -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_i32, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %wide, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto index 4010bd6b29..48d6c0d146 100644 --- a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e4m3-4x128/kernel.pto @@ -64,43 +64,40 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c8_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_slot = arith.divui %row, %c2 : index %scale_ub_off = arith.muli %scale_slot, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto index 7c0ce6174b..4d42e59c69 100644 --- a/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-bf16-e5m2-4x128/kernel.pto @@ -64,43 +64,40 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c15_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_slot = arith.divui %row, %c2 : index %scale_ub_off = arith.muli %scale_slot, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto index c248b12f8b..a71ca57db9 100644 --- a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e4m3-64x256/kernel.pto @@ -66,42 +66,39 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c8_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_ub_off = arith.muli %row, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto index 40587088c3..d8c2377acc 100644 --- a/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-mx-quant-f16-e5m2-8x256/kernel.pto @@ -66,42 +66,39 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c15_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_ub_off = arith.muli %row, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale1[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto index b920c0da85..21664de067 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-2x128/kernel.pto @@ -47,16 +47,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto index a7162689b8..551f06b94d 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-32x128/kernel.pto @@ -49,16 +49,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto index 4d2b183fed..cc6456ea14 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128-min-scale/kernel.pto @@ -51,21 +51,21 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale_raw = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> %scale_limit_v = pto.vmi.broadcast %scale_limit - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.minf %scale_raw, %scale_limit_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> @@ -85,21 +85,21 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale_raw = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> %scale_limit_v = pto.vmi.broadcast %scale_limit - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.minf %scale_raw, %scale_limit_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto index 330e6341ec..b9a9b04a78 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-bf16-fp8-4x128/kernel.pto @@ -51,16 +51,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> @@ -80,16 +80,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%c2], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto index 014cb9b8ce..5655f47848 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-16x256/kernel.pto @@ -55,16 +55,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto index 5a447251cc..8362807227 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-4x256/kernel.pto @@ -55,16 +55,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto index e4f00cab6e..ad3fa32760 100644 --- a/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/block-quant-f16-fp8-8x128/kernel.pto @@ -55,16 +55,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %fp8_max_v = pto.vmi.broadcast %fp8_max - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %fp8_max_v - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%scale_off], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %q = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto index 76941c1daa..80172d52e1 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-bf16-4x32/kernel.pto @@ -53,16 +53,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<8xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto index f40ca1d9d7..babe48feb5 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-16x128/kernel.pto @@ -58,16 +58,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> pto.vmi.group_store %scale, %ub_scale[%row], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto index deba0f577f..761ff9be84 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-f16-4x32/kernel.pto @@ -53,16 +53,16 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<8xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scale, %ub_scale[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %x, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto index e5b384e96b..35b641663a 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-bf16-8x64/kernel.pto @@ -95,17 +95,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<4xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<4xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<4xf32>, !pto.vmi.vreg<4xf32> + -> !pto.vmi.vreg<4xf32> %scale_offset = arith.muli %tile, %c8 : index pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<4xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<4xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %smoothed, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto index 8bd115dac2..2fd9708cc9 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-16x128/kernel.pto @@ -99,17 +99,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 2} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<2xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<2xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32>, !pto.vmi.vreg<2xf32> + -> !pto.vmi.vreg<2xf32> %scale_offset = arith.muli %tile, %c8 : index pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 2} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<2xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 2} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %smoothed, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto index 4c8691ad4d..37d6ec43b2 100644 --- a/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto +++ b/test/vpto/cases/vmi/kernels/dynamic-quant-pertoken-smooth-f16-8x64/kernel.pto @@ -95,17 +95,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 4} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<4xf32> %max_int8 = pto.vmi.broadcast %c127_f32 - : f32 -> !pto.vmi.vreg<256xf32> + : f32 -> !pto.vmi.vreg<4xf32> %scale = pto.vmi.divf %amax, %max_int8 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<4xf32>, !pto.vmi.vreg<4xf32> + -> !pto.vmi.vreg<4xf32> %scale_offset = arith.muli %tile, %c8 : index pto.vmi.group_store %scale, %ub_scale[%scale_offset], %c1 {num_groups = 4} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<4xf32>, !pto.ptr %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 4} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<4xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.divf %smoothed, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto index 49aacafc8f..0500e465c8 100644 --- a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e4m3-4x8/kernel.pto @@ -98,42 +98,40 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c8_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_u8 = pto.vmi.trunci %e8m0_i32 - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xui8> pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8>, !pto.ptr - + : !pto.vmi.vreg<8xui8>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %swiglu, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto index b3cd60e99a..8375e83240 100644 --- a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-bf16-e5m2-4x8/kernel.pto @@ -98,42 +98,40 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c15_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_u8 = pto.vmi.trunci %e8m0_i32 - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xui8> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xui8> pto.vmi.group_store %e8m0_u8, %ub_scale[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xui8>, !pto.ptr - + : !pto.vmi.vreg<8xui8>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %swiglu, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto index 32d9fc4985..bd0b116a53 100644 --- a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e4m3-64x512/kernel.pto @@ -93,42 +93,39 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c8_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_ub_off = arith.muli %row, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %swiglu, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto index 5cbb12c575..86aa4e4a35 100644 --- a/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto +++ b/test/vpto/cases/vmi/kernels/swiglu-mx-quant-f16-e5m2-128x256/kernel.pto @@ -103,42 +103,39 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %c2139095040_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %c23_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %c15_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %c254_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> - + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_ub_off = arith.muli %row, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_scale[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scale_vec = pto.vmi.group_broadcast %scale {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %swiglu, %scale_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto index da8dca54e6..1b880ec786 100644 --- a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x32-nd/kernel.pto @@ -54,42 +54,40 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %exp_mask_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %shift_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %emax_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_slot = arith.divui %row, %c8 : index %scale_ub_off = arith.muli %scale_slot, %c32 : index pto.vmi.group_store %e8m0_i32, %ub_out_e8m0[%scale_ub_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr - + : !pto.vmi.vreg<8xi32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scaling = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scaling_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> diff --git a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto index 8d09bf51d6..fe9e8ba375 100644 --- a/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto +++ b/test/vpto/cases/vmi/kernels/tquant-mxfp8-32x64-nz/kernel.pto @@ -76,40 +76,38 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %amax = pto.vmi.group_reduce_maxf %abs, %mask_f32 {num_groups = 8} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> - + -> !pto.vmi.vreg<8xf32> %amax_bits = pto.vmi.bitcast %amax - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xi32> %exp_mask = pto.vmi.broadcast %exp_mask_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %shift = pto.vmi.broadcast %shift_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %emax = pto.vmi.broadcast %emax_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %scale_exp_bias = pto.vmi.broadcast %scale_exp_bias_i32 - : i32 -> !pto.vmi.vreg<256xi32> + : i32 -> !pto.vmi.vreg<8xi32> %exp_bits = pto.vmi.andi %amax_bits, %exp_mask - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %exp = pto.vmi.shrui %exp_bits, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %e8m0_i32 = pto.vmi.subi %exp, %emax - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> pto.vmi.group_store %amax, %ub_max[%scale_off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr - + : !pto.vmi.vreg<8xf32>, !pto.ptr %scale_exp = pto.vmi.subi %scale_exp_bias, %e8m0_i32 - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scale_bits = pto.vmi.shli %scale_exp, %shift - : !pto.vmi.vreg<256xi32>, !pto.vmi.vreg<256xi32> - -> !pto.vmi.vreg<256xi32> + : !pto.vmi.vreg<8xi32>, !pto.vmi.vreg<8xi32> + -> !pto.vmi.vreg<8xi32> %scaling = pto.vmi.bitcast %scale_bits - : !pto.vmi.vreg<256xi32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xi32> -> !pto.vmi.vreg<8xf32> %scaling_vec = pto.vmi.group_broadcast %scaling {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %scaling_vec : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> From abc4af547e9d7fff55a7afa428a92476ad5ab565 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Sun, 28 Jun 2026 23:23:47 +0800 Subject: [PATCH 37/54] Add VMI simdvf per-block FP8 cast kernel case --- test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md | 19 +++- test/vpto/cases/vmi/kernels/README.md | 3 +- .../simdvf-per-block-cast-to-fp8/compare.py | 33 +++++++ .../simdvf-per-block-cast-to-fp8/golden.py | 75 +++++++++++++++ .../simdvf-per-block-cast-to-fp8/kernel.pto | 78 +++++++++++++++ .../simdvf-per-block-cast-to-fp8/launch.cpp | 43 +++++++++ .../simdvf-per-block-cast-to-fp8/main.cpp | 94 +++++++++++++++++++ .../simdvf-per-block-cast-to-fp8/ptoas.flags | 1 + 8 files changed, 341 insertions(+), 5 deletions(-) create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp create mode 100644 test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags diff --git a/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md index c92e060d37..dd5df9a0bc 100644 --- a/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md +++ b/test/vpto/cases/vmi/kernels/CCE_CASE_SCOPE.md @@ -30,6 +30,7 @@ VMI kernel 迁移范围。当前审计快照为本地 clone 的 | `dequant/anti_mx_quant` | `dequant/anti_mx_quant/test/test_equivalence.py` | 16 | 先支持 FP8 case | FP4 输入因 VMI FP4 surface 未设计而暂缓 | | `block_mx_quant` | `block_mx_quant/test/test_equivalence.py`; `test_cce.py` 是更宽的 smoke/correctness surface | 14 canonical,30 full union | 先支持 canonical FP8/OCP/rint 行 | FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | | `swiglu_mx_quant` | `swiglu_mx_quant/test/test_equivalence.py` | 14 | 先支持 FP8/OCP/rint f16/bf16 行 | FP4 暂缓;CCE 源码中 `scale_alg=1` CUBLAS 路径异常 | +| `simdvf_per_block_cast` | PTOAS PR #488 | 1 | 支持 16x256 f16 + 4x8 f32 scale -> fp8 cast | 当前只迁移该 PR 提供的确定 shape | | `tutorial/block_mx_quant` | `tutorial/block_mx_quant/README.md` | 已由 `block_mx_quant` 覆盖 | BF16 FP8 tutorial shape 作为代表覆盖 | tutorial FP4 与主 `block_mx_quant` 共用 FP4 blocker | ## quant_minimum @@ -64,6 +65,14 @@ VMI kernel 迁移范围。当前审计快照为本地 clone 的 和 bandwidth sweep 主要验证更大 streaming shape;只有当它们暴露新的 VMI memory/layout 约束时,才增加代表性 runtime case。 +## simdvf per-block cast to FP8 + +来源:PTOAS PR #488。 + +| 目标 case | VMI 支持状态 | +| --- | --- | +| FP16 input `(16,256)`,FP32 scale `(4,8)`,每个 scale 覆盖 4 行 x 32 列,输出 FP8 E4M3 | 必须支持 | + ## dynamic_quant 来源:`.work/external/a5-kernel-standalone/cce/dynamic_quant/test/test_dq_equivalence.py`。 @@ -186,9 +195,9 @@ BF16 output-type case,cross-check 有 7 个 byte-exact case,但当前快照 该目录已裁剪为 target-scoped runtime case。删除的 case 只有在目标仓库新增匹配的正确性入口, 或迁移到独立的非目标 probe suite 后,才应重新引入。 -当前 `test/vpto/cases/vmi/kernels` 已缩减为 35 个 case 目录。上面的目标 CCE canonical -正确性范围在 `block_mx_quant` 采用 14-case canonical suite 时有 64 行;如果把 -`block_mx_quant/test_cce.py` 作为完整 small-shape surface union 计入,则有 80 行。 +当前 `test/vpto/cases/vmi/kernels` 已缩减为 36 个 case 目录。上面的目标 CCE canonical +正确性范围在 `block_mx_quant` 采用 14-case canonical suite 时有 65 行;如果把 +`block_mx_quant/test_cce.py` 作为完整 small-shape surface union 计入,则有 81 行。 这些数量不能直接和当前支持集比较,因为目标列表仍包含当前 VMI 有意暂缓的 FP4 行。 | Area | 当前 VMI 目录数 | 目标 canonical 正确性 | 差异 | @@ -199,13 +208,14 @@ BF16 output-type case,cross-check 有 7 个 byte-exact case,但当前快照 | `anti_mx_quant` | 7 | 16 | 保留当前 FP8 目标行;暂缓的 FP4 行不表达 | | `block_mx_quant` | 4 | 14 canonical / 30 full union | 保留 canonical FP8 目标行;暂缓的 FP4/DDR 和额外 `test_cce.py` union 行不表达 | | `swiglu_mx_quant` | 4 | 14 | 保留当前 FP8/OCP 目标行;暂缓的 FP4 和异常 CUBLAS 行不表达 | +| `simdvf_per_block_cast` | 1 | 1 | 对齐 PR #488 | | historical `anti_quant` | 0 | 0 | 已从 target-scoped 目录移除 | | historical `swiglu_quant` | 0 | 0 | 已从 target-scoped 目录移除 | | other probe | 0 | 0 | 已从 target-scoped 目录移除 | ## 当前支持目录清单 -当前 target-scoped runtime 目录精确包含以下 35 个 VMI case: +当前 target-scoped runtime 目录精确包含以下 36 个 VMI case: | CCE family | VMI case 目录 | | --- | --- | @@ -215,6 +225,7 @@ BF16 output-type case,cross-check 有 7 个 byte-exact case,但当前快照 | `dequant/anti_mx_quant` | `anti-mx-f8-bf16-scaled-4x128`, `anti-mx-f8-f32-scaled-4x128`, `anti-mx-f8-f16-scaled-4x128`, `anti-mx-f8-bf16-scaled-16x512`, `anti-mx-f8-bf16-scaled-64x2048`, `anti-mx-f8e5m2-bf16-scaled-4x128`, `anti-mx-f8e5m2-bf16-scaled-16x512` | | `block_mx_quant` | `block-mx-quant-bf16-e4m3-4x128`, `block-mx-quant-f16-e4m3-64x256`, `block-mx-quant-bf16-e5m2-4x128`, `block-mx-quant-f16-e5m2-8x256` | | `swiglu_mx_quant` | `swiglu-mx-quant-bf16-e4m3-4x8`, `swiglu-mx-quant-f16-e4m3-64x512`, `swiglu-mx-quant-bf16-e5m2-4x8`, `swiglu-mx-quant-f16-e5m2-128x256` | +| `simdvf_per_block_cast` | `simdvf-per-block-cast-to-fp8` | | 已移除的 VMI 区域 | 范围说明 | | --- | --- | diff --git a/test/vpto/cases/vmi/kernels/README.md b/test/vpto/cases/vmi/kernels/README.md index ff0fad599f..cab387fdb4 100644 --- a/test/vpto/cases/vmi/kernels/README.md +++ b/test/vpto/cases/vmi/kernels/README.md @@ -19,7 +19,7 @@ See LICENSE in the root of the software repository for the full text of the Lice ## 当前目录范围 -当前目录保留 35 个 runtime case: +当前目录保留 36 个 runtime case: | CCE family | 当前 case 数 | 范围 | | --- | ---: | --- | @@ -29,6 +29,7 @@ See LICENSE in the root of the software repository for the full text of the Lice | `dequant/anti_mx_quant` | 7 | 当前保留 VMI 能表达的 FP8 行;FP4 输入暂缓 | | `block_mx_quant` | 4 | 当前保留 canonical FP8/OCP 等价性行;FP4、DDR `scale_alg=2` 和额外 `test_cce.py` union 行暂缓 | | `swiglu_mx_quant` | 4 | 当前保留 FP8/OCP 等价性行;FP4 和 CCE 已标记异常的 CUBLAS `scale_alg=1` 暂缓 | +| `simdvf_per_block_cast` | 1 | 对齐 PTOAS PR #488 中的 16x256 f16 + 4x8 scale -> fp8 per-block cast case | ## 设计上暂缓 diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py new file mode 100644 index 0000000000..2d18033478 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/compare.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import sys + +import numpy as np + + +def main() -> None: + golden_out = np.fromfile("golden_v3.bin", dtype=np.uint8) + out = np.fromfile("v3.bin", dtype=np.uint8) + + if golden_out.shape != out.shape or not np.array_equal(golden_out, out): + diff = np.nonzero(golden_out != out)[0] + idx = int(diff[0]) if diff.size else -1 + print( + f"[ERROR] fp8 compare failed idx={idx} " + f"golden={int(golden_out[idx]) if idx >= 0 else 'n/a'} " + f"output={int(out[idx]) if idx >= 0 else 'n/a'}" + ) + sys.exit(2) + + print("[INFO] compare passed") + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py new file mode 100644 index 0000000000..22ff7c2109 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/golden.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import argparse +from pathlib import Path + +import numpy as np + +ROWS = 16 +COLS = 256 +SCALE_ROWS = 4 +SCALE_COLS = 8 +TOKENS_PER_SCALE_ROW = 4 +CHANNELS_PER_SCALE = 32 +SENTINEL_U8 = np.uint8(0xA5) + +Q_VALUES = np.array( + [0.0, 1.0, -1.0, 0.5, 2.0, -2.0, 4.0, -4.0, 448.0], dtype=np.float32 +) +F8E4M3FN_BYTES = np.array( + [0x00, 0x38, 0xB8, 0x30, 0x40, 0xC0, 0x48, 0xC8, 0x7E], dtype=np.uint8 +) + + +def generate(output_dir: Path) -> None: + scale = np.array( + [ + [0.25, 0.5, 1.0, 2.0, 0.25, 0.5, 1.0, 2.0], + [0.5, 1.0, 2.0, 4.0, 0.5, 1.0, 2.0, 4.0], + [1.0, 2.0, 4.0, 0.25, 1.0, 2.0, 4.0, 0.25], + [2.0, 4.0, 0.25, 0.5, 2.0, 4.0, 0.25, 0.5], + ], + dtype=np.float32, + ) + + repeats = (CHANNELS_PER_SCALE + len(Q_VALUES) - 1) // len(Q_VALUES) + q_block = np.tile(Q_VALUES, repeats)[:CHANNELS_PER_SCALE].astype(np.float32) + f8_block = np.tile(F8E4M3FN_BYTES, repeats)[:CHANNELS_PER_SCALE] + + src = np.empty((ROWS, COLS), dtype=np.float16) + golden_out = np.empty((ROWS, COLS), dtype=np.uint8) + for row in range(ROWS): + scale_row = row // TOKENS_PER_SCALE_ROW + for scale_col in range(SCALE_COLS): + start = scale_col * CHANNELS_PER_SCALE + stop = start + CHANNELS_PER_SCALE + src[row, start:stop] = (q_block / scale[scale_row, scale_col]).astype( + np.float16 + ) + golden_out[row, start:stop] = f8_block + + out = np.full((ROWS, COLS), SENTINEL_U8, dtype=np.uint8) + + output_dir.mkdir(parents=True, exist_ok=True) + src.view(np.uint16).tofile(output_dir / "v1.bin") + scale.reshape(-1).tofile(output_dir / "v2.bin") + out.tofile(output_dir / "v3.bin") + golden_out.tofile(output_dir / "golden_v3.bin") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=Path, default=Path(".")) + args = parser.parse_args() + generate(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto new file mode 100644 index 0000000000..0e6c80648e --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/kernel.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_simdvf_per_block_cast_to_fp8_kernel(%src_gm: !pto.ptr, + %scale_gm: !pto.ptr, + %out8_gm: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %num_per_tokens = arith.constant 4 : index + %num_sf_rows_per_block = arith.constant 4 : index + %num_per_channels = arith.constant 32 : index + %block_k = arith.muli %c8, %num_per_channels : index + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %c128_i64 = arith.constant 128 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + + %ub_src = pto.castptr %c0_i64 : i64 -> !pto.ptr + %ub_scale = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %ub_out8_u8 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + %ub_out8_f8 = pto.castptr %c12288_i64 : i64 -> !pto.ptr + + pto.mte_gm_ub %src_gm, %ub_src, %c0_i64, %c8192_i64 + nburst(%c1_i64, %c8192_i64, %c8192_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + pto.mte_gm_ub %scale_gm, %ub_scale, %c0_i64, %c128_i64 + nburst(%c1_i64, %c128_i64, %c128_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64, i64 + + pto.set_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + pto.wait_flag["PIPE_MTE2", "PIPE_V", "EVENT_ID0"] + + pto.vecscope { + scf.for %sf_i = %c0 to %num_sf_rows_per_block step %c1 { + %sf_row_offset = arith.muli %sf_i, %c8 : index + %sf_slots = pto.vmi.group_slot_load %ub_scale[%sf_row_offset], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sf = pto.vmi.group_broadcast %sf_slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + scf.for %token_j = %c0 to %num_per_tokens step %c1 { + %sf_token_row_base = arith.muli %sf_i, %c4 : index + %token_row = arith.addi %sf_token_row_base, %token_j : index + %row_elem_offset = arith.muli %token_row, %block_k : index + %x16 = pto.vmi.load %ub_src[%row_elem_offset] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %scaled = pto.vmi.mulf %x32, %sf + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + %out8 = pto.vmi.truncf %scaled + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + pto.vmi.store %out8, %ub_out8_f8[%row_elem_offset] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + } + + pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.wait_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] + pto.mte_ub_gm %ub_out8_u8, %out8_gm, %c4096_i64 + nburst(%c1_i64, %c4096_i64, %c4096_i64) + : !pto.ptr, !pto.ptr, i64, i64, i64, i64 + pto.barrier #pto.pipe + return + } +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp new file mode 100644 index 0000000000..f053235e1b --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/launch.cpp @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef __VEC_SCOPE__ +#define __VEC_SCOPE__ +#endif +#if defined(__CCE_AICORE__) && defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201) +typedef struct { unsigned char v; } hifloat8_t; +typedef struct { unsigned char v; } float8_e4m3_t; +typedef struct { unsigned char v; } float8_e5m2_t; +typedef struct { unsigned char v; } float8_e8m0_t; +typedef struct { unsigned char v; } float4_e1m2x2_t; +typedef struct { unsigned char v; } float4_e2m1x2_t; +#endif +#include +#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) +struct MrgSortExecutedNumList { + uint16_t mrgSortList0; + uint16_t mrgSortList1; + uint16_t mrgSortList2; + uint16_t mrgSortList3; +}; +#endif +#ifndef __CPU_SIM +#include "acl/acl.h" +#endif + +extern "C" __global__ [aicore] void +vmi_simdvf_per_block_cast_to_fp8_kernel(__gm__ half *src, + __gm__ float *scale, + __gm__ uint8_t *out8); + +void LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(uint16_t *src, float *scale, + uint8_t *out8, + void *stream) { + vmi_simdvf_per_block_cast_to_fp8_kernel<<<1, nullptr, stream>>>( + (__gm__ half *)src, (__gm__ float *)scale, (__gm__ uint8_t *)out8); +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp new file mode 100644 index 0000000000..2fa1b028d4 --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "acl/acl.h" +#include "test_common.h" +#include +#include +#include + +using namespace PtoTestCommon; + +#define ACL_CHECK(expr) \ + do { \ + const aclError _ret = (expr); \ + if (_ret != ACL_SUCCESS) { \ + std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ + (int)_ret, __FILE__, __LINE__); \ + rc = 1; \ + goto cleanup; \ + } \ + } while (0) + +void LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(uint16_t *src, float *scale, + uint8_t *out8, + void *stream); + +int main() { + constexpr size_t kRows = 16; + constexpr size_t kCols = 256; + constexpr size_t kElems = kRows * kCols; + constexpr size_t kScaleRows = 4; + constexpr size_t kScaleCols = 8; + constexpr size_t kScaleElems = kScaleRows * kScaleCols; + size_t srcBytes = kElems * sizeof(uint16_t); + size_t scaleBytes = kScaleElems * sizeof(float); + size_t outBytes = kElems * sizeof(uint8_t); + uint16_t *srcHost = nullptr; + float *scaleHost = nullptr; + uint8_t *outHost = nullptr; + uint16_t *srcDevice = nullptr; + float *scaleDevice = nullptr; + uint8_t *outDevice = nullptr; + int rc = 0; + bool aclInited = false; + bool deviceSet = false; + int deviceId = 0; + aclrtStream stream = nullptr; + + ACL_CHECK(aclInit(nullptr)); + aclInited = true; + if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) + deviceId = std::atoi(envDevice); + ACL_CHECK(aclrtSetDevice(deviceId)); + deviceSet = true; + ACL_CHECK(aclrtCreateStream(&stream)); + ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&scaleHost), scaleBytes)); + ACL_CHECK(aclrtMallocHost((void **)(&outHost), outBytes)); + ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&scaleDevice, scaleBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc((void **)&outDevice, outBytes, ACL_MEM_MALLOC_HUGE_FIRST)); + + ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); + ReadFile("./v2.bin", scaleBytes, scaleHost, scaleBytes); + ReadFile("./v3.bin", outBytes, outHost, outBytes); + ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(scaleDevice, scaleBytes, scaleHost, scaleBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + ACL_CHECK(aclrtMemcpy(outDevice, outBytes, outHost, outBytes, ACL_MEMCPY_HOST_TO_DEVICE)); + LaunchVmi_simdvf_per_block_cast_to_fp8_kernel(srcDevice, scaleDevice, + outDevice, stream); + ACL_CHECK(aclrtSynchronizeStream(stream)); + ACL_CHECK(aclrtMemcpy(outHost, outBytes, outDevice, outBytes, ACL_MEMCPY_DEVICE_TO_HOST)); + WriteFile("./v3.bin", outHost, outBytes); + +cleanup: + aclrtFree(srcDevice); + aclrtFree(scaleDevice); + aclrtFree(outDevice); + aclrtFreeHost(srcHost); + aclrtFreeHost(scaleHost); + aclrtFreeHost(outHost); + if (stream) + aclrtDestroyStream(stream); + if (deviceSet) + aclrtResetDevice(deviceId); + if (aclInited) + aclFinalize(); + return rc; +} diff --git a/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags new file mode 100644 index 0000000000..a79aede1ca --- /dev/null +++ b/test/vpto/cases/vmi/kernels/simdvf-per-block-cast-to-fp8/ptoas.flags @@ -0,0 +1 @@ +--pto-arch a5 --pto-backend=vpto --enable-vmi From 90aa4839e3159fc5dfcf35d34288e5dc810a0faf Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:23:40 +0800 Subject: [PATCH 38/54] Support two-way VMI interleaved memory ops --- .../vmi-layout-assignment-lowering-design.md | 7 + docs/designs/vmi-layout-lowering-cases.md | 15 ++ include/PTO/IR/VMIOps.td | 17 ++ lib/PTO/IR/VMI.cpp | 60 +++++ lib/PTO/Transforms/VMILayoutAssignment.cpp | 11 + lib/PTO/Transforms/VMIToVPTO.cpp | 242 +++++++++++++++++- test/lit/vmi/vmi_interleaved_memory_ops.pto | 56 ++++ .../vmi_interleaved_memory_ops_invalid.pto | 23 ++ 8 files changed, 427 insertions(+), 4 deletions(-) create mode 100644 test/lit/vmi/vmi_interleaved_memory_ops.pto create mode 100644 test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 1bfc46174f..bc16e3bb20 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -558,6 +558,13 @@ dense store: requests contiguous source if the stored value is assigned deinterleaved, baseline assignment inserts ensure_layout at the store use + +two-way interleaved memory ops: + `pto.vmi.deinterleave_load` produces two dense logical streams and requests + contiguous layouts for both results + `pto.vmi.interleave_store` consumes two dense logical streams and requests + contiguous layouts for both inputs + the deinterleave/interleave memory pattern is op semantics, not a VMI layout ``` ### 5.2 Baseline Group Layout Requests diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index ab77a35632..16a87696ab 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -184,6 +184,21 @@ This optimization is legal only for full physical chunks and supported `DINTLV_B8/B16/B32` element widths. Tail and masked loads keep their explicit safe lowering until a masked or guarded `vldsx2` strategy is designed. +Two-way logical interleaved memory access is represented by dedicated VMI ops, +not by exposing assigned layouts in surface IR: + +```mlir +%x, %y = pto.vmi.deinterleave_load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + +pto.vmi.interleave_store %x, %y, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.ptr +``` + +Each VMI value is an ordinary dense logical vector. Layout assignment requests +contiguous layouts for both streams. Lowering maps full-chunk 8/16/32-bit cases +to `vldsx2 DINTLV_B*` and `vstsx2 INTLV_B*`. + ## 3. Lowering Results The following examples use symbolic VPTO names. `PAT_ALL_B*` means an all-true diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index cc29ea4666..87f99898f8 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -547,6 +547,14 @@ def VMILoadOp : VMI_Op<"load", [DeclareOpInterfaceMethods]> { + let summary = "VMI two-way logical deinterleave load"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset); + let results = (outs VMI_VRegTypeConstraint:$low, VMI_VRegTypeConstraint:$high); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` attr-dict `:` type($source) `->` type($low) `,` type($high)"; +} + def VMIGroupLoadOp : VMI_Op<"group_load", [DeclareOpInterfaceMethods]> { let summary = "VMI logical grouped vector load with a row stride between groups"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$row_stride, @@ -614,6 +622,15 @@ def VMIStoreOp : VMI_Op<"store", [DeclareOpInterfaceMethods]> { + let summary = "VMI two-way logical interleave store"; + let arguments = (ins VMI_VRegTypeConstraint:$low, VMI_VRegTypeConstraint:$high, + PtrOrMemRef:$destination, Index:$offset); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "$low `,` $high `,` $destination `[` $offset `]` attr-dict `:` type($low) `,` type($high) `,` type($destination)"; +} + def VMIGroupStoreOp : VMI_Op<"group_store", [DeclareOpInterfaceMethods]> { let summary = "VMI logical grouped vector store with a row stride between groups"; let arguments = (ins VMI_VRegTypeConstraint:$value, PtrOrMemRef:$destination, diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index b84d63f1c4..97d43cadbc 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -339,6 +339,17 @@ static LogicalResult verifyMemoryElementMatches(Operation *op, Type memoryType, return success(); } +static LogicalResult verifyContiguousIfLayoutAssigned(Operation *op, + VMIVRegType type, + StringRef role) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (layout && !layout.isContiguous()) + return op->emitOpError() + << "requires layout-assigned " << role + << " to use #pto.vmi.layout"; + return success(); +} + static bool isPackedByteGroupStore(Type memoryType, VMIVRegType dataType) { Type memoryElementType = getMemoryElementType(memoryType); if (!memoryElementType) @@ -1522,6 +1533,30 @@ void VMILoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIDeinterleaveLoadOp::verify() { + auto lowType = cast(getLow().getType()); + auto highType = cast(getHigh().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {lowType, highType}, + /*requireSameElement=*/true))) + return failure(); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + lowType, "source"))) + return failure(); + if (failed(verifyContiguousIfLayoutAssigned(getOperation(), lowType, + "low result")) || + failed(verifyContiguousIfLayoutAssigned(getOperation(), highType, + "high result"))) + return failure(); + return success(); +} + +void VMIDeinterleaveLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIGroupLoadOp::verify() { auto resultType = cast(getResult().getType()); if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), @@ -1654,6 +1689,31 @@ void VMIStoreOp::getEffects( effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); } +LogicalResult VMIInterleaveStoreOp::verify() { + auto lowType = cast(getLow().getType()); + auto highType = cast(getHigh().getType()); + if (failed(verifyAllSameVRegShapeAndLayout(getOperation(), + {lowType, highType}, + /*requireSameElement=*/true))) + return failure(); + if (failed(verifyMemoryElementMatches(getOperation(), + getDestination().getType(), lowType, + "destination"))) + return failure(); + if (failed(verifyContiguousIfLayoutAssigned(getOperation(), lowType, + "low input")) || + failed(verifyContiguousIfLayoutAssigned(getOperation(), highType, + "high input"))) + return failure(); + return success(); +} + +void VMIInterleaveStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getDestinationMutable()); +} + LogicalResult VMIGroupStoreOp::verify() { auto valueType = cast(getValue().getType()); if (!isPackedByteGroupStore(getDestination().getType(), valueType) && diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 0350c48166..113529607a 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -1206,6 +1206,12 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout(load.getLow(), getContiguousLayout(), op)) || + failed(setNaturalLayout(load.getHigh(), getContiguousLayout(), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { requestDataUse(load.getPassthruMutable(), getContiguousLayout()); if (failed( @@ -1262,6 +1268,11 @@ struct LayoutSolver { requestDataUse(store.getValueMutable(), getContiguousLayout()); return WalkResult::advance(); } + if (auto store = dyn_cast(op)) { + requestDataUse(store.getLowMutable(), getContiguousLayout()); + requestDataUse(store.getHighMutable(), getContiguousLayout()); + return WalkResult::advance(); + } if (auto store = dyn_cast(op)) { requestDataUse( store.getValueMutable(), diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 2308337386..e230cc0526 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -52,6 +52,9 @@ using namespace mlir::pto; namespace { +std::optional getX2MemoryDistToken(Type elementType, + StringRef prefix); + bool isVMIType(Type type) { return isa(type); } bool containsVMIType(Type type) { @@ -1163,6 +1166,40 @@ checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, "; fallback decision: " + accessPlan.fallbackDecision.reason); } +LogicalResult checkSupportedDeinterleaveLoadShape( + const VMITargetCapabilityRegistry &capabilities, + VMIDeinterleaveLoadOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lowType = cast(op.getLow().getType()); + auto highType = cast(op.getHigh().getType()); + VMILayoutAttr lowLayout = lowType.getLayoutAttr(); + VMILayoutAttr highLayout = highType.getLayoutAttr(); + if (!lowLayout || !highLayout || !lowLayout.isContiguous() || + !highLayout.isContiguous()) + return fail("requires assigned contiguous low/high result layouts"); + if (lowType.getElementCount() != highType.getElementCount() || + lowType.getElementType() != highType.getElementType()) + return fail("requires matching low/high result shape and element type"); + if (!getX2MemoryDistToken(lowType.getElementType(), "DINTLV")) + return fail("requires 8/16/32-bit element type for vldsx2 DINTLV"); + + VMIMemoryAccessPlan accessPlan = buildReadAccessPlan( + capabilities, op.getSource(), op.getSource().getType(), lowType, + getConstantIndexValue(op.getOffset()), VMIMemoryValidMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(lowType, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + return success(); +} + LogicalResult checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, VMIVRegType type, Value destination, @@ -1206,6 +1243,43 @@ checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, fullChunkReason + ", materialization " + materializationReason); } +LogicalResult checkSupportedInterleaveStoreShape( + const VMITargetCapabilityRegistry &capabilities, + VMIInterleaveStoreOp op, std::string *reason) { + auto fail = [&](const Twine &message) -> LogicalResult { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto lowType = cast(op.getLow().getType()); + auto highType = cast(op.getHigh().getType()); + VMILayoutAttr lowLayout = lowType.getLayoutAttr(); + VMILayoutAttr highLayout = highType.getLayoutAttr(); + if (!lowLayout || !highLayout || !lowLayout.isContiguous() || + !highLayout.isContiguous()) + return fail("requires assigned contiguous low/high input layouts"); + if (lowType.getElementCount() != highType.getElementCount() || + lowType.getElementType() != highType.getElementType()) + return fail("requires matching low/high input shape and element type"); + if (!getX2MemoryDistToken(lowType.getElementType(), "INTLV")) + return fail("requires 8/16/32-bit element type for vstsx2 INTLV"); + + VMIMemoryAccessPlan accessPlan = + buildWriteAccessPlan(capabilities, op.getDestination(), + op.getDestination().getType(), lowType, + VMIMemoryWriteMaskKind::AllTrue); + if (!accessPlan.targetCapability.isSupported()) + return fail(accessPlan.targetCapability.reason); + if (failed(checkSupportedMaskableVReg(capabilities, lowType, reason))) + return failure(); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(lowType, &fullChunkReason))) + return fail(Twine("requires full physical chunks; ") + fullChunkReason); + return success(); +} + FailureOr getGroupSizeFromNumGroups(VMIVRegType type, int64_t numGroups, std::string *reason = nullptr) { @@ -4142,6 +4216,73 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIDeinterleaveLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIDeinterleaveLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIDeinterleaveLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto lowVMIType = cast(op.getLow().getType()); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "deinterleave_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "deinterleave_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + + FailureOr lanesPerPart = + getDataLanesPerPart(lowVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires known physical lanes per part"); + + std::optional dist = + getX2MemoryDistToken(lowVMIType.getElementType(), "DINTLV"); + if (!dist) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires vldsx2 DINTLV element support"); + + TypeRange lowTypes = adaptor.getResultMapping().getConvertedTypes(0); + TypeRange highTypes = adaptor.getResultMapping().getConvertedTypes(1); + if (lowTypes.size() != highTypes.size()) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires matching low/high physical arity"); + + SmallVector lows; + SmallVector highs; + lows.reserve(lowTypes.size()); + highs.reserve(highTypes.size()); + for (size_t index = 0, e = lowTypes.size(); index < e; ++index) { + Type lowType = lowTypes[index]; + Type highType = highTypes[index]; + if (lowType != highType) + return rewriter.notifyMatchFailure( + op, "deinterleave_load requires matching low/high physical types"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, static_cast(index) * 2 * *lanesPerPart, + rewriter); + auto load = rewriter.create( + op.getLoc(), lowType, highType, *source, chunkOffset, + rewriter.getStringAttr(*dist)); + lows.push_back(load.getLow()); + highs.push_back(load.getHigh()); + } + + SmallVector results; + results.reserve(lows.size() + highs.size()); + results.append(lows); + results.append(highs); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIGroupLoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -4751,6 +4892,72 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { } }; +struct OneToNVMIInterleaveStoreOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIInterleaveStoreOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIInterleaveStoreOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto lowVMIType = cast(op.getLow().getType()); + FailureOr lanesPerPart = + getDataLanesPerPart(lowVMIType.getElementType()); + if (failed(lanesPerPart)) + return rewriter.notifyMatchFailure( + op, "interleave_store requires known physical lanes per part"); + + std::optional dist = + getX2MemoryDistToken(lowVMIType.getElementType(), "INTLV"); + if (!dist) + return rewriter.notifyMatchFailure( + op, "interleave_store requires vstsx2 INTLV element support"); + + FailureOr destination = + getSingleValue(op, adaptor.getDestination(), + "interleave_store destination must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "interleave_store offset must convert to one value", + rewriter); + if (failed(destination) || failed(offset)) + return failure(); + + ValueRange lowParts = adaptor.getLow(); + ValueRange highParts = adaptor.getHigh(); + if (lowParts.size() != highParts.size()) + return rewriter.notifyMatchFailure( + op, "interleave_store requires matching low/high physical arity"); + + for (size_t index = 0, e = lowParts.size(); index < e; ++index) { + Value low = lowParts[index]; + Value high = highParts[index]; + if (low.getType() != high.getType()) + return rewriter.notifyMatchFailure( + op, "interleave_store requires matching low/high physical types"); + auto vregType = dyn_cast(low.getType()); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "interleave_store value must be vreg"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), vregType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for interleave_store mask"); + Value chunkOffset = createChunkOffset( + op.getLoc(), *offset, static_cast(index) * 2 * *lanesPerPart, + rewriter); + rewriter.create(op.getLoc(), low, high, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + } + + rewriter.eraseOp(op); + return success(); + } +}; + struct OneToNVMIGroupStoreOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -7671,12 +7878,13 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, - OneToNVMIGroupLoadOpPattern, OneToNVMIGroupSlotLoadOpPattern, - OneToNVMIStrideLoadOpPattern, + OneToNVMIDeinterleaveLoadOpPattern, OneToNVMIGroupLoadOpPattern, + OneToNVMIGroupSlotLoadOpPattern, OneToNVMIStrideLoadOpPattern, OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, - OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, - OneToNVMIStrideStoreOpPattern, OneToNVMIScatterOpPattern, + OneToNVMIInterleaveStoreOpPattern, OneToNVMIGroupStoreOpPattern, + OneToNVMIMaskedStoreOpPattern, OneToNVMIStrideStoreOpPattern, + OneToNVMIScatterOpPattern, OneToNVMIBinaryOpPattern, OneToNVMIBinaryOpPattern, OneToNVMIBinaryOpPattern, @@ -8454,6 +8662,19 @@ verifySupportedVMIToVPTOOps(ModuleOp module, op, "pto.vmi.load", cast(load.getResult().getType()), load.getSource(), getConstantIndexValue(load.getOffset())); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedDeinterleaveLoadShape(capabilities, load, &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.deinterleave_load lowers through pto.vldsx2 only for " + "matching contiguous full low/high result chunks with a supported " + "UB source and 8/16/32-bit element type (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { std::string reason; if (succeeded(checkSupportedStrideLoadShape(capabilities, load, &reason))) @@ -8550,6 +8771,19 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << reason << ")"; return WalkResult::interrupt(); } + if (auto store = dyn_cast(op)) { + std::string reason; + if (succeeded( + checkSupportedInterleaveStoreShape(capabilities, store, &reason))) + return WalkResult::advance(); + store.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.interleave_store lowers through pto.vstsx2 only for " + "matching contiguous full low/high input chunks with a supported " + "UB destination and 8/16/32-bit element type (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto store = dyn_cast(op)) { std::string reason; if (succeeded( diff --git a/test/lit/vmi/vmi_interleaved_memory_ops.pto b/test/lit/vmi/vmi_interleaved_memory_ops.pto new file mode 100644 index 0000000000..26aa6324ef --- /dev/null +++ b/test/lit/vmi/vmi_interleaved_memory_ops.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_deinterleave_load( + %src: !pto.ptr, + %offset: index) -> (!pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>) { + %low, %high = pto.vmi.deinterleave_load %src[%offset] + : !pto.ptr -> !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + return %low, %high : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> + } + + func.func @vmi_interleave_store( + %low: !pto.vmi.vreg<64xf32>, + %high: !pto.vmi.vreg<64xf32>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.interleave_store %low, %high, %dst[%offset] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_deinterleave_load( +// ASSIGN-SAME: -> (!pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout>) +// ASSIGN: %[[LOW:.*]], %[[HIGH:.*]] = pto.vmi.deinterleave_load +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: return %[[LOW]], %[[HIGH]] : !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_deinterleave_load( +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: return %[[LOW]], %[[HIGH]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_interleave_store( +// ASSIGN-SAME: %[[LOW_ARG:[^:]+]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-SAME: %[[HIGH_ARG:[^:]+]]: !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.interleave_store %[[LOW_ARG]], %[[HIGH_ARG]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_interleave_store( +// LOWER: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" +// LOWER: pto.vstsx2 %arg0, %arg1, %arg2[%arg3], "INTLV_B32", %[[MASK]] +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto b/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto new file mode 100644 index 0000000000..81aaa858ae --- /dev/null +++ b/test/lit/vmi/vmi_interleaved_memory_ops_invalid.pto @@ -0,0 +1,23 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s 2>&1 | FileCheck %s + +module { + func.func @vmi_interleave_store_mismatch( + %low: !pto.vmi.vreg<64xf32>, + %high: !pto.vmi.vreg<128xf32>, + %dst: !pto.ptr, + %offset: index) { + pto.vmi.interleave_store %low, %high, %dst[%offset] + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<128xf32>, !pto.ptr + return + } +} + +// CHECK: 'pto.vmi.interleave_store' op requires all VMI data values to have the same logical lane count From c49d7f53148c0a4191004451f0e3626a45429328 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 14:26:15 +0800 Subject: [PATCH 39/54] Add VMI group broadcast E2B lowering --- .../vmi-e2b-scale-broadcast-optimization.md | 992 ++++++++++++++++++ docs/designs/vmi-introduction.md | 7 +- include/PTO/IR/VMIOps.td | 9 + include/PTO/Transforms/VMILayoutSupport.h | 14 + lib/PTO/IR/VMI.cpp | 25 + lib/PTO/Transforms/PTOValidateVMIIR.cpp | 11 + lib/PTO/Transforms/VMILayoutAssignment.cpp | 99 ++ lib/PTO/Transforms/VMILayoutSupport.cpp | 71 ++ lib/PTO/Transforms/VMIToVPTO.cpp | 155 ++- ...ssignment_group_broadcast_load_e2b_b16.pto | 39 + ...ment_group_slot_broadcast_load_e2b_b16.pto | 74 ++ ...assignment_group_slot_broadcast_no_e2b.pto | 29 + ..._slot_broadcast_partial_packet_invalid.pto | 25 + ...i_to_vpto_group_broadcast_load_e2b_b16.pto | 43 + ..._broadcast_load_e2b_b16_stride_invalid.pto | 25 + 15 files changed, 1613 insertions(+), 5 deletions(-) create mode 100644 docs/designs/vmi-e2b-scale-broadcast-optimization.md create mode 100644 test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto create mode 100644 test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto create mode 100644 test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto diff --git a/docs/designs/vmi-e2b-scale-broadcast-optimization.md b/docs/designs/vmi-e2b-scale-broadcast-optimization.md new file mode 100644 index 0000000000..9777162f3b --- /dev/null +++ b/docs/designs/vmi-e2b-scale-broadcast-optimization.md @@ -0,0 +1,992 @@ +# VMI E2B Scale Broadcast Optimization Study + +本文推演 VMI 是否能把 block quant 中的 scale broadcast 自动优化成 +`E2B_B16` load。结论是: + +```text +group_slot_load + group_broadcast 足以表达逻辑语义。 + +它不足以单独触发 E2B,因为 E2B 是某个 physical chunk layout 下的 +materialization,不是 dense logical broadcast 的直接 lowering。 + +如果后续 layout 已经由 consumer requirement 或 target-specific layout +optimization 选成 E2B-compatible 形态,vmi-to-vpto 可以把对应 chunk lower +成 E2B。 + +如果想从普通 dense quant IR 自动得到 CCE 的 DINTLV/E2B 形态,需要一个 +target-specific layout optimization/cost selection 阶段整体选择这套计划。 +``` + +## 1. Logical Quant Semantics + +`ComputeY1ToFP8` 的 surface VMI 语义应保持 dense quant: + +```text +for i in 0..255: + y[i] = fp8(x[i] * scale[i / 32]) +``` + +也就是 8 个 scale,每个覆盖 32 个 dense logical lanes: + +```text +s0 x32, s1 x32, ..., s7 x32 +``` + +对应 VMI 形态是: + +```text +%x = pto.vmi.load %x_base[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + +%scale_slots = pto.vmi.group_slot_load %scale_base[%scale_off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xbf16> + +%scale = pto.vmi.group_broadcast %scale_slots {num_groups = 8} + : !pto.vmi.vreg<8xbf16> -> !pto.vmi.vreg<256xbf16> +``` + +This form is the canonical logical IR. The source scale should be BF16 payload, +not FP16, because the CCE implementation loads `uint16_t` values and later +reinterprets them as `vector_bf16`. + +`num_groups = 16` would express a different algorithm: + +```text +16 scale values, each covering 16 dense lanes +``` + +That is not equivalent unless the input memory redundantly stores +`s0, s0, s1, s1, ...`, which is not what the CCE kernel does. + +## 2. E2B_B16 Semantics + +`E2B_B16` is a VPTO load distribution mode. For a b16 destination register it +loads 8 source elements and expands each one to 16 consecutive destination +lanes: + +```text +dst[j] = src[floor(j / 16)] for j = 0..127 +``` + +The result is: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +So `E2B_B16` does not directly materialize the dense VMI broadcast +`8 -> 256`. It materializes one 128-lane physical view that becomes useful only +after the x data and later f32 computation have been split into compatible +physical chunks. + +## 3. Why CCE Can Use E2B + +The CCE FP16 path uses a physical implementation shape like: + +```text +vlds(x0F16, x1F16, xHalf, stride, DINTLV_B16, POST_UPDATE) +vlds(scaleForMulFP16, scale_base, 0, E2B_B16) + +vcvt(x0_even_f32, x0F16, PART_EVEN) +vcvt(x0_odd_f32, x0F16, PART_ODD) +vcvt(x1_even_f32, x1F16, PART_EVEN) +vcvt(x1_odd_f32, x1F16, PART_ODD) + +vcvt(scale_f32, (vector_bf16 &)scaleForMulFP16, PART_EVEN) +``` + +`DINTLV_B16` splits the dense 256-element row into two 128-lane physical streams. +After each stream is converted from f16 to f32, the computation is effectively +four 64-lane f32 chunks: + +```text +x0 even part +x0 odd part +x1 even part +x1 odd part +``` + +For every one of those chunks, the needed scale pattern is: + +```text +s0 x8, s1 x8, ..., s7 x8 +``` + +`E2B_B16` produces: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +Then `vcvt PART_EVEN` produces: + +```text +s0 x8, s1 x8, ..., s7 x8 +``` + +Because every scale value is duplicated in adjacent even/odd b16 positions, +`PART_EVEN` and `PART_ODD` would produce the same f32 scale chunk. The CCE code +computes the scale chunk once and reuses it for all four x chunks. + +## 4. What Is A Legal Automatic Optimization? + +The following rewrite is not legal as a standalone local rule: + +```text +group_slot_load + group_broadcast(8 -> 256) => E2B_B16 +``` + +It is invalid because the left side is a dense 256-lane logical value, while +`E2B_B16` produces a 128-lane physical value with a different lane repetition +count. + +A legal E2B lowering must be conditional on the assigned physical layout: + +```text +if the broadcasted scale value is required in physical chunks where each chunk +needs s0 x16 ... s7 x16 at b16 width, or s0 x8 ... s7 x8 after bf16->f32, +then that chunk may be materialized with E2B_B16. +``` + +In other words: + +```text +group_slot_load + group_broadcast + is the logical source pattern + +consumer-required or target-selected layout + determines whether any physical chunk is E2B-compatible + +vmi-to-vpto + lowers only those compatible chunks to E2B +``` + +`group_slot_load` alone cannot lower to E2B. A group-slot value has only group +slots as semantic lanes. `E2B_B16` already produces broadcasted physical lanes. +The `group_broadcast` use is required to justify reading those lanes. + +## 5. Layout Selection Boundary + +Deinterleaved layout must not be inferred only because E2B would be cheaper. +The selected layout must be explicit before `vmi-to-vpto`. That layout can come +from either side: + +```text +consumer requirement: + a later op requires a particular layout. + +producer natural layout: + the producing op has a declared, deterministic natural layout that is legal + for all of its uses. +``` + +`group_broadcast` is a materialization op, so it may define or participate in an +E2B-friendly natural layout when that layout is part of the declared layout +contract. That is still a layout-assignment decision, not a hidden +`vmi-to-vpto` peephole. Do not reuse `block_elems` as an ad-hoc broadcast split +knob; `block_elems` belongs to the dense deinterleaved layout contract and has +existing producer/consumer meanings. + +Baseline layout assignment may still choose conservative contiguous layouts even +when a target-specific fused implementation exists. + +Therefore this optimization has two valid implementation levels. + +### 5.1 Compatible-Layout Lowering Shortcut + +If some earlier layout pass has already assigned an E2B-compatible physical +layout, `vmi-to-vpto` may lower the scale chunk with `E2B_B16`. + +This is a local deterministic lowering. It does not discover the CCE plan by +itself. It only avoids a generic `vsldb + vselr` materialization when the +assigned layout has already made the required physical chunk shape explicit. + +### 5.2 Producer Natural Layout + +For simple broadcasts, the producer itself may choose an E2B-friendly natural +layout when that layout satisfies every use. + +Example for b16, using an existing DINTLV-like element-parity layout: + +```text +logical 1 -> 32: + s0 x32 + +layout: + deinterleaved = 2, block_elems = 1 + +physical part 0: + s0 x16 + +physical part 1: + s0 x16 +``` + +The two physical parts can share one E2B materialization or use two identical +E2B materializations. This is a general layout choice for the broadcast result, +not a quant-specific graph rewrite. + +For a uniform `1 -> 32` or per-group `x32` broadcast, `deinterleaved = 2, +block_elems = 1` yields 16 lanes of the same group per physical part and is +closer to an even/odd `DINTLV_B16` data layout. + +For the MX quant scale: + +```text +logical 8 -> 256: + s0 x32, s1 x32, ..., s7 x32 + +layout: + deinterleaved = 2, block_elems = 1 + +physical part 0: + s0 x16, s1 x16, ..., s7 x16 + +physical part 1: + s0 x16, s1 x16, ..., s7 x16 +``` + +Each physical part is directly `E2B_B16`-compatible. +The implementation should run the E2B compatibility query over the assigned lane +mapping. It should not infer a new meaning for `block_elems`. + +### 5.3 Target-Specific Layout Optimization + +To automatically discover the complete CCE plan from canonical dense quant IR, +add an optional target-specific layout optimization before `vmi-to-vpto`. + +That pass may select a cheaper equivalent implementation for the whole quant +subgraph: + +```text +dense x load +f16/bf16 -> f32 conversion +scale group_slot_load + group_broadcast 8 -> 256 +scale bf16 -> f32 conversion +mul +fp32 -> fp8 conversion/store +``` + +The pass must rewrite or annotate the VMI layout-assigned IR so that +`vmi-to-vpto` no longer has to infer the plan from context. + +Expected selected physical plan: + +```text +x load: + vlds DINTLV_B16 into two b16 streams + +scale load: + vlds E2B_B16 into one b16 stream + +scale conversion: + vcvt PART_EVEN into one 64-lane f32 stream + +mul: + reuse that scale f32 stream for the four x f32 chunks +``` + +This is an optimization, not a correctness requirement. If the optimizer does +not fire, the canonical dense VMI program still has a valid generic lowering. + +## 6. Candidate Match Preconditions + +A target-specific optimization may match the CCE-style scale pattern only under +strict conditions: + +```text +scale_slots: + pto.vmi.group_slot_load + num_groups = 8 + source_group_stride = 1 + source element width = 16 bits + semantic type is bf16 or a bitcastable ui16 payload later interpreted as bf16 + +scale broadcast: + pto.vmi.group_broadcast + same num_groups = 8 + dense logical result has 256 b16 lanes for this case + +scale conversion: + bf16 -> f32 + conversion has no rounding/exception behavior that distinguishes duplicated + even and odd source lanes + +x path: + dense logical row has 256 f16 or bf16 lanes + the target plan can legally compute the row as four 64-lane f32 chunks + +uses: + no user observes the intermediate dense scale layout in a way that prevents + rematerialization or chunk reuse +``` + +The optimization should reject or skip the pattern if any of these conditions are +not proven. + +## 7. Correctness Sketch + +Let the logical dense lane be `i`. + +The canonical VMI scale value is: + +```text +scale_dense[i] = s[floor(i / 32)] +``` + +The CCE physical decomposition maps each dense lane into one of four f32 chunks. +For a chunk-local f32 lane `k`: + +```text +dense lane = 4 * k + delta +delta in {0, 1, 2, 3} +``` + +Then: + +```text +floor((4 * k + delta) / 32) = floor(k / 8) +``` + +So every f32 chunk needs: + +```text +scale_chunk[k] = s[floor(k / 8)] +``` + +`E2B_B16` plus `vcvt PART_EVEN` gives: + +```text +e2b_b16[j] = s[floor(j / 16)] for j = 0..127 +scale_f32[k] = e2b_b16[2 * k] + = s[floor((2 * k) / 16)] + = s[floor(k / 8)] +``` + +That matches the required `scale_chunk[k]` for all four f32 chunks. + +## 8. Recommendation + +Prefer adding a target-agnostic VMI `group_broadcast_load` logical memory op if +we want to make this optimization robust and local. The op should mean: + +```text +load one source value per logical group, then broadcast that value to every lane +in the group. +``` + +It must not mean `E2B`. `E2B_B16` is only one possible lowering when the +assigned layout is compatible. + +The unfused logical IR remains valid: + +```text +group_slot_load + group_broadcast +``` + +but a canonicalization/layout-prep pass may fuse it to: + +```text +group_broadcast_load +``` + +when the group-slot value has no separate semantic users. + +Then implement E2B support in phases: + +```text +1. Ensure the example PTO uses the correct logical semantics: + bf16 scale, num_groups = 8, dense 8 -> 256 broadcast. + +2. Add group_broadcast_load as a logical VMI memory op, plus canonicalization + from group_slot_load + group_broadcast when legal. + +3. Add a compatible-layout lowering shortcut: + when layout assignment already exposes an E2B-compatible chunk, lower the + group_broadcast_load chunk with vlds E2B_B16. + +4. Add an optional target-specific quant layout optimization: + recognize the whole dense quant subgraph and select the DINTLV/E2B plan when + it is legal and profitable. +``` + +This keeps VMI logical semantics independent from physical layout, while still +leaving a clear path to recover the CCE optimization automatically. + +## 9. Generalized E2B Broadcast Optimization + +The scale case above is one instance of a broader rule: E2B is a physical +materialization primitive for a packet of repeated group slots. It is not tied +to MX quant, but its legality depends on the physical chunk layout and on the +load distribution's carrier element width. + +### 9.1 E2B As A Packet Primitive + +For the verified `B16` case: + +```text +E2B_B16 packet: + source slots per packet = 8 + destination lanes per packet = 128 b16 lanes + repeat per source slot = 16 b16 lanes + +dst[lane] = src[base_slot + floor(lane / 16)] +``` + +This can materialize a physical chunk that needs: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +The optimization should reason in terms of physical chunks: + +```text +logical group_broadcast + source group slot for logical lane i = floor(i / logical_group_size) + +assigned physical layout + maps physical chunk lane l to logical lane i(l) + +E2B-compatible chunk + floor(i(l) / logical_group_size) = base_slot + floor(l / 16) +``` + +If this equality holds for a b16 physical chunk, the chunk can be loaded with +`E2B_B16` instead of materializing the broadcast with `vselr`. + +### 9.2 Direct 1 -> 16 + +A logical `1 -> 16` b16 broadcast is directly compatible with one E2B group: + +```text +s0 x16 +``` + +However, `E2B_B16` is naturally an 8-group packet: + +```text +s0 x16, s1 x16, ..., s7 x16 +``` + +So a single `1 -> 16` use may lower to E2B only under one of these conditions: + +```text +packed case: + the compiler can pack eight independent 1 -> 16 broadcasts into one E2B load. + +partial-live case: + only one 16-lane group is live, and the target semantics prove inactive E2B + groups do not require valid source memory or can be safely over-read. + +full-packet case: + the logical IR actually contains eight adjacent groups, even if the current + consumer observes only one group through a layout/mask. +``` + +If these conditions are not proven, `BRC_B16`, `vdup`, or the existing generic +broadcast lowering is safer than E2B. In particular, do not introduce an E2B +load that reads seven extra source values unless the memory safety rule is +explicit. + +### 9.3 1 -> 32 Via Deinterleaved Reuse + +A dense logical `1 -> 32` b16 broadcast does not fit one E2B group in a single +contiguous physical chunk: + +```text +logical: s0 x32 +E2B group: s0 x16 +``` + +It becomes E2B-compatible when the assigned physical layout splits those 32 +logical lanes into two 16-lane physical uses: + +```text +physical use A: s0 x16 +physical use B: s0 x16 +``` + +This split can use the existing DINTLV-like element-parity layout: + +```text +#pto.vmi.layout +``` + +For logical lanes `0..31`, this maps: + +```text +even lanes 0,2,...,30 -> physical part 0 lanes 0..15 +odd lanes 1,3,...,31 -> physical part 1 lanes 0..15 +``` + +Because all 32 logical lanes carry the same `s0`, each part still sees +`s0 x16`. The lowering rule should check the resulting group index function, +not invent a new layout spelling. + +Then the compiler has two valid strategies: + +```text +reuse: + materialize one E2B group/chunk and map both physical uses to the same value. + +duplicate: + materialize the same E2B group twice if reuse would violate scheduling, + lifetime, or destructive-update constraints. +``` + +This is the mechanism behind the MX quant scale case: + +```text +dense logical scale: 8 groups, each x32 +physical f16/bf16 streams: each group appears as x16 per stream +``` + +The optimization is legal only if the 32 logical lanes are split by layout. It +is not legal as a direct E2B chunk load for a contiguous physical chunk that +genuinely needs `s0 x32` inside one chunk; that would require a separate +duplicate/interleave/concat materialization. + +### 9.4 N -> N * 16 And N -> N * 32 + +For b16 group broadcasts with consecutive slots and unit source stride: + +```text +N -> N * 16 +``` + +can be lowered by E2B in packets of 8 groups when the physical chunk sees the +groups in E2B order: + +```text +for base_slot in 0, 8, 16, ... + load src[base_slot : base_slot + 8] with E2B_B16 +``` + +Tail packets require either a proven safe masked/partial E2B form or a generic +fallback. + +For: + +```text +N -> N * 32 +``` + +E2B is profitable when the assigned layout decomposes each 32-lane logical group +into two 16-lane physical uses. That assigned layout may be the +`group_broadcast` producer's natural layout, or it may be required by a +downstream consumer. The lowering then reuses or duplicates the corresponding +E2B materialization for those two uses. This rule extends to: + +```text +N -> N * (16 * F) +``` + +when the layout decomposes each logical group into `F` physical 16-lane uses. + +### 9.5 Type Generalization + +E2B is a carrier-width load distribution. For `E2B_B16`, the load itself is +valid for 16-bit carriers: + +```text +bf16 +f16 +ui16 / si16 payloads +other 16-bit bit patterns whose consumers preserve the intended interpretation +``` + +The optimization must keep type interpretation outside the load: + +```text +bf16 scale + extf to f32: + E2B_B16 may feed vcvt bf16 -> f32. + +f16 broadcast: + E2B_B16 may materialize repeated f16 lanes if the consumer expects f16. + +ui16 payload later bitcast to bf16: + E2B_B16 may load the ui16 carrier, but the later bitcast/interpretation must + remain explicit in VMI or in the selected lowering plan. +``` + +Do not infer a floating-point type from E2B itself. `E2B_B16` only says how UB +bytes are placed into b16 lanes. + +`E2B_B32` is the b32 member of the same distribution family. The VPTO verifier +accepts `E2B_B32`, the ISA docs list E2B for `b16` and `b32`, and CCE quant code +uses `E2B_B32` in FP32 paths. It follows the same 8-source-slot packet rule: + +```text +E2B_B16: 8 source slots * 16 lanes/slot = 128 b16 lanes +E2B_B32: 8 source slots * 8 lanes/slot = 64 b32 lanes +``` + +The implemented E2B broadcast optimization therefore supports: + +```text +b16 contiguous: logical 1 -> 16 +b16 deinterleaved=2: logical 1 -> 32 +b32 contiguous: logical 1 -> 8 +b32 deinterleaved=2: logical 1 -> 16 +``` + +There is no `E2B_B8` in the documented load distribution family, so b8 +broadcasts should use other distributions or generic materialization. + +### 9.6 Broadcast Generalization + +E2B can optimize `pto.vmi.group_broadcast` when all of these are true: + +```text +source: + group slots come from consecutive memory slots + source_group_stride = 1 + slot type matches the E2B carrier width + +broadcast: + each physical chunk needs a run-length pattern compatible with the E2B repeat + count for that carrier width + +layout: + the run-length pattern is visible in the assigned layout before vmi-to-vpto + +uses: + rematerializing or reusing the E2B packet does not change observable memory or + arithmetic semantics +``` + +E2B is generally not the right primitive for ordinary scalar `pto.vmi.broadcast` +unless the scalar value is already in memory as an E2B packet or the compiler can +pack several independent scalar broadcasts into one E2B load. For a scalar +stored once in memory and needed in every lane, `BRC_B16/B32`, `BRC_BLK`, or a +register `vdup` is usually the more direct representation. + +### 9.7 Implementation Shape + +The recommended implementation order is: + +```text +1. Keep VMI semantics canonical: + group_slot_load + group_broadcast is the desugared meaning. + +2. Optionally canonicalize to group_broadcast_load: + this keeps memory source and broadcast semantics in one local op. + +3. Add an E2B compatibility query over assigned physical chunks: + given source slots, result layout, carrier width, and live lanes, answer + whether a chunk's group-index function is E2B-shaped. + +4. Lower compatible chunks to E2B packets: + generate one E2B load per needed packet, or reuse an existing packet when + multiple physical uses require identical contents. + +5. Add a later target-specific layout optimizer: + it may choose layouts that expose E2B-compatible chunks, but only by + rewriting/annotating layout-assigned VMI before vmi-to-vpto. +``` + +The compatibility query should return a reason when it rejects a candidate: + +```text +non-unit source stride +non-consecutive group slots +unsupported carrier width +tail packet lacks safe partial E2B semantics +physical lane mapping is not E2B-shaped +extra source memory read would be unsafe +consumer observes a different dense layout +``` + +This keeps the optimization auditable and prevents E2B from becoming an implicit +layout-changing peephole. + +## 10. Recognition, Solidification, Propagation, Lowering + +This section describes how an implementation should carry the optimization from +canonical VMI to VPTO without making `vmi-to-vpto` rediscover hidden context. + +### 10.1 Recognize Information + +Run recognition after hard layout assignment, when every relevant value already +has an explicit layout. + +Recognize the source shape: + +```text +%slots = pto.vmi.group_slot_load %base[%off], %stride {num_groups = G} +%bcast = pto.vmi.group_broadcast %slots {num_groups = G} +``` + +or the already-fused form: + +```text +%bcast = pto.vmi.group_broadcast_load %base[%off], %stride {num_groups = G} +``` + +Collect candidate facts: + +```text +source memory: + base pointer + offset + source_group_stride + element carrier width + memory element type + +logical broadcast: + num_groups = G + logical lanes = N + logical group size S = N / G + +assigned result layout: + physical arity + physical lanes per chunk + logical lane mapped to each physical lane + +uses: + whether the broadcast feeds elementwise ops, extf/truncf, stores, or multiple + independent consumers +``` + +Then compute an E2B packet plan per physical chunk. For `E2B_B16`, a physical +chunk is compatible when: + +```text +group_index_for_physical_lane(l) = base_slot + floor(l / 16) +``` + +for all live lanes in that chunk. + +Reject the candidate if: + +```text +source_group_stride != 1 +source slots are not consecutive +carrier width is unsupported +the assigned layout does not produce E2B-shaped chunks +tail/partial packet would read memory that is not proven valid +the group_slot_load has other non-rematerializable users +``` + +This recognition is an analysis step. It must not silently change layouts. + +### 10.2 Solidify Information + +`vmi-to-vpto` should not have to look at an arbitrary +`group_slot_load -> group_broadcast` use-def chain and decide to suppress one +load while replacing another op with E2B. The optimization pass must solidify +the decision in the layout-assigned IR. + +The preferred solidification is a target-agnostic logical memory op: + +```text +%bcast = pto.vmi.group_broadcast_load %base[%off], %stride {num_groups = G} + : !pto.ptr -> !pto.vmi.vreg +``` + +Semantic definition: + +```text +group_size = N / G +for logical lane i: + group = floor(i / group_size) + result[i] = base[off + group * stride] +``` + +This op is not target-specific and does not promise E2B. It is exactly the +fused logical form of: + +```text +%slots = pto.vmi.group_slot_load %base[%off], %stride {num_groups = G} +%bcast = pto.vmi.group_broadcast %slots {num_groups = G} +``` + +The fused op makes lowering local because the memory source, stride, group count, +result type, and assigned layout are all available on one op. A generic lowering +can still materialize it with `vsldb + vselr`; an optimized lowering may choose +`E2B_B16` for compatible physical chunks. + +The current implementation is intentionally narrower: because +`group_broadcast_load` does not yet have a generic `vsldb + vselr` lowering, +layout assignment fuses `group_slot_load + group_broadcast` only when the fused +op is already an E2B-compatible b16 candidate. Non-E2B shapes stay in the +unfused form and continue to use the existing `group_slot_load` plus +`group_broadcast` lowering path. + +Canonicalization rules: + +```text +group_slot_load + group_broadcast -> group_broadcast_load + when the group_slot_load has exactly that broadcast use, or when cloning the + load is legal and profitable for that use. + +group_broadcast_load -> group_slot_load + group_broadcast + remains a valid conceptual expansion for verification, documentation, and + generic fallback reasoning. +``` + +Solidification must preserve semantics for multi-use values: + +```text +if all uses consume only the broadcasted value: + replace with one shared group_broadcast_load. + +if only one use can benefit from the fused form: + clone/rematerialize that use-site load as group_broadcast_load and keep the + original group_slot_load for other users. + +if the group_slot_load itself has semantic group-slot users: + do not delete it; add a separate group_broadcast_load only if the extra memory + read is legal or if load cloning is otherwise proven safe. +``` + +### 10.3 Propagate Information + +After solidification, propagation should use ordinary VMI layout rules whenever +possible: + +```text +elementwise ops: + preserve the assigned layout when operands agree. + +ensure_layout: + makes layout transitions explicit when one use needs E2B-compatible chunks and + another use needs a different layout. + +rematerialization: + may clone group_broadcast_load per use-site instead of forcing a single layout + for all consumers. +``` + +For casts, propagation may need a targeted rule. The important MX quant case is: + +```text +E2B_B16 gives: + s0 x16, s1 x16, ..., s7 x16 + +bf16 -> f32 PART_EVEN gives: + s0 x8, s1 x8, ..., s7 x8 +``` + +If multiple f32 physical chunks require that same `s0 x8 ... s7 x8` pattern, +the post-assignment plan may mark them as the same rematerialized value. The +lowerer can then generate one `vcvt PART_EVEN` and map several logical physical +chunks to the same VPTO value. + +This reuse fact must be derived from the assigned lane mapping and the E2B packet +plan. It must not rely on a later CSE pass accidentally proving the duplicate. + +### 10.4 Implement Lowering + +`vmi-to-vpto` should lower `group_broadcast_load` locally. It may choose E2B +only when the op's assigned layout and source facts produce an explicit +E2B-compatible packet plan. + +For each E2B packet: + +```text +1. compute the source pointer: + base + packet_base_slot + +2. emit: + pto.vlds {dist = "E2B_B16"} + +3. map the emitted VPTO value to the physical result chunk(s) recorded in the + group_broadcast_load packet plan. +``` + +For `1 -> 32` under `deinterleaved = 2, block_elems = 1`: + +```text +logical group: + s0 x32 + +physical part 0: + s0 x16 + +physical part 1: + s0 x16 + +lowering: + emit one E2B packet if reuse is legal, or two identical E2B packets if + scheduling/lifetime constraints require duplication. +``` + +For MX quant scale after bf16->f32: + +```text +1. emit E2B_B16 for the b16 scale packet. +2. emit vcvt PART_EVEN to produce the f32 packet. +3. map that f32 packet to every physical f32 chunk whose lane mapping requires + s0 x8, s1 x8, ..., s7 x8. +4. lower mulf normally using the assigned physical chunks. +``` + +### 10.5 Where Layout Choices Happen + +There are three levels of optimization: + +```text +level 0: no E2B + canonical group_broadcast lowers through generic vselr materialization. + +level 1: E2B for already-compatible layouts + recognition sees the assigned layout is E2B-shaped and solidifies an E2B + materialization. + +level 2: choose E2B-compatible layouts + an optional layout optimization changes/rematerializes layouts before + recognition, for example selecting deinterleaved=2/block_elems=1 for a + broadcast use when all consumers can accept that layout. +``` + +The full CCE-like optimization for `ComputeY1ToFP8` is level 2: + +```text +x path: + select DINTLV-compatible layout for the dense x load/cast path. + +scale path: + select an E2B-compatible broadcast materialization. + +compute path: + keep mul/trunc/store in the selected physical chunk layout or insert explicit + layout materialization where required. +``` + +### 10.6 Test Plan + +Add focused tests in phases: + +```text +positive: + bf16 group_slot_load stride=1 + group_broadcast 8->256 assigned to + deinterleaved=2/block_elems=1 lowers scale chunks with E2B_B16. + +positive: + f16 1->16 or packed 8*(1->16) lowers to E2B only when source memory safety is + proven by full packet or supported partial semantics. + +positive: + 1->32 assigned to deinterleaved=2/block_elems=1 maps two physical uses to one + E2B packet or to two explicit duplicate packets. + +positive: + f32 1->8 lowers to E2B_B32, and f32 1->16 under deinterleaved=2/block_elems=1 + maps two physical uses to one E2B_B32 packet. + +negative: + source_group_stride != 1 falls back or diagnoses the E2B optimization. + +negative: + non-E2B-shaped assigned layout falls back to generic group_broadcast lowering. + +negative: + partial packet without proven safe memory read does not emit E2B. + +deferred: + E2B_B32 remains disabled until simulator/spec tests confirm the exact lane + mapping. +``` diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index 8c61be40a7..e1cc5974f1 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -380,9 +380,10 @@ group_reduce: cast: widening/narrowing 根据 cast support 决定 source request 和 result layout。 -group_load / group_slot_load: +group_load / group_slot_load / group_broadcast_load: result 根据 group size、row stride 和目标能力选择 contiguous、deinterleaved - 或 group_slots。 + 或 group_slots。group_broadcast_load 表达“每个 logical group load 一个值并 + 广播到组内 lanes”的逻辑语义;E2B 只是兼容 layout 下的一种 lowering。 stride_load: result 是 contiguous。block/repeat stride 只描述 memory address map, @@ -588,7 +589,7 @@ constant_mask 这个 pass 不 rematerialize: ```text -load / masked_load / group_load / group_slot_load +load / masked_load / group_load / group_slot_load / group_broadcast_load stride_load reduce / group_reduce control-flow results diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index 87f99898f8..eb8ad6e98c 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -573,6 +573,15 @@ def VMIGroupSlotLoadOp : VMI_Op<"group_slot_load", [DeclareOpInterfaceMethods]> { + let summary = "VMI load one scalar value per logical group and broadcast it to group lanes"; + let arguments = (ins PtrOrMemRef:$source, Index:$offset, Index:$source_group_stride, + I64Attr:$num_groups); + let results = (outs VMI_VRegTypeConstraint:$result); + let hasVerifier = 1; + let assemblyFormat = "$source `[` $offset `]` `,` $source_group_stride attr-dict `:` type($source) `->` type($result)"; +} + def VMIStrideLoadOp : VMI_Op<"stride_load", [DeclareOpInterfaceMethods]> { let summary = "VMI block-strided vector load"; let arguments = (ins PtrOrMemRef:$source, Index:$offset, diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index f4fb3744dc..adff29dff9 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -140,6 +140,15 @@ struct VMIGroupBroadcastSupport { VMIGroupBroadcastSupportKind::GroupSlotsVselr; }; +enum class VMIGroupBroadcastLoadSupportKind { + E2BVlds, +}; + +struct VMIGroupBroadcastLoadSupport { + VMIGroupBroadcastLoadSupportKind kind = + VMIGroupBroadcastLoadSupportKind::E2BVlds; +}; + enum class VMITruncFSupportKind { Deinterleaved2F32ToContiguousF16, Deinterleaved4F32ToContiguousF8, @@ -291,6 +300,11 @@ class VMILayoutSupport { int64_t numGroups, std::string *reason = nullptr) const; + FailureOr + getGroupBroadcastLoadSupport(const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastLoadOp op, + std::string *reason = nullptr) const; + FailureOr getTruncFSupport(VMITruncFOp op, std::string *reason = nullptr) const; diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 97d43cadbc..3c25a6f359 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -1597,6 +1597,31 @@ void VMIGroupSlotLoadOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); } +LogicalResult VMIGroupBroadcastLoadOp::verify() { + auto resultType = cast(getResult().getType()); + int64_t numGroups = getNumGroupsAttr().getInt(); + if (numGroups <= 0) + return emitOpError("requires num_groups to be positive"); + if (resultType.getElementCount() % numGroups != 0) + return emitOpError( + "requires num_groups to evenly divide result logical lane count"); + if (failed(verifyMemoryElementMatches(getOperation(), getSource().getType(), + resultType, "source"))) + return failure(); + if (auto resultLayout = resultType.getLayoutAttr()) { + if (resultLayout.isGroupSlots()) + return emitOpError( + "requires layout-assigned result to use a dense VMI layout"); + } + return verifyNumGroups(getOperation(), resultType, numGroups); +} + +void VMIGroupBroadcastLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable()); +} + LogicalResult VMIMaskedLoadOp::verify() { auto maskType = cast(getMask().getType()); auto passthruType = cast(getPassthru().getType()); diff --git a/lib/PTO/Transforms/PTOValidateVMIIR.cpp b/lib/PTO/Transforms/PTOValidateVMIIR.cpp index 6eae21dbeb..16a6b24393 100644 --- a/lib/PTO/Transforms/PTOValidateVMIIR.cpp +++ b/lib/PTO/Transforms/PTOValidateVMIIR.cpp @@ -563,6 +563,17 @@ LogicalResult verifyLayoutSemanticSupport(Operation *op, return success(); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (failed( + supports.getGroupBroadcastLoadSupport(capabilities, load, &reason))) + return emitLayoutSupportContract( + op, diagOS, + "pto.vmi.group_broadcast_load has no registered layout support", + reason); + return success(); + } + if (auto store = dyn_cast(op)) { auto valueType = cast(store.getValue().getType()); VMILayoutAttr layout = valueType.getLayoutAttr(); diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 113529607a..bddede3bca 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -282,6 +282,59 @@ struct LayoutSolver { return VMILayoutAttr::getGroupSlots(ctx, numGroups, /*slots=*/1); } + bool isE2BGroupBroadcastLoadCandidate(VMIVRegType type, Type sourceType, + Value sourceGroupStride, + int64_t numGroups) { + if (numGroups <= 0 || type.getElementCount() % numGroups != 0) + return false; + int64_t groupSize = type.getElementCount() / numGroups; + if (numGroups % 8 != 0) + return false; + + if (!isa(sourceType)) + return false; + unsigned elementBits = getElementBitWidth(type.getElementType()); + if (elementBits != 16 && elementBits != 32) + return false; + int64_t directGroupSize = 256 / elementBits; + if (groupSize != directGroupSize && groupSize != 2 * directGroupSize) + return false; + std::optional strideValue = + getConstantIndexValue(sourceGroupStride); + if (!strideValue || *strideValue != 1) + return false; + + VMILayoutAttr existing = type.getLayoutAttr(); + if (!existing) + return true; + if (groupSize == directGroupSize) + return existing.isContiguous(); + return existing.isDeinterleaved() && existing.getFactor() == 2 && + existing.getBlockElems() == 1; + } + + bool isE2BGroupBroadcastLoadCandidate(VMIGroupBroadcastLoadOp op) { + return isE2BGroupBroadcastLoadCandidate( + cast(op.getResult().getType()), op.getSource().getType(), + op.getSourceGroupStride(), op.getNumGroupsAttr().getInt()); + } + + VMILayoutAttr getPreferredGroupBroadcastLoadLayout( + VMIGroupBroadcastLoadOp op) { + auto type = cast(op.getResult().getType()); + if (VMILayoutAttr existing = type.getLayoutAttr()) + return existing; + + if (!isE2BGroupBroadcastLoadCandidate(op)) + return getContiguousLayout(); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + int64_t groupSize = type.getElementCount() / numGroups; + int64_t directGroupSize = 256 / getElementBitWidth(type.getElementType()); + if (groupSize == directGroupSize) + return getContiguousLayout(); + return VMILayoutAttr::getDeinterleaved(ctx, 2, /*blockElems=*/1); + } + VMILayoutAttr getPreferredGroupBroadcastSourceLayout(Value value, int64_t numGroups) { auto type = dyn_cast(value.getType()); @@ -684,6 +737,43 @@ struct LayoutSolver { return success(); } + LogicalResult fuseGroupSlotBroadcastLoads() { + SmallVector broadcasts; + module.walk([&](VMIGroupBroadcastOp broadcast) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load || !load.getResult().hasOneUse()) + return; + if (load.getNumGroupsAttr().getInt() != + broadcast.getNumGroupsAttr().getInt()) + return; + + if (!isE2BGroupBroadcastLoadCandidate( + cast(broadcast.getResult().getType()), + load.getSource().getType(), load.getSourceGroupStride(), + broadcast.getNumGroupsAttr().getInt())) + return; + broadcasts.push_back(broadcast); + }); + + OpBuilder builder(ctx); + for (VMIGroupBroadcastOp broadcast : broadcasts) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load) + continue; + + builder.setInsertionPoint(broadcast); + auto fused = builder.create( + broadcast.getLoc(), broadcast.getResult().getType(), + load.getSource(), load.getOffset(), load.getSourceGroupStride(), + broadcast.getNumGroupsAttr()); + broadcast.getResult().replaceAllUsesWith(fused.getResult()); + broadcast.erase(); + if (load->use_empty()) + load.erase(); + } + return success(); + } + LogicalResult addConstraints() { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto maskAnd = dyn_cast(op)) { @@ -1253,6 +1343,13 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } + if (auto load = dyn_cast(op)) { + if (failed(setNaturalLayout( + load.getResult(), + getPreferredGroupBroadcastLoadLayout(load), op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } if (auto load = dyn_cast(op)) { auto resultType = cast(load.getResult().getType()); if (failed( @@ -1898,6 +1995,8 @@ struct LayoutSolver { } LogicalResult run() { + if (failed(fuseGroupSlotBroadcastLoads())) + return failure(); if (failed(commuteTruncFAfterGroupBroadcast())) return failure(); if (failed(collect())) diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 6920b0ca38..c6b64f06b4 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -1020,6 +1020,77 @@ FailureOr VMILayoutSupport::getGroupBroadcastSupport( op.getNumGroupsAttr().getInt(), reason); } +FailureOr +VMILayoutSupport::getGroupBroadcastLoadSupport( + const VMITargetCapabilityRegistry &capabilities, + VMIGroupBroadcastLoadOp op, std::string *reason) const { + auto fail = + [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + auto resultType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0) + return fail("requires positive num_groups"); + if (resultType.getElementCount() % numGroups != 0) + return fail("requires num_groups to evenly divide result lane count"); + if (!capabilities.supportsDirectMemory(op.getSource().getType(), "source") + .isSupported()) + return fail("requires supported direct memory source"); + if (!isa(op.getSource().getType())) + return fail("requires !pto.ptr source for E2B lowering"); + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (elementBits != 16 && elementBits != 32) + return fail("E2B lowering currently supports only 16-bit and 32-bit " + "element types"); + int64_t directGroupSize = 256 / elementBits; + + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return fail("E2B lowering requires assigned result layout"); + bool contiguousPacketLayout = layout.isContiguous(); + bool splitPacketLayout = layout.isDeinterleaved() && layout.getFactor() == 2 && + layout.getBlockElems() == 1; + if (!contiguousPacketLayout && !splitPacketLayout) + return fail("E2B lowering requires contiguous result layout for " + "direct group size or deinterleaved=2, block_elems=1 " + "result layout for split group size"); + + std::string fullChunkReason; + if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) + return fail(Twine("requires full result physical chunks; ") + + fullChunkReason); + + FailureOr lanesPerPart = + getDataLanesPerPart(resultType.getElementType()); + if (failed(lanesPerPart) || *lanesPerPart != (2048 / elementBits)) + return fail("E2B lowering requires one full 256-byte vreg per physical " + "part"); + + int64_t groupSize = resultType.getElementCount() / numGroups; + if (contiguousPacketLayout && groupSize != directGroupSize) + return fail("E2B contiguous lowering requires logical group size matching " + "the element-width direct packet size"); + if (splitPacketLayout && groupSize != 2 * directGroupSize) + return fail("E2B deinterleaved=2 lowering requires logical group size " + "matching the element-width split packet size"); + if (numGroups % 8 != 0) + return fail("E2B lowering requires num_groups to be a multiple of 8"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return fail("E2B lowering requires constant unit source_group_stride"); + + return VMIGroupBroadcastLoadSupport{ + VMIGroupBroadcastLoadSupportKind::E2BVlds}; +} + FailureOr VMILayoutSupport::getGroupBroadcastSupport( const VMITargetCapabilityRegistry &capabilities, VMIVRegType sourceType, VMIVRegType resultType, int64_t numGroups, std::string *reason) const { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index e230cc0526..e5ddd0a366 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1367,6 +1367,15 @@ LogicalResult checkSupportedGroupSlotLoadShape( return success(); } +LogicalResult checkSupportedGroupBroadcastLoadShape( + const VMITargetCapabilityRegistry &capabilities, VMIGroupBroadcastLoadOp op, + std::string *reason) { + VMILayoutSupport supports; + if (failed(supports.getGroupBroadcastLoadSupport(capabilities, op, reason))) + return failure(); + return success(); +} + LogicalResult checkSupportedGroupStoreShape(const VMITargetCapabilityRegistry &capabilities, VMIGroupStoreOp op, std::string *reason) { @@ -5383,6 +5392,132 @@ struct OneToNVMIMaskedStoreOpPattern } }; +struct OneToNVMIGroupBroadcastLoadOpPattern + : OneToNOpConversionPattern { + using OneToNOpConversionPattern< + VMIGroupBroadcastLoadOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(VMIGroupBroadcastLoadOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto resultVMIType = cast(op.getResult().getType()); + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + bool contiguousPacketLayout = layout && layout.isContiguous(); + bool splitPacketLayout = layout && layout.isDeinterleaved() && + layout.getFactor() == 2 && + layout.getBlockElems() == 1; + if (!contiguousPacketLayout && !splitPacketLayout) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires " + "contiguous result layout for direct group size or " + "deinterleaved=2, block_elems=1 result layout for split " + "group size"); + + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits != 16 && elementBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires b16 or b32 " + "element type"); + int64_t directGroupSize = 256 / elementBits; + StringRef e2bDist = elementBits == 16 ? "E2B_B16" : "E2B_B32"; + + int64_t numGroups = op.getNumGroupsAttr().getInt(); + if (numGroups <= 0 || resultVMIType.getElementCount() % numGroups != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires valid num_groups"); + int64_t groupSize = resultVMIType.getElementCount() / numGroups; + if (numGroups % 8 != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires num_groups " + "multiple of 8"); + if (contiguousPacketLayout && groupSize != directGroupSize) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B contiguous lowering requires " + "element-width direct group size"); + if (splitPacketLayout && groupSize != 2 * directGroupSize) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B deinterleaved=2 lowering requires " + "element-width split group size"); + + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires constant unit " + "source_group_stride"); + + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_broadcast_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_broadcast_load offset must convert to one value", + rewriter); + if (failed(source) || failed(offset)) + return failure(); + if (!isa((*source).getType())) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B lowering requires !pto.ptr source"); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); + if (failed(chunksPerPart) || *chunksPerPart <= 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires known chunks per part"); + int64_t factor = layout.getFactor(); + for (int64_t part = 1; part < factor; ++part) { + FailureOr currentChunks = + getDataChunksInPart(resultVMIType, part); + if (failed(currentChunks) || *currentChunks != *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load requires uniform chunks per part"); + } + if (static_cast(resultTypes.size()) != + factor * *chunksPerPart) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load physical arity mismatch"); + if (*chunksPerPart != numGroups / 8) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load expected one E2B packet per 8 groups in " + "each part"); + + SmallVector packets; + packets.reserve(*chunksPerPart); + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + Type packetType = resultTypes[chunk]; + auto vregType = dyn_cast(packetType); + if (!vregType) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load result must be vreg"); + Value packetOffset = + createChunkOffset(op.getLoc(), *offset, chunk * 8, rewriter); + packets.push_back( + rewriter + .create(op.getLoc(), packetType, + /*updated_base=*/Type{}, *source, packetOffset, + rewriter.getStringAttr(e2bDist)) + .getResult()); + } + + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t part = 0; part < factor; ++part) { + for (int64_t chunk = 0; chunk < *chunksPerPart; ++chunk) { + int64_t flatIndex = part * *chunksPerPart + chunk; + if (resultTypes[flatIndex] != resultTypes[chunk]) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load E2B reused packet type mismatch"); + results.push_back(packets[chunk]); + } + } + + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } +}; + struct OneToNVMIStrideLoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern::OneToNOpConversionPattern; @@ -7879,8 +8014,9 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, OneToNVMIDeinterleaveLoadOpPattern, OneToNVMIGroupLoadOpPattern, - OneToNVMIGroupSlotLoadOpPattern, OneToNVMIStrideLoadOpPattern, - OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, + OneToNVMIGroupSlotLoadOpPattern, OneToNVMIGroupBroadcastLoadOpPattern, + OneToNVMIStrideLoadOpPattern, OneToNVMIMaskedLoadOpPattern, + OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, OneToNVMIInterleaveStoreOpPattern, OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, OneToNVMIStrideStoreOpPattern, @@ -8712,6 +8848,21 @@ verifySupportedVMIToVPTOOps(ModuleOp module, << reason << ")"; return WalkResult::interrupt(); } + if (auto load = dyn_cast(op)) { + std::string reason; + if (succeeded(checkSupportedGroupBroadcastLoadShape(capabilities, load, + &reason))) + return WalkResult::advance(); + load.emitError() + << kVMIDiagUnsupportedPrefix + << "pto.vmi.group_broadcast_load currently lowers through E2B " + "only for b16/b32 contiguous direct group size or " + "deinterleaved=2/block_elems=1 split group size full result " + "chunks, num_groups multiple of 8, unit source_group_stride, " + "and supported UB pointer source (" + << reason << ")"; + return WalkResult::interrupt(); + } if (auto load = dyn_cast(op)) { if (enableStableGatherMaskedLoad) { load.emitError() diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..2f881b2c71 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_load_e2b_b16.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xbf16> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<256xbf16> + return %out : !pto.vmi.vreg<256xbf16> + } + + func.func @vmi_layout_assignment_group_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf32> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[E2B:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B]], %[[E2B]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC32:.*]]: !pto.ptr, %[[OFF32:.*]]: index) +// CHECK: %[[E2B32:.*]] = pto.vlds %[[SRC32]][%[[OFF32]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B32]] : !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..3de1463671 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto @@ -0,0 +1,74 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 16} + : !pto.ptr -> !pto.vmi.vreg<16xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 16} + : !pto.vmi.vreg<16xf16> -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16_deint2( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xbf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xbf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xbf16> -> !pto.vmi.vreg<256xbf16> + return %out : !pto.vmi.vreg<256xbf16> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<64xf32> + return %out : !pto.vmi.vreg<64xf32> + } + + func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32_deint2( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<128xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> + return %out : !pto.vmi.vreg<128xf32> + } +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[E2B0:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: %[[OFF8:.*]] = arith.addi %[[OFF]], %[[C8]] : index +// CHECK: %[[E2B1:.*]] = pto.vlds %[[SRC]][%[[OFF8]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: return %[[E2B0]], %[[E2B1]] : !pto.vreg<128xf16>, !pto.vreg<128xf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16_deint2 +// CHECK-SAME: (%[[SRC2:.*]]: !pto.ptr, %[[OFF2:.*]]: index) +// CHECK: %[[E2B2:.*]] = pto.vlds %[[SRC2]][%[[OFF2]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B2]], %[[E2B2]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC3:.*]]: !pto.ptr, %[[OFF3:.*]]: index) +// CHECK: %[[E2B3:.*]] = pto.vlds %[[SRC3]][%[[OFF3]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B3]] : !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b32_deint2 +// CHECK-SAME: (%[[SRC4:.*]]: !pto.ptr, %[[OFF4:.*]]: index) +// CHECK: %[[E2B4:.*]] = pto.vlds %[[SRC4]][%[[OFF4]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B4]], %[[E2B4]] : !pto.vreg<64xf32>, !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto new file mode 100644 index 0000000000..da550785b5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_no_e2b.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_no_e2b( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } + +} + +// CHECK-LABEL: func.func @vmi_layout_assignment_group_slot_broadcast_no_e2b +// CHECK-NOT: E2B_B16 +// CHECK-NOT: E2B_B32 +// CHECK: pto.vselr +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto new file mode 100644 index 0000000000..edcddb8d50 --- /dev/null +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_layout_assignment_group_slot_broadcast_partial_packet_invalid( + %src: !pto.ptr, %off: index) -> !pto.vmi.vreg<64xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 4} + : !pto.ptr -> !pto.vmi.vreg<4xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 4} + : !pto.vmi.vreg<4xf16> -> !pto.vmi.vreg<64xf16> + return %out : !pto.vmi.vreg<64xf16> + } +} + +// CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.group_broadcast has no registered layout support +// CHECK-SAME: requires full result physical chunks +// CHECK-NOT: E2B_B16 diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto new file mode 100644 index 0000000000..41c2304d23 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16.pto @@ -0,0 +1,43 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_load_e2b_b16( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + return %out : !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + } + + func.func @vmi_to_vpto_group_broadcast_load_e2b_b32( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %c1 = arith.constant 1 : index + %out = pto.vmi.group_broadcast_load %src[%off], %c1 + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %out : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } +} + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_load_e2b_b16 +// CHECK-SAME: (%[[SRC:.*]]: !pto.ptr, %[[OFF:.*]]: index) +// CHECK: %[[E2B:.*]] = pto.vlds %[[SRC]][%[[OFF]]] {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xbf16> +// CHECK: return %[[E2B]], %[[E2B]] : !pto.vreg<128xbf16>, !pto.vreg<128xbf16> + +// CHECK-LABEL: func.func @vmi_to_vpto_group_broadcast_load_e2b_b32 +// CHECK-SAME: (%[[SRC32:.*]]: !pto.ptr, %[[OFF32:.*]]: index) +// CHECK: %[[E2B32:.*]] = pto.vlds %[[SRC32]][%[[OFF32]]] {dist = "E2B_B32"} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: return %[[E2B32]] : !pto.vreg<64xf32> diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto new file mode 100644 index 0000000000..3e942b0737 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s + +module { + func.func @vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid( + %src: !pto.ptr, %off: index, %stride: index) + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> { + %out = pto.vmi.group_broadcast_load %src[%off], %stride + {num_groups = 8} + : !pto.ptr + -> !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + return %out : !pto.vmi.vreg<256xbf16, #pto.vmi.layout> + } +} + +// CHECK: VMI-UNSUPPORTED: +// CHECK: pto.vmi.group_broadcast_load currently lowers through E2B +// CHECK: E2B lowering requires constant unit source_group_stride From c6a4eed083d4eb53731969ce4994629377fd8122 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:43:15 +0800 Subject: [PATCH 40/54] Rename VMI layout fold pass --- docs/designs/vmi-implementation-manual.md | 16 ++--- docs/designs/vmi-introduction.md | 6 +- .../vmi-layout-assignment-implementation.md | 14 ++-- .../vmi-layout-assignment-lowering-design.md | 2 +- docs/designs/vmi-layout-lowering-cases.md | 2 +- include/PTO/Transforms/Passes.h | 2 +- include/PTO/Transforms/Passes.td | 16 ++--- lib/PTO/Transforms/CMakeLists.txt | 2 +- ...outFoldConsumers.cpp => VMILayoutFold.cpp} | 70 ++++++++++++++++--- ..._deint4.pto => vmi_layout_fold_deint4.pto} | 16 ++--- test/lit/vmi/vmi_layout_fold_load.pto | 62 ++++++++++++++++ ...e.pto => vmi_layout_fold_masked_store.pto} | 10 +-- ...rs_store.pto => vmi_layout_fold_store.pto} | 10 +-- test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 4 +- tools/ptoas/ptoas.cpp | 2 +- 15 files changed, 175 insertions(+), 59 deletions(-) rename lib/PTO/Transforms/{VMILayoutFoldConsumers.cpp => VMILayoutFold.cpp} (67%) rename test/lit/vmi/{vmi_layout_fold_consumers_deint4.pto => vmi_layout_fold_deint4.pto} (84%) create mode 100644 test/lit/vmi/vmi_layout_fold_load.pto rename test/lit/vmi/{vmi_layout_fold_consumers_masked_store.pto => vmi_layout_fold_masked_store.pto} (86%) rename test/lit/vmi/{vmi_layout_fold_consumers_store.pto => vmi_layout_fold_store.pto} (86%) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 8824d605b4..76f6d966f7 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -126,7 +126,7 @@ pipeline: pto-validate-vmi-ir vmi-layout-assignment canonicalize/cse -vmi-layout-fold-consumers +vmi-layout-fold canonicalize/cse vmi-layout-rematerialize canonicalize/cse @@ -182,7 +182,7 @@ vmi-layout-assignment: module-level per-SSA-value constraint solver。先收集等价类、producer natural layout 和 consumer request, 再把结果写回 VMI type/helper op。它可以使用 IRRewriter 改 IR,但不以 TypeConverter 为主模型。 -vmi-layout-fold-consumers / vmi-layout-rematerialize / vmi-layout-sink-materialization: +vmi-layout-fold / vmi-layout-rematerialize / vmi-layout-sink-materialization: legal-to-legal VMI optimization passes。它们只消费 layout-assigned VMI IR,并继续产出 layout-assigned VMI IR;所有新选择必须体现在 current op、type 或 helper IR 中。 @@ -238,7 +238,7 @@ Layout support query layer: 一个 lowering pattern 自己使用的分支应该留在该 pattern 内。 Layout optimization layer: - lib/PTO/Transforms/VMILayoutFoldConsumers.cpp + lib/PTO/Transforms/VMILayoutFold.cpp lib/PTO/Transforms/VMILayoutRematerialize.cpp lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp lib/PTO/Transforms/VMILegalizeArithSelect.cpp @@ -362,12 +362,12 @@ lib/PTO/Transforms/VMILayoutAssignment.cpp hide chosen layout in a pass-private side table infer external VMI ABI -lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +lib/PTO/Transforms/VMILayoutFold.cpp lib/PTO/Transforms/VMILayoutRematerialize.cpp lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp lib/PTO/Transforms/VMILegalizeArithSelect.cpp pass: - VMILayoutFoldConsumersPass + VMILayoutFoldPass VMILayoutRematerializePass VMILayoutSinkMaterializationPass VMILegalizeArithSelectPass @@ -437,8 +437,8 @@ source file pass primary lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/PTOValidateVMIIR.cpp pto-validate-vmi-layout-ir Operation::walk + recursive type/attr scan lib/PTO/Transforms/VMILayoutAssignment.cpp vmi-layout-assignment module-level union-find solver + IRRewriter -lib/PTO/Transforms/VMILayoutFoldConsumers.cpp - vmi-layout-fold-consumers Pattern-free local IR rewrite +lib/PTO/Transforms/VMILayoutFold.cpp + vmi-layout-fold Pattern-free local IR rewrite lib/PTO/Transforms/VMILayoutRematerialize.cpp vmi-layout-rematerialize Pattern-free local IR rewrite lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp @@ -1186,7 +1186,7 @@ raw VMI producer -> pto-validate-vmi-ir -> vmi-layout-assignment -> canonicalize/cse - -> vmi-layout-fold-consumers + -> vmi-layout-fold -> canonicalize/cse -> vmi-layout-rematerialize -> canonicalize/cse diff --git a/docs/designs/vmi-introduction.md b/docs/designs/vmi-introduction.md index e1cc5974f1..089120a4f8 100644 --- a/docs/designs/vmi-introduction.md +++ b/docs/designs/vmi-introduction.md @@ -178,7 +178,7 @@ group-slot control-flow/function boundary pto-validate-vmi-ir -> vmi-layout-assignment -> canonicalize/cse - -> vmi-layout-fold-consumers + -> vmi-layout-fold -> canonicalize/cse -> vmi-layout-rematerialize -> canonicalize/cse @@ -488,7 +488,7 @@ pto.vmi.store %y, %out0 // wants contiguous 一个 SSA value 只能属于一个 data layout 等价类。若两个 use 不能共同满足, baseline assignment 保留一个等价类 layout,并在不匹配 use 前插 -`ensure_layout`。后续 `vmi-layout-fold-consumers`、`vmi-layout-rematerialize` +`ensure_layout`。后续 `vmi-layout-fold`、`vmi-layout-rematerialize` 和 `vmi-layout-sink-materialization` 可以在显式 helper op 上做优化,但 `vmi-to-vpto` 不读取隐藏 plan 或 sibling user。 @@ -526,7 +526,7 @@ body,当前需要显式 ABI materialization 设计,因此 layout assignment 这个阶段之后,IR 不再依赖隐藏 plan;后续 pass 和 `vmi-to-vpto` 都只读取 type 上的 layout 和显式 `ensure_*` helper。 -### 3.3 `vmi-layout-fold-consumers` +### 3.3 `vmi-layout-fold` 当 consumer 可以直接保持同样的外部效果时,把显式 materialization 折进 consumer。 diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index d3778ab8fa..3e01e56048 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -14,7 +14,7 @@ Recommended pass pipeline: pto-validate-vmi-ir -> vmi-layout-assignment // hard legalization baseline -> canonicalize/cse - -> vmi-layout-fold-consumers // optional optimization + -> vmi-layout-fold // optional optimization -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse @@ -50,7 +50,7 @@ canonicalize/cse: remove dead helpers and merge identical cloned producers where MLIR legality permits -vmi-layout-fold-consumers: +vmi-layout-fold: fold use-site materialization into consumers that can directly consume the source layout while preserving the same logical effect example: ensure_layout(deinterleaved=2 -> contiguous) feeding store may become @@ -597,7 +597,7 @@ bitcast: This includes contiguous, deinterleaved, and identical group_slots layouts. ``` -`vmi-layout-fold-consumers`, rematerialization, sink/hoist, and private +`vmi-layout-fold`, rematerialization, sink/hoist, and private function specialization passes consume explicit helper IR. They may replace helpers with cheaper equivalent IR, but they must not introduce hidden lowering plans that `vmi-to-vpto` has to rediscover from producer/user context. @@ -2033,13 +2033,13 @@ runtime SIM: test/vpto/cases/vmi/widen-f16-to-f32-store-reduce ``` -Current checked-in lit coverage for the first `vmi-layout-fold-consumers` +Current checked-in lit coverage for the first `vmi-layout-fold` optimization is: ```text -test/lit/vmi/vmi_layout_fold_consumers_store.pto -test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto -test/lit/vmi/vmi_layout_fold_consumers_deint4.pto +test/lit/vmi/vmi_layout_fold_store.pto +test/lit/vmi/vmi_layout_fold_masked_store.pto +test/lit/vmi/vmi_layout_fold_deint4.pto ``` Current checked-in lit coverage for the first `vmi-layout-rematerialize` diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index bc16e3bb20..77ba7447ee 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -11,7 +11,7 @@ VMI surface IR -> pto-validate-vmi-ir -> vmi-layout-assignment // hard legalization baseline -> canonicalize/cse - -> vmi-layout-fold-consumers // optional optimization + -> vmi-layout-fold // optional optimization -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse diff --git a/docs/designs/vmi-layout-lowering-cases.md b/docs/designs/vmi-layout-lowering-cases.md index 16a87696ab..9b26ecbde7 100644 --- a/docs/designs/vmi-layout-lowering-cases.md +++ b/docs/designs/vmi-layout-lowering-cases.md @@ -5516,7 +5516,7 @@ for i = 0..127: Optimization pass result: ```text -// vmi-layout-fold-consumers may remove both ensure_layout ops if the target +// vmi-layout-fold may remove both ensure_layout ops if the target // supports store lowering that consumes deinterleaved=2 and writes contiguous // row-major memory. pto.vmi.store %t1, %out1[%off] diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 78cb8bc78e..4ac1eaa5bb 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -115,7 +115,7 @@ LogicalResult validateVMILayoutAssignedIR(ModuleOp module, std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); std::unique_ptr createVMILayoutAssignmentPass(); -std::unique_ptr createVMILayoutFoldConsumersPass(); +std::unique_ptr createVMILayoutFoldPass(); std::unique_ptr createVMILayoutRematerializePass(); std::unique_ptr createVMILayoutSinkMaterializationPass(); std::unique_ptr createVMILegalizeArithSelectPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index fdbe82b5bf..18444f8a36 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -841,16 +841,16 @@ def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { "mlir::scf::SCFDialect"]; } -def VMILayoutFoldConsumers : Pass<"vmi-layout-fold-consumers", "ModuleOp"> { - let summary = "Fold VMI layout materialization into layout-aware consumers"; +def VMILayoutFold : Pass<"vmi-layout-fold", "ModuleOp"> { + let summary = "Fold VMI layout materializations"; let description = [{ - Optimizes legal layout-assigned VMI IR by replacing selected use-site - ensure_layout consumers with consumers that can directly lower from the - source layout while preserving the same logical effect. The pass does not - choose layouts by inspecting producer/user context for vmi-to-vpto; it only - rewrites explicit helper IR into an equivalent local-consumer form. + Optimizes legal layout-assigned VMI IR by folding selected ensure_layout + helpers into layout-aware producers or consumers while preserving the same + logical effect. The pass does not choose layouts by inspecting arbitrary + producer/user context for vmi-to-vpto; it only rewrites explicit helper IR + into equivalent local forms. }]; - let constructor = "mlir::pto::createVMILayoutFoldConsumersPass()"; + let constructor = "mlir::pto::createVMILayoutFoldPass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", "mlir::func::FuncDialect", "mlir::pto::PTODialect", diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 3f808c3072..ddf73ba356 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -38,7 +38,7 @@ add_mlir_dialect_library(PTOTransforms PTOValidateVMIIR.cpp VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp - VMILayoutFoldConsumers.cpp + VMILayoutFold.cpp VMILayoutSupport.cpp VMILayoutRematerialize.cpp VMILayoutSinkMaterialization.cpp diff --git a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp b/lib/PTO/Transforms/VMILayoutFold.cpp similarity index 67% rename from lib/PTO/Transforms/VMILayoutFoldConsumers.cpp rename to lib/PTO/Transforms/VMILayoutFold.cpp index ac7942d93a..9646041310 100644 --- a/lib/PTO/Transforms/VMILayoutFoldConsumers.cpp +++ b/lib/PTO/Transforms/VMILayoutFold.cpp @@ -8,7 +8,7 @@ // FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository // for the full text of the License. -//===- VMILayoutFoldConsumers.cpp - Fold VMI layout consumers ------------===// +//===- VMILayoutFold.cpp - Fold VMI layout materializations --------------===// //===----------------------------------------------------------------------===// #include "PTO/IR/PTO.h" @@ -27,7 +27,7 @@ namespace mlir { namespace pto { -#define GEN_PASS_DEF_VMILAYOUTFOLDCONSUMERS +#define GEN_PASS_DEF_VMILAYOUTFOLD #include "PTO/Transforms/Passes.h.inc" } // namespace pto } // namespace mlir @@ -37,6 +37,57 @@ using namespace mlir::pto; namespace { +static bool hasSameDataShapeAndElementType(VMIVRegType lhs, VMIVRegType rhs) { + return lhs && rhs && lhs.getElementCount() == rhs.getElementCount() && + lhs.getElementType() == rhs.getElementType(); +} + +static bool isFoldableLoadEnsure(VMIEnsureLayoutOp ensure) { + auto load = ensure.getSource().getDefiningOp(); + if (!load) + return false; + + auto sourceType = dyn_cast(ensure.getSource().getType()); + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!hasSameDataShapeAndElementType(sourceType, resultType)) + return false; + + VMILayoutSupport supports; + return succeeded(supports.canMaterializeDataLayout(sourceType, resultType)); +} + +static void tryFoldLoadEnsures( + VMILoadOp load, SmallVectorImpl &maybeDeadEnsures) { + auto sourceType = dyn_cast(load.getResult().getType()); + if (!sourceType) + return; + + VMIVRegType targetType; + SmallVector ensures; + for (OpOperand &use : load.getResult().getUses()) { + auto ensure = dyn_cast(use.getOwner()); + if (!ensure || use.getOperandNumber() != 0 || !isFoldableLoadEnsure(ensure)) + return; + + auto resultType = cast(ensure.getResult().getType()); + if (!targetType) { + targetType = resultType; + } else if (targetType != resultType) { + return; + } + ensures.push_back(ensure); + } + + if (ensures.empty() || targetType == sourceType) + return; + + load.getResult().setType(targetType); + for (VMIEnsureLayoutOp ensure : ensures) { + ensure.getResult().replaceAllUsesWith(load.getResult()); + maybeDeadEnsures.push_back(ensure); + } +} + static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { auto sourceType = dyn_cast(ensure.getSource().getType()); auto resultType = dyn_cast(ensure.getResult().getType()); @@ -94,16 +145,19 @@ static void tryFoldEnsureLayoutIntoMaskedStore( maybeDeadMaskEnsures.push_back(maskEnsure); } -struct VMILayoutFoldConsumersPass - : public mlir::pto::impl::VMILayoutFoldConsumersBase< - VMILayoutFoldConsumersPass> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutFoldConsumersPass) +struct VMILayoutFoldPass + : public mlir::pto::impl::VMILayoutFoldBase { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VMILayoutFoldPass) void runOnOperation() override { ModuleOp module = getOperation(); SmallVector maybeDeadEnsures; SmallVector maybeDeadMaskEnsures; + module.walk([&](VMILoadOp load) { + tryFoldLoadEnsures(load, maybeDeadEnsures); + }); + module.walk([&](Operation *op) { if (auto store = dyn_cast(op)) tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), @@ -126,6 +180,6 @@ struct VMILayoutFoldConsumersPass } // namespace -std::unique_ptr mlir::pto::createVMILayoutFoldConsumersPass() { - return std::make_unique(); +std::unique_ptr mlir::pto::createVMILayoutFoldPass() { + return std::make_unique(); } diff --git a/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto b/test/lit/vmi/vmi_layout_fold_deint4.pto similarity index 84% rename from test/lit/vmi/vmi_layout_fold_consumers_deint4.pto rename to test/lit/vmi/vmi_layout_fold_deint4.pto index 84ba3b5b1e..6cd26b29e4 100644 --- a/test/lit/vmi/vmi_layout_fold_consumers_deint4.pto +++ b/test/lit/vmi/vmi_layout_fold_deint4.pto @@ -6,11 +6,11 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD -// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { - func.func @vmi_layout_fold_consumers_store_deint4( + func.func @vmi_layout_fold_store_deint4( %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %dst: !pto.ptr, %offset: index) { @@ -23,7 +23,7 @@ module { return } - func.func @vmi_layout_fold_consumers_masked_store_deint4( + func.func @vmi_layout_fold_masked_store_deint4( %value: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<256xb32, #pto.vmi.layout>, %dst: !pto.ptr, @@ -42,7 +42,7 @@ module { } } -// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// FOLD-LABEL: func.func @vmi_layout_fold_store_deint4( // FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // FOLD-NOT: pto.vmi.ensure_layout // FOLD: pto.vmi.store %[[VALUE]] @@ -50,7 +50,7 @@ module { // FOLD-NOT: pto.vmi.ensure_layout // FOLD: return -// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// FOLD-LABEL: func.func @vmi_layout_fold_masked_store_deint4( // FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<256xf32, #pto.vmi.layout> // FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<256xb32, #pto.vmi.layout> // FOLD-NOT: pto.vmi.ensure_layout @@ -63,7 +63,7 @@ module { // FOLD-NOT: pto.vmi.ensure_mask_layout // FOLD: return -// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store_deint4( +// LOWER-LABEL: func.func @vmi_layout_fold_store_deint4( // LOWER: pto.vintlv // LOWER: pto.vintlv // LOWER: pto.vintlv @@ -73,7 +73,7 @@ module { // LOWER: pto.vsts // LOWER: pto.vsts -// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store_deint4( +// LOWER-LABEL: func.func @vmi_layout_fold_masked_store_deint4( // LOWER: pto.vintlv // LOWER: pto.vintlv // LOWER: pto.vintlv diff --git a/test/lit/vmi/vmi_layout_fold_load.pto b/test/lit/vmi/vmi_layout_fold_load.pto new file mode 100644 index 0000000000..804c522df1 --- /dev/null +++ b/test/lit/vmi/vmi_layout_fold_load.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_load_all_ensures( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split0 = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split1 = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %split0, %split1 + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_load_keeps_mixed_use( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + %split = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<128xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> + return %load, %split + : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout> + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_load_all_ensures( +// FOLD: %[[LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return %[[LOAD]], %[[LOAD]] + +// FOLD-LABEL: func.func @vmi_layout_fold_load_keeps_mixed_use( +// FOLD: %[[MIXED_LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// FOLD: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[MIXED_LOAD]] +// FOLD: return %[[MIXED_LOAD]], %[[SPLIT]] + +// LOWER-LABEL: func.func @vmi_layout_fold_load_all_ensures( +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 +// LOWER-SAME: "DINTLV_B32" +// LOWER: return %[[LOW]], %[[HIGH]], %[[LOW]], %[[HIGH]] diff --git a/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto b/test/lit/vmi/vmi_layout_fold_masked_store.pto similarity index 86% rename from test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto rename to test/lit/vmi/vmi_layout_fold_masked_store.pto index 8f31b78f7b..4fc8cbee83 100644 --- a/test/lit/vmi/vmi_layout_fold_consumers_masked_store.pto +++ b/test/lit/vmi/vmi_layout_fold_masked_store.pto @@ -6,11 +6,11 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD -// RUN: pto-test-opt %s -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { - func.func @vmi_layout_fold_consumers_masked_store( + func.func @vmi_layout_fold_masked_store( %value: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask: !pto.vmi.mask<128xb32, #pto.vmi.layout>, %dst: !pto.ptr, @@ -29,7 +29,7 @@ module { } } -// FOLD-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// FOLD-LABEL: func.func @vmi_layout_fold_masked_store( // FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> // FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<128xb32, #pto.vmi.layout> // FOLD-NOT: pto.vmi.ensure_layout @@ -42,7 +42,7 @@ module { // FOLD-NOT: pto.vmi.ensure_mask_layout // FOLD: return -// LOWER-LABEL: func.func @vmi_layout_fold_consumers_masked_store( +// LOWER-LABEL: func.func @vmi_layout_fold_masked_store( // LOWER-SAME: %[[V0:[^,]+]]: !pto.vreg<64xf32> // LOWER-SAME: %[[V1:[^,]+]]: !pto.vreg<64xf32> // LOWER-SAME: %[[M0:[^,]+]]: !pto.mask diff --git a/test/lit/vmi/vmi_layout_fold_consumers_store.pto b/test/lit/vmi/vmi_layout_fold_store.pto similarity index 86% rename from test/lit/vmi/vmi_layout_fold_consumers_store.pto rename to test/lit/vmi/vmi_layout_fold_store.pto index e8249eec06..484b1c636b 100644 --- a/test/lit/vmi/vmi_layout_fold_consumers_store.pto +++ b/test/lit/vmi/vmi_layout_fold_store.pto @@ -6,11 +6,11 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers | FileCheck %s --check-prefix=FOLD -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold-consumers -vmi-to-vpto | FileCheck %s --check-prefix=LOWER +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER module { - func.func @vmi_layout_fold_consumers_store( + func.func @vmi_layout_fold_store( %src: !pto.vmi.vreg<128xf16>, %scale: f32, %out1: !pto.ptr, @@ -32,7 +32,7 @@ module { } -// FOLD-LABEL: func.func @vmi_layout_fold_consumers_store( +// FOLD-LABEL: func.func @vmi_layout_fold_store( // FOLD-SAME: %[[SRC:.*]]: !pto.vmi.vreg<128xf16, #pto.vmi.layout> // FOLD: %[[SCALE:.*]] = pto.vmi.broadcast // FOLD-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> @@ -49,7 +49,7 @@ module { // FOLD-NOT: pto.vmi.ensure_layout // FOLD: return -// LOWER-LABEL: func.func @vmi_layout_fold_consumers_store( +// LOWER-LABEL: func.func @vmi_layout_fold_store( // LOWER: %[[SCALE0:.*]] = pto.vdup // LOWER: %[[SCALE1:.*]] = pto.vdup // LOWER: %[[WIDE0:.*]] = pto.vcvt diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto index e49dba60c3..dd28fbe21e 100644 --- a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -23,7 +23,7 @@ module attributes {pto.target_arch = "a5"} { return } - func.func @vmi_ptoas_cli_fold_consumers_pipeline( + func.func @vmi_ptoas_cli_fold_pipeline( %src: !pto.ptr, %dst: !pto.ptr, %offset: index) { @@ -47,7 +47,7 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast -// CHECK-LABEL: func.func @vmi_ptoas_cli_fold_consumers_pipeline +// CHECK-LABEL: func.func @vmi_ptoas_cli_fold_pipeline // CHECK: pto.vlds // CHECK: pto.vcvt {{.*}} {part = "EVEN"} // CHECK: pto.vcvt {{.*}} {part = "ODD"} diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 0639dcd7e8..0f30adfb9c 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1809,7 +1809,7 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.addPass(pto::createVMILayoutAssignmentPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(pto::createVMILayoutFoldConsumersPass()); + pm.addPass(pto::createVMILayoutFoldPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); pm.addPass(pto::createVMILayoutRematerializePass()); From 460ae8841d109a1b09eb01364d18c23b39a0dd87 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 16:48:14 +0800 Subject: [PATCH 41/54] Implement VMI relation-aware rematerialization --- .../vmi-layout-assignment-implementation.md | 10 +- .../vmi-layout-assignment-lowering-design.md | 2 + ...ayout-relation-rematerialization-design.md | 239 ++++++++++ ...lation-rematerialization-implementation.md | 408 ++++++++++++++++++ include/PTO/IR/VMIOps.td | 96 ++--- include/PTO/Transforms/Passes.td | 9 +- include/PTO/Transforms/VMILayoutSupport.h | 6 + lib/PTO/Transforms/VMILayoutFold.cpp | 38 +- lib/PTO/Transforms/VMILayoutRematerialize.cpp | 211 ++++++++- lib/PTO/Transforms/VMILayoutSupport.cpp | 98 ++++- .../vmi_layout_gate_extf_support_invalid.pto | 2 +- .../vmi/vmi_layout_rematerialize_relation.pto | 131 ++++++ test/lit/vmi/vmi_op_verifier_basic.pto | 27 +- 13 files changed, 1174 insertions(+), 103 deletions(-) create mode 100644 docs/designs/vmi-layout-relation-rematerialization-design.md create mode 100644 docs/designs/vmi-layout-relation-rematerialization-implementation.md create mode 100644 test/lit/vmi/vmi_layout_rematerialize_relation.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 3e01e56048..4c8d52a97d 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -18,6 +18,8 @@ pto-validate-vmi-ir -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse + -> vmi-layout-fold // optional optimization over remat-exposed helpers + -> canonicalize/cse -> vmi-layout-sink-materialization // optional optimization -> canonicalize/cse -> vmi-legalize-arith-select @@ -66,8 +68,12 @@ vmi-layout-rematerialize: replace explicit ensure_* helpers with cloned cheap layout-polymorphic producers when the clone directly creates the requested result type current implementation: splat pto.vmi.constant, pto.vmi.broadcast, - pto.vmi.iota, pto.vmi.create_mask, pto.vmi.create_group_mask, and - pto.vmi.constant_mask + pto.vmi.iota, selected layout-transparent data ops, widening + pto.vmi.ext{f,si,ui}, pto.vmi.create_mask, pto.vmi.create_group_mask, and + pto.vmi.constant_mask. Relation-aware remat rewrites result-side + ensure_layout through layout-transparent producers and widening ext + producers, leaving any newly exposed producer-side helpers for the following + vmi-layout-fold. not included in the first implementation: load, group_load, masked_load, group_slot_load, and group_broadcast; those require separate memory, execution-count, or source-layout proof before they can be rematerialized diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 77ba7447ee..98881fe6a8 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -15,6 +15,8 @@ VMI surface IR -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse + -> vmi-layout-fold // optional optimization over remat-exposed helpers + -> canonicalize/cse -> vmi-layout-sink-materialization // optional optimization -> canonicalize/cse -> optional later layout optimization passes diff --git a/docs/designs/vmi-layout-relation-rematerialization-design.md b/docs/designs/vmi-layout-relation-rematerialization-design.md new file mode 100644 index 0000000000..dc0a74907c --- /dev/null +++ b/docs/designs/vmi-layout-relation-rematerialization-design.md @@ -0,0 +1,239 @@ +# VMI Layout Relation-Aware Rematerialization Design + +本文描述 VMI layout optimization 中 relation-aware rematerialization 的设计。 +目标是让 `vmi-layout-assignment` 只产生 legal baseline IR,把跨 layout +relation 的优化放到显式 `ensure_layout` 上完成。 + +## 1. Motivation + +`vmi-layout-assignment` 已经负责三件 hard legalization 工作: + +```text +1. 为每个 VMI value 选择 concrete layout +2. 在不匹配的 use-site 插入 ensure_layout / ensure_mask_layout +3. 保证 vmi-to-vpto 只需要 local lowering information +``` + +对 `ext` 这类 width-changing op,assignment 的 baseline 可以保守选择: + +```text +ext f16 -> f32: + source = contiguous + result = deinterleaved=2 +``` + +如果下游 `truncf f32 -> f8` 要求 source 为 `deinterleaved=4`,assignment 会 +显式插入: + +```text +%e = pto.vmi.extf %x + : !vreg<..., layout> + -> !vreg<..., layout> + +%e4 = pto.vmi.ensure_layout %e + : !vreg<..., layout> + -> !vreg<..., layout> +``` + +这个 IR 已经合法,但不是最优。优化 pass 可以从显式 helper 出发,把 relation +应用到 producer: + +```text +ensure_layout(ext(src), resultLayout) + => ext(ensure_layout(src, derivedSourceLayout)) +``` + +这样 assignment 不需要做 consumer-driven global propagation,也不需要在多 +consumer 冲突时引入 cost model。 + +## 2. Goals + +```text +1. assignment 保持 hard legalization baseline,不做 ext relation propagation。 +2. relation-aware optimization 从显式 ensure_layout 出发。 +3. 多 consumer 冲突由 use-site helper + rematerialization 解决。 +4. vmi-to-vpto 仍只消费当前 op 的 operand/result layout,不扫描上下文。 +5. 变换必须是局部、确定、可验证的 IR rewrite。 +``` + +非目标: + +```text +1. 不做 ComputeY1 专用 pattern。 +2. 不在 assignment 中实现全局 cost model。 +3. 不通过 vmi-to-vpto 猜 producer/consumer relation。 +4. 第一阶段不做 trunc/narrow relation remat。 +``` + +## 3. Optimization Model + +relation-aware remat 以 `ensure_layout` 为唯一触发点: + +```text +%wanted = pto.vmi.ensure_layout %source : sourceLayout -> targetLayout +``` + +如果 `%source` 的 producer 可以在 `targetLayout` 或 relation 派生出的 operand +layout 下重新创建等价结果,则用 cloned producer 替换 helper。 + +### 3.1 Layout-Transparent Producer Remat + +对 layout-transparent elementwise op: + +```text +ensure_layout(op(a, b), L) + => op(ensure_layout(a, L), ensure_layout(b, L)) +``` + +适用对象包括纯 elementwise data ops: + +```text +addf/addi/subf/subi/mulf/muli/divf/minf/maxf +andi/ori/xori/shli/shrui +negf/absf/absi/sqrt/exp/ln/relu/not +fma +select, when data operands and mask layout requirements can be kept explicit +``` + +第一阶段可以先覆盖 ComputeY1 需要的 `mulf`,但实现形态应按 op family 泛化。 + +### 3.2 Widen Ext Relation Remat + +对 widening `ext`: + +```text +ensure_layout(ext(src), resultLayout) + => ext(ensure_layout(src, sourceLayout)) +``` + +其中: + +```text +resultFactor = sourceFactor * widenFactor +``` + +例子: + +```text +ext f16 -> f32, widenFactor = 2 +target result layout = deinterleaved=4 +derived source layout = deinterleaved=2 +``` + +`deinterleaved=1` 等价于 contiguous。 + +### 3.3 Producer Fold After Remat + +relation remat 可能暴露 producer-side helper: + +```text +ensure_layout(load(...), deinterleaved=2) +``` + +这类 helper 应由 `vmi-layout-fold` 吸收到 producer 或 consumer: + +```text +load contiguous + ensure_layout to deinterleaved=2 + => load result deinterleaved=2 +``` + +因此推荐优化 pipeline 在 remat 后再次运行 fold: + +```text +vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse +``` + +## 4. Multi-Consumer Conflict + +如果一个 `ext` result 有两个 consumer: + +```text +consumer A requires deinterleaved=2 +consumer B requires deinterleaved=4 +``` + +assignment 不需要判断哪个更优。它可以选择稳定 baseline,例如 `deinterleaved=2`, +并为另一个 use 插入 helper: + +```text +%e2 = pto.vmi.extf %x : contiguous -> deinterleaved=2 +consumer_a(%e2) + +%e4 = pto.vmi.ensure_layout %e2 : deinterleaved=2 -> deinterleaved=4 +consumer_b(%e4) +``` + +remat 再把第二个 use 优化成 cloned producer: + +```text +%x2 = pto.vmi.ensure_layout %x : contiguous -> deinterleaved=2 +%e4 = pto.vmi.extf %x2 : deinterleaved=2 -> deinterleaved=4 +consumer_b(%e4) +``` + +原 `%e2` 仍服务 `consumer_a`。这样不需要 assignment 做全局 cost selection。 + +## 5. ComputeY1 Shape + +baseline assignment 可能产生: + +```text +%x32 = extf %x16 // result deinterleaved=2 +%s32 = extf %scale16 // result deinterleaved=2 +%m = mulf %x32, %s32 // result deinterleaved=2 +%m4 = ensure_layout %m // deinterleaved=2 -> deinterleaved=4 +%y = truncf %m4 +``` + +remat/fold 后目标 IR: + +```text +%x16_d2 = load ... // folded deinterleaved=2 load +%x32_d4 = extf %x16_d2 // deinterleaved=2 -> deinterleaved=4 + +%scale16_d2 = group_broadcast_load ... // folded/assigned deinterleaved=2 +%scale32_d4 = extf %scale16_d2 // deinterleaved=2 -> deinterleaved=4 + +%m4 = mulf %x32_d4, %scale32_d4 +%y = truncf %m4 +``` + +关键点: + +```text +1. truncf 只通过 ensure_layout 表达自己的 source layout requirement。 +2. remat 不需要识别 quant 语义。 +3. ext relation 是 local rule。 +4. load/group_broadcast_load 的物理优化由 fold 或 producer capability 处理。 +``` + +## 6. Lowering Contract + +`vmi-to-vpto` 的 contract 不变: + +```text +1. 不扫描 ext 的 users。 +2. 不扫描 producer chain 来猜 layout。 +3. 只根据当前 op 的 operand/result layout lower。 +``` + +relation-aware remat 必须在 `vmi-to-vpto` 前把 IR 显式改写为: + +```text +%x = pto.vmi.load ... -> !vreg<..., layout> +%e = pto.vmi.extf %x + : !vreg<..., layout> + -> !vreg<..., layout> +``` + +之后 lowering 只消费这个 local shape。 + diff --git a/docs/designs/vmi-layout-relation-rematerialization-implementation.md b/docs/designs/vmi-layout-relation-rematerialization-implementation.md new file mode 100644 index 0000000000..605964799b --- /dev/null +++ b/docs/designs/vmi-layout-relation-rematerialization-implementation.md @@ -0,0 +1,408 @@ +# VMI Layout Relation-Aware Rematerialization Implementation Plan + +本文是 `vmi-layout-relation-rematerialization-design.md` 的实现计划。目标是 +扩展现有 `vmi-layout-rematerialize` / `vmi-layout-fold` 优化,让 assignment +保持 legal baseline,并从显式 `ensure_layout` 中恢复更好的 producer layout。 + +## 1. Current Baseline + +当前 pipeline 中相关 pass: + +```text +vmi-layout-assignment: + chooses concrete layouts + inserts ensure_layout / ensure_mask_layout / ensure_mask_granularity + +vmi-layout-fold: + folds selected ensure_layout helpers into layout-aware producers/consumers + current coverage includes store-side fold, load -> ensure_layout producer + fold, and inverse nested ensure_layout fold + +vmi-layout-rematerialize: + replaces ensure_* around cheap construction producers + current data coverage: splat constant, broadcast, iota + current mask coverage: create_mask, create_group_mask, constant_mask + +vmi-layout-sink-materialization: + sinks matching operand-side helpers through pure elementwise ops + it does not currently rewrite result-side ensure_layout(op(...), L) +``` + +ComputeY1-like IR currently remains suboptimal because assignment emits: + +```text +ensure_layout(mulf(ext(...), ext(...)), deinterleaved=4) +``` + +but remat does not yet: + +```text +1. hoist result-side ensure_layout through mulf +2. rematerialize ext under a requested result layout +3. expose foldable load/group_broadcast_load helpers +``` + +## 2. Support APIs + +Add support-layer helpers in `VMILayoutSupport`. + +### 2.1 Widen Relation Query + +```cpp +FailureOr getWidenSourceLayoutForResultLayout( + VMIVRegType sourceType, + VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, + std::string *reason = nullptr) const; +``` + +Semantics: + +```text +1. source/result lane count must match. +2. result element width must be an integer multiple of source element width. +3. first implementation supports widenFactor 2 and 4. +4. requestedResultLayout must be contiguous or deinterleaved(block_elems=1). +5. requested result factor F must be divisible by widenFactor K. +6. derived source factor is F / K. +7. derived source factor 1 means contiguous. +8. derived source/result layout pair must be accepted by ext support gates. +``` + +Examples: + +```text +f16 -> f32, requested result deinterleaved=4 + => source deinterleaved=2 + +f16 -> f32, requested result deinterleaved=2 + => source contiguous + +f8 -> f32, requested result deinterleaved=4 + => source contiguous +``` + +### 2.2 Ext Support Gates + +Update `getExtFSupport`, `getExtSISupport`, and `getExtUISupport` so they accept +relation-rematerialized local shapes: + +```text +source layout: + contiguous or deinterleaved(S, block_elems=1) + +result layout: + deinterleaved(S * widenFactor, block_elems=1) +``` + +Keep group_slots integer extension behavior unchanged. + +Reject: + +```text +1. result layout that is not deinterleaved for dense ext. +2. block_elems != 1 in this first implementation. +3. source/result arity that does not satisfy resultArity = factor * sourceArity. +4. unsupported element width relation. +``` + +`vmi-to-vpto` ext lowering already works from physical source/result arity. If +support admits `source deinterleaved=2 -> result deinterleaved=4`, lowering must +be covered by tests. + +## 3. Rematerialize Pass Changes + +Extend `VMILayoutRematerialize.cpp` around `VMIEnsureLayoutOp`. + +Recommended ordering for one helper: + +```text +try relation-aware ext remat +try result-side layout-transparent producer remat +try existing cheap construction remat +``` + +The pass should use a helper worklist. When one rewrite creates new +`ensure_layout` helpers, enqueue them so the same pass can continue locally. + +### 3.1 Ext Remat Pattern + +Match: + +```text +%wanted = pto.vmi.ensure_layout %old +%old = pto.vmi.extf %src +``` + +where `%wanted` has `requestedResultType`. + +Rewrite: + +```text +derivedSourceLayout = + support.getWidenSourceLayoutForResultLayout(srcType, requestedResultType, + requestedResultLayout) + +%src2 = materialize source to derivedSourceLayout +%new = pto.vmi.extf %src2 : derivedSourceType -> requestedResultType +replace %wanted with %new +``` + +Equivalent patterns are needed for: + +```text +pto.vmi.extf +pto.vmi.extsi +pto.vmi.extui +``` + +The source materialization step should: + +```text +1. reuse %src if it already has derivedSourceLayout. +2. create pto.vmi.ensure_layout otherwise. +3. enqueue the new helper for further remat/fold opportunities. +``` + +### 3.2 Layout-Transparent Result Helper Remat + +Match: + +```text +%wanted = pto.vmi.ensure_layout %old +%old = pto.vmi.mulf %lhs, %rhs +``` + +Rewrite: + +```text +%lhs2 = ensure_layout %lhs : lhsLayout -> requestedLayout +%rhs2 = ensure_layout %rhs : rhsLayout -> requestedLayout +%new = pto.vmi.mulf %lhs2, %rhs2 : requestedLayout +replace %wanted with %new +``` + +Initial op coverage: + +```text +mulf +addf/addi/subf/subi/muli/divf/minf/maxf +andi/ori/xori/shli/shrui +negf/absf/absi/sqrt/exp/ln/relu/not +fma +``` + +Optional later coverage: + +```text +cmpf/cmpi: + result is mask, so this belongs with ensure_mask_layout. + +select: + requires coordinated data and mask layout/granularity handling. +``` + +The pass must preserve op attributes exactly. + +### 3.3 Existing Cheap Producer Remat + +Keep current behavior for: + +```text +splat pto.vmi.constant +pto.vmi.broadcast +pto.vmi.iota +pto.vmi.create_mask +pto.vmi.create_group_mask +pto.vmi.constant_mask +``` + +These remain direct remat cases and do not require relation queries. + +## 4. Fold Pass Interaction + +Relation remat may create producer-side helpers: + +```text +ensure_layout(load(...), deinterleaved=2) +ensure_layout(group_broadcast_load(...), deinterleaved=2) +``` + +`vmi-layout-fold` should absorb these when the producer can directly materialize +the requested layout. + +Existing load fold should use producer capability, not helper materialization +capability. A load may directly produce a requested contiguous or +deinterleaved=2/4 block_elems=1 result layout even when the helper conversion +from the old load layout to the requested layout would not be a legal register +materialization. + +Add fold coverage if missing for: + +```text +group_broadcast_load result layout requested as deinterleaved=2/block_elems=1 +group_slot_broadcast_load result layout requested as deinterleaved=2/block_elems=1 +``` + +The fold pass must still be local: + +```text +load/group_broadcast_load + ensure_layout + => cloned/retyped producer with requested result layout +``` + +It must not inspect downstream `ext` or `trunc`. + +## 5. Pipeline + +Use a pipeline with fold after remat: + +```text +vmi-layout-assignment + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-rematerialize + -> canonicalize/cse + -> vmi-layout-fold + -> canonicalize/cse + -> vmi-layout-sink-materialization + -> canonicalize/cse + -> pto-validate-vmi-layout-ir + -> vmi-to-vpto +``` + +The first fold handles helpers already emitted by assignment. The second fold +handles helpers exposed by relation-aware remat. + +If later result-side remat and operand-side sink need to alternate for longer +chains, the driver may repeat: + +```text +vmi-layout-rematerialize +canonicalize/cse +vmi-layout-fold +canonicalize/cse +``` + +Keep the first implementation single-pass unless tests prove a fixed point is +needed. + +## 6. Tests + +Add focused lit tests. + +### 6.1 Direct Ext Remat + +Input shape: + +```text +load f16 +extf f16 -> f32 +ensure_layout ext result deinterleaved=2 -> deinterleaved=4 +truncf f32 -> f8 +``` + +Check after: + +```text +vmi-layout-rematerialize +``` + +```text +extf source is deinterleaved=2 +extf result is deinterleaved=4 +old ensure_layout is gone +``` + +Check after: + +```text +vmi-layout-rematerialize -vmi-layout-fold -vmi-to-vpto +``` + +```text +load uses deinterleaved load lowering when fold is available +extf lowers from local source/result arity +``` + +### 6.2 Elementwise Result Helper Remat + +Input shape: + +```text +extf lhs -> deinterleaved=2 +extf rhs -> deinterleaved=2 +mulf lhs, rhs -> deinterleaved=2 +ensure_layout mulf result -> deinterleaved=4 +truncf +``` + +Check: + +```text +mulf is cloned/rebuilt with deinterleaved=4 operands/results +each ext is rematerialized as source deinterleaved=2 -> result deinterleaved=4 +no ensure_layout remains between mulf and truncf +``` + +### 6.3 Multi-Consumer Conflict + +Input shape: + +```text +ext result deinterleaved=2 +consumer A uses deinterleaved=2 +consumer B has ensure_layout to deinterleaved=4 +``` + +Check: + +```text +original ext remains for consumer A +new cloned ext feeds consumer B +no global layout selection is required +``` + +### 6.4 ComputeY1 + +Run: + +```text +pto-test-opt compute_y1_to_fp8_fp16_vmi.pto \ + -vmi-layout-assignment \ + -vmi-layout-rematerialize \ + -vmi-layout-fold \ + -vmi-to-vpto +``` + +Expected: + +```text +x load can become deinterleaved=2 and lower through deinterleaved load support +scale path can keep the E2B-compatible deinterleaved layout +mulf/truncf path has no deinterleaved=2 -> deinterleaved=4 helper immediately +before truncf +``` + +## 7. Non-Goals And Follow-Ups + +Do not implement in this change: + +```text +1. assignment relation propagation. +2. global layout cost model. +3. trunc/narrow relation remat. +4. cloning memory loads in remat without going through explicit fold support. +5. context-sensitive vmi-to-vpto lowering. +``` + +Follow-ups: + +```text +1. Add narrow relation remat for selected trunc patterns after widen is stable. +2. Add select/cmp mask-aware result helper remat. +3. Consider a fixed-point layout optimization pipeline if long chains need it. +4. Move repeated op-family cloning utilities into a shared helper if the pass + grows beyond the first ext/elementwise implementation. +``` diff --git a/include/PTO/IR/VMIOps.td b/include/PTO/IR/VMIOps.td index eb8ad6e98c..152b712339 100644 --- a/include/PTO/IR/VMIOps.td +++ b/include/PTO/IR/VMIOps.td @@ -42,14 +42,14 @@ def PTO_PhysicalVMIPartTypeConstraint : AnyTypeOf< class VMI_Op traits = []> : PTO_Op<"vmi." # mnemonic, traits>; -def VMIConstantOp : VMI_Op<"constant"> { +def VMIConstantOp : VMI_Op<"constant", [Pure]> { let summary = "VMI logical vector constant"; let arguments = (ins AnyAttr:$value); let results = (outs VMI_VRegTypeConstraint:$result); let hasVerifier = 1; } -def VMIBroadcastOp : VMI_Op<"broadcast"> { +def VMIBroadcastOp : VMI_Op<"broadcast", [Pure]> { let summary = "Broadcast one scalar or 1-lane VMI vector to a VMI logical vector"; let arguments = (ins AnyType:$value); let results = (outs VMI_VRegTypeConstraint:$result); @@ -57,7 +57,7 @@ def VMIBroadcastOp : VMI_Op<"broadcast"> { let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)"; } -def VMIIotaOp : VMI_Op<"iota"> { +def VMIIotaOp : VMI_Op<"iota", [Pure]> { let summary = "Create a VMI logical index vector from a scalar base"; let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat], "integer/float scalar">:$base, @@ -68,7 +68,7 @@ def VMIIotaOp : VMI_Op<"iota"> { let assemblyFormat = "$base attr-dict `:` type($base) `->` type($result)"; } -def VMICreateMaskOp : VMI_Op<"create_mask"> { +def VMICreateMaskOp : VMI_Op<"create_mask", [Pure]> { let summary = "Create a VMI logical prefix predicate mask"; let arguments = (ins Index:$active_lanes); let results = (outs VMI_MaskTypeConstraint:$result); @@ -76,7 +76,7 @@ def VMICreateMaskOp : VMI_Op<"create_mask"> { let assemblyFormat = "$active_lanes attr-dict `:` type($active_lanes) `->` type($result)"; } -def VMICreateGroupMaskOp : VMI_Op<"create_group_mask"> { +def VMICreateGroupMaskOp : VMI_Op<"create_group_mask", [Pure]> { let summary = "Create a VMI logical grouped predicate mask"; let description = [{ Creates a mask where lane i is active iff @@ -92,14 +92,14 @@ def VMICreateGroupMaskOp : VMI_Op<"create_group_mask"> { let assemblyFormat = "$active_elems_per_group attr-dict `:` type($active_elems_per_group) `->` type($result)"; } -def VMIConstantMaskOp : VMI_Op<"constant_mask"> { +def VMIConstantMaskOp : VMI_Op<"constant_mask", [Pure]> { let summary = "VMI logical predicate mask constant"; let arguments = (ins AnyAttr:$value); let results = (outs VMI_MaskTypeConstraint:$result); let hasVerifier = 1; } -def VMIMaskAndOp : VMI_Op<"mask_and"> { +def VMIMaskAndOp : VMI_Op<"mask_and", [Pure]> { let summary = "VMI logical predicate mask and"; let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); let results = (outs VMI_MaskTypeConstraint:$result); @@ -107,7 +107,7 @@ def VMIMaskAndOp : VMI_Op<"mask_and"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMaskOrOp : VMI_Op<"mask_or"> { +def VMIMaskOrOp : VMI_Op<"mask_or", [Pure]> { let summary = "VMI logical predicate mask or"; let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); let results = (outs VMI_MaskTypeConstraint:$result); @@ -115,7 +115,7 @@ def VMIMaskOrOp : VMI_Op<"mask_or"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMaskXOrOp : VMI_Op<"mask_xor"> { +def VMIMaskXOrOp : VMI_Op<"mask_xor", [Pure]> { let summary = "VMI logical predicate mask xor"; let arguments = (ins VMI_MaskTypeConstraint:$lhs, VMI_MaskTypeConstraint:$rhs); let results = (outs VMI_MaskTypeConstraint:$result); @@ -123,7 +123,7 @@ def VMIMaskXOrOp : VMI_Op<"mask_xor"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMaskNotOp : VMI_Op<"mask_not"> { +def VMIMaskNotOp : VMI_Op<"mask_not", [Pure]> { let summary = "VMI logical predicate mask not"; let arguments = (ins VMI_MaskTypeConstraint:$source); let results = (outs VMI_MaskTypeConstraint:$result); @@ -131,7 +131,7 @@ def VMIMaskNotOp : VMI_Op<"mask_not"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIAddFOp : VMI_Op<"addf"> { +def VMIAddFOp : VMI_Op<"addf", [Pure]> { let summary = "VMI floating-point elementwise add"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -139,7 +139,7 @@ def VMIAddFOp : VMI_Op<"addf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIAddIOp : VMI_Op<"addi"> { +def VMIAddIOp : VMI_Op<"addi", [Pure]> { let summary = "VMI integer elementwise add"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -147,7 +147,7 @@ def VMIAddIOp : VMI_Op<"addi"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMISubFOp : VMI_Op<"subf"> { +def VMISubFOp : VMI_Op<"subf", [Pure]> { let summary = "VMI floating-point elementwise subtract"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -155,7 +155,7 @@ def VMISubFOp : VMI_Op<"subf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMISubIOp : VMI_Op<"subi"> { +def VMISubIOp : VMI_Op<"subi", [Pure]> { let summary = "VMI integer elementwise subtract"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -163,7 +163,7 @@ def VMISubIOp : VMI_Op<"subi"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMulFOp : VMI_Op<"mulf"> { +def VMIMulFOp : VMI_Op<"mulf", [Pure]> { let summary = "VMI floating-point elementwise multiply"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -171,7 +171,7 @@ def VMIMulFOp : VMI_Op<"mulf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMulIOp : VMI_Op<"muli"> { +def VMIMulIOp : VMI_Op<"muli", [Pure]> { let summary = "VMI integer elementwise multiply"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -179,7 +179,7 @@ def VMIMulIOp : VMI_Op<"muli"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIFmaOp : VMI_Op<"fma"> { +def VMIFmaOp : VMI_Op<"fma", [Pure]> { let summary = "VMI fused floating-point multiply-add"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs, VMI_VRegTypeConstraint:$acc); @@ -188,7 +188,7 @@ def VMIFmaOp : VMI_Op<"fma"> { let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)"; } -def VMIDivFOp : VMI_Op<"divf"> { +def VMIDivFOp : VMI_Op<"divf", [Pure]> { let summary = "VMI floating-point elementwise divide"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -196,7 +196,7 @@ def VMIDivFOp : VMI_Op<"divf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMinFOp : VMI_Op<"minf"> { +def VMIMinFOp : VMI_Op<"minf", [Pure]> { let summary = "VMI floating-point elementwise minimum"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -204,7 +204,7 @@ def VMIMinFOp : VMI_Op<"minf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIMaxFOp : VMI_Op<"maxf"> { +def VMIMaxFOp : VMI_Op<"maxf", [Pure]> { let summary = "VMI floating-point elementwise maximum"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -212,7 +212,7 @@ def VMIMaxFOp : VMI_Op<"maxf"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMINegFOp : VMI_Op<"negf"> { +def VMINegFOp : VMI_Op<"negf", [Pure]> { let summary = "VMI floating-point elementwise negate"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -220,7 +220,7 @@ def VMINegFOp : VMI_Op<"negf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIAbsFOp : VMI_Op<"absf"> { +def VMIAbsFOp : VMI_Op<"absf", [Pure]> { let summary = "VMI floating-point elementwise absolute value"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -228,7 +228,7 @@ def VMIAbsFOp : VMI_Op<"absf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIAbsIOp : VMI_Op<"absi"> { +def VMIAbsIOp : VMI_Op<"absi", [Pure]> { let summary = "VMI integer elementwise absolute value"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -236,7 +236,7 @@ def VMIAbsIOp : VMI_Op<"absi"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMISqrtOp : VMI_Op<"sqrt"> { +def VMISqrtOp : VMI_Op<"sqrt", [Pure]> { let summary = "VMI floating-point elementwise square root"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -244,7 +244,7 @@ def VMISqrtOp : VMI_Op<"sqrt"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIExpOp : VMI_Op<"exp"> { +def VMIExpOp : VMI_Op<"exp", [Pure]> { let summary = "VMI floating-point elementwise exponential"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -252,7 +252,7 @@ def VMIExpOp : VMI_Op<"exp"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMILnOp : VMI_Op<"ln"> { +def VMILnOp : VMI_Op<"ln", [Pure]> { let summary = "VMI floating-point elementwise natural logarithm"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -260,7 +260,7 @@ def VMILnOp : VMI_Op<"ln"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIReluOp : VMI_Op<"relu"> { +def VMIReluOp : VMI_Op<"relu", [Pure]> { let summary = "VMI floating-point elementwise ReLU"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -268,7 +268,7 @@ def VMIReluOp : VMI_Op<"relu"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIAndIOp : VMI_Op<"andi"> { +def VMIAndIOp : VMI_Op<"andi", [Pure]> { let summary = "VMI integer elementwise bitwise and"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -276,7 +276,7 @@ def VMIAndIOp : VMI_Op<"andi"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIOrIOp : VMI_Op<"ori"> { +def VMIOrIOp : VMI_Op<"ori", [Pure]> { let summary = "VMI integer elementwise bitwise or"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -284,7 +284,7 @@ def VMIOrIOp : VMI_Op<"ori"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIXOrIOp : VMI_Op<"xori"> { +def VMIXOrIOp : VMI_Op<"xori", [Pure]> { let summary = "VMI integer elementwise bitwise xor"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -292,7 +292,7 @@ def VMIXOrIOp : VMI_Op<"xori"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIShLIOp : VMI_Op<"shli"> { +def VMIShLIOp : VMI_Op<"shli", [Pure]> { let summary = "VMI integer elementwise left shift"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -300,7 +300,7 @@ def VMIShLIOp : VMI_Op<"shli"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMIShRUIOp : VMI_Op<"shrui"> { +def VMIShRUIOp : VMI_Op<"shrui", [Pure]> { let summary = "VMI unsigned integer elementwise right shift"; let arguments = (ins VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_VRegTypeConstraint:$result); @@ -308,7 +308,7 @@ def VMIShRUIOp : VMI_Op<"shrui"> { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMINotOp : VMI_Op<"not"> { +def VMINotOp : VMI_Op<"not", [Pure]> { let summary = "VMI integer elementwise bitwise not"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -316,7 +316,7 @@ def VMINotOp : VMI_Op<"not"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMICmpFOp : VMI_Op<"cmpf"> { +def VMICmpFOp : VMI_Op<"cmpf", [Pure]> { let summary = "VMI floating-point elementwise compare"; let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_MaskTypeConstraint:$result); @@ -324,7 +324,7 @@ def VMICmpFOp : VMI_Op<"cmpf"> { let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMICmpIOp : VMI_Op<"cmpi"> { +def VMICmpIOp : VMI_Op<"cmpi", [Pure]> { let summary = "VMI integer elementwise compare"; let arguments = (ins StrAttr:$predicate, VMI_VRegTypeConstraint:$lhs, VMI_VRegTypeConstraint:$rhs); let results = (outs VMI_MaskTypeConstraint:$result); @@ -332,7 +332,7 @@ def VMICmpIOp : VMI_Op<"cmpi"> { let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)"; } -def VMISelectOp : VMI_Op<"select"> { +def VMISelectOp : VMI_Op<"select", [Pure]> { let summary = "VMI elementwise select"; let arguments = (ins VMI_MaskTypeConstraint:$mask, VMI_VRegTypeConstraint:$true_value, VMI_VRegTypeConstraint:$false_value); @@ -474,7 +474,7 @@ def VMIDhistOp : VMIHistogramOp<"dhist", def VMIChistOp : VMIHistogramOp<"chist", "VMI full 256-bin cumulative histogram over unsigned 8-bit source lanes">; -def VMIExtFOp : VMI_Op<"extf"> { +def VMIExtFOp : VMI_Op<"extf", [Pure]> { let summary = "VMI floating-point elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -482,7 +482,7 @@ def VMIExtFOp : VMI_Op<"extf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMITruncFOp : VMI_Op<"truncf"> { +def VMITruncFOp : VMI_Op<"truncf", [Pure]> { let summary = "VMI floating-point elementwise truncation"; let arguments = (ins VMI_VRegTypeConstraint:$source, OptionalAttr:$rounding); @@ -491,7 +491,7 @@ def VMITruncFOp : VMI_Op<"truncf"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIFPToSIOp : VMI_Op<"fptosi"> { +def VMIFPToSIOp : VMI_Op<"fptosi", [Pure]> { let summary = "VMI floating-point to signed integer elementwise conversion"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -499,7 +499,7 @@ def VMIFPToSIOp : VMI_Op<"fptosi"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMISIToFPOp : VMI_Op<"sitofp"> { +def VMISIToFPOp : VMI_Op<"sitofp", [Pure]> { let summary = "VMI signed integer to floating-point elementwise conversion"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -507,7 +507,7 @@ def VMISIToFPOp : VMI_Op<"sitofp"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIExtSIOp : VMI_Op<"extsi"> { +def VMIExtSIOp : VMI_Op<"extsi", [Pure]> { let summary = "VMI signed integer elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -515,7 +515,7 @@ def VMIExtSIOp : VMI_Op<"extsi"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIExtUIOp : VMI_Op<"extui"> { +def VMIExtUIOp : VMI_Op<"extui", [Pure]> { let summary = "VMI unsigned integer elementwise extension"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -523,7 +523,7 @@ def VMIExtUIOp : VMI_Op<"extui"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMITruncIOp : VMI_Op<"trunci"> { +def VMITruncIOp : VMI_Op<"trunci", [Pure]> { let summary = "VMI saturating integer elementwise truncation"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -531,7 +531,7 @@ def VMITruncIOp : VMI_Op<"trunci"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIBitcastOp : VMI_Op<"bitcast"> { +def VMIBitcastOp : VMI_Op<"bitcast", [Pure]> { let summary = "VMI bitwise vector reinterpretation"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -679,7 +679,7 @@ def VMIScatterOp : VMI_Op<"scatter", [DeclareOpInterfaceMethods { +def VMIShuffleOp : VMI_Op<"shuffle", [Pure]> { let summary = "VMI static lane shuffle"; let arguments = (ins VMI_VRegTypeConstraint:$source, DenseI64ArrayAttr:$indices); let results = (outs VMI_VRegTypeConstraint:$result); @@ -701,7 +701,7 @@ def VMIChannelMergeOp : VMI_Op<"channel_merge"> { let hasVerifier = 1; } -def VMIEnsureLayoutOp : VMI_Op<"ensure_layout"> { +def VMIEnsureLayoutOp : VMI_Op<"ensure_layout", [Pure]> { let summary = "Internal VMI data layout materialization helper"; let arguments = (ins VMI_VRegTypeConstraint:$source); let results = (outs VMI_VRegTypeConstraint:$result); @@ -709,7 +709,7 @@ def VMIEnsureLayoutOp : VMI_Op<"ensure_layout"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout"> { +def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout", [Pure]> { let summary = "Internal VMI mask layout materialization helper"; let arguments = (ins VMI_MaskTypeConstraint:$source); let results = (outs VMI_MaskTypeConstraint:$result); @@ -717,7 +717,7 @@ def VMIEnsureMaskLayoutOp : VMI_Op<"ensure_mask_layout"> { let assemblyFormat = "$source attr-dict `:` type($source) `->` type($result)"; } -def VMIEnsureMaskGranularityOp : VMI_Op<"ensure_mask_granularity"> { +def VMIEnsureMaskGranularityOp : VMI_Op<"ensure_mask_granularity", [Pure]> { let summary = "Internal VMI mask granularity materialization helper"; let arguments = (ins VMI_MaskTypeConstraint:$source); let results = (outs VMI_MaskTypeConstraint:$result); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 18444f8a36..9f23689230 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -862,10 +862,11 @@ def VMILayoutRematerialize : Pass<"vmi-layout-rematerialize", "ModuleOp"> { let summary = "Rematerialize cheap VMI producers at layout helpers"; let description = [{ Optimizes legal layout-assigned VMI IR by replacing selected ensure_layout, - ensure_mask_layout, and ensure_mask_granularity helpers with cloned cheap - producers that directly create the requested result type. The pass is - deliberately limited to pure construction ops, so memory, control-flow, and - mask-tail proofs remain explicit in the IR. + ensure_mask_layout, and ensure_mask_granularity helpers with cloned + producers that directly create the requested result type. The pass covers + pure construction ops, selected layout-transparent data ops, and dense + widening ext relation rematerialization. Memory, control-flow, and mask-tail + proofs remain explicit in the IR. }]; let constructor = "mlir::pto::createVMILayoutRematerializePass()"; let dependentDialects = ["mlir::cf::ControlFlowDialect", diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index adff29dff9..e957a69058 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -251,6 +251,12 @@ class VMILayoutSupport { getPreferredCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, std::string *reason = nullptr) const; + FailureOr + getWidenSourceLayoutForResultLayout(VMIVRegType sourceType, + VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, + std::string *reason = nullptr) const; + FailureOr getGroupSlotLoadSupport(const VMITargetCapabilityRegistry &capabilities, VMIGroupSlotLoadOp op, diff --git a/lib/PTO/Transforms/VMILayoutFold.cpp b/lib/PTO/Transforms/VMILayoutFold.cpp index 9646041310..a592786f5b 100644 --- a/lib/PTO/Transforms/VMILayoutFold.cpp +++ b/lib/PTO/Transforms/VMILayoutFold.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/IR/VMIUtils.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/VMILayoutSupport.h" @@ -42,6 +43,21 @@ static bool hasSameDataShapeAndElementType(VMIVRegType lhs, VMIVRegType rhs) { lhs.getElementType() == rhs.getElementType(); } +static bool isLoadProducerLayout(VMIVRegType type) { + if (!type) + return false; + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout) + return false; + if (layout.isContiguous()) + return true; + if (!layout.isDeinterleaved() || layout.getBlockElems() != 1 || + (layout.getFactor() != 2 && layout.getFactor() != 4)) + return false; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + static bool isFoldableLoadEnsure(VMIEnsureLayoutOp ensure) { auto load = ensure.getSource().getDefiningOp(); if (!load) @@ -52,8 +68,7 @@ static bool isFoldableLoadEnsure(VMIEnsureLayoutOp ensure) { if (!hasSameDataShapeAndElementType(sourceType, resultType)) return false; - VMILayoutSupport supports; - return succeeded(supports.canMaterializeDataLayout(sourceType, resultType)); + return isLoadProducerLayout(resultType); } static void tryFoldLoadEnsures( @@ -88,6 +103,21 @@ static void tryFoldLoadEnsures( } } +static void +tryFoldNestedEnsureLayout(VMIEnsureLayoutOp ensure, + SmallVectorImpl &maybeDeadEnsures) { + auto inner = ensure.getSource().getDefiningOp(); + if (!inner) + return; + + if (inner.getSource().getType() != ensure.getResult().getType()) + return; + + ensure.getResult().replaceAllUsesWith(inner.getSource()); + maybeDeadEnsures.push_back(ensure); + maybeDeadEnsures.push_back(inner); +} + static bool isFoldableStoreEnsure(VMIEnsureLayoutOp ensure) { auto sourceType = dyn_cast(ensure.getSource().getType()); auto resultType = dyn_cast(ensure.getResult().getType()); @@ -158,6 +188,10 @@ struct VMILayoutFoldPass tryFoldLoadEnsures(load, maybeDeadEnsures); }); + module.walk([&](VMIEnsureLayoutOp ensure) { + tryFoldNestedEnsureLayout(ensure, maybeDeadEnsures); + }); + module.walk([&](Operation *op) { if (auto store = dyn_cast(op)) tryFoldEnsureLayoutIntoOperand(store.getValueMutable(), diff --git a/lib/PTO/Transforms/VMILayoutRematerialize.cpp b/lib/PTO/Transforms/VMILayoutRematerialize.cpp index 4f230d4189..5a3d1e48ec 100644 --- a/lib/PTO/Transforms/VMILayoutRematerialize.cpp +++ b/lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -13,6 +13,7 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" +#include "PTO/Transforms/VMILayoutSupport.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,6 +22,8 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" +#include + namespace mlir { namespace pto { #define GEN_PASS_DEF_VMILAYOUTREMATERIALIZE @@ -41,6 +44,150 @@ static bool hasConcreteLayout(VMIMaskType type) { return type && static_cast(type.getLayoutAttr()); } +static Value materializeDataLayout(Value value, VMIVRegType resultType, + Location loc, OpBuilder &builder) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType || sourceType == resultType) + return value; + + return builder.create(loc, resultType, value).getResult(); +} + +template +static std::optional rematerializeWidenExt(ExtOp op, + VMIVRegType resultType, + Location loc, + OpBuilder &builder) { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType || !hasConcreteLayout(resultType)) + return std::nullopt; + + VMILayoutSupport supports; + FailureOr sourceLayout = + supports.getWidenSourceLayoutForResultLayout( + sourceType, resultType, resultType.getLayoutAttr()); + if (failed(sourceLayout)) + return std::nullopt; + + auto rematSourceType = + VMIVRegType::get(sourceType.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), *sourceLayout); + Value rematSource = materializeDataLayout(op.getSource(), rematSourceType, + loc, builder); + return builder.create(loc, resultType, rematSource).getResult(); +} + +static std::optional +rematerializeBinaryDataOp(Operation *op, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto rebuild = [&](auto typedOp) -> std::optional { + auto lhsType = dyn_cast(typedOp.getLhs().getType()); + auto rhsType = dyn_cast(typedOp.getRhs().getType()); + if (!lhsType || !rhsType) + return std::nullopt; + auto lhsResultType = + VMIVRegType::get(lhsType.getContext(), lhsType.getElementCount(), + lhsType.getElementType(), resultType.getLayoutAttr()); + auto rhsResultType = + VMIVRegType::get(rhsType.getContext(), rhsType.getElementCount(), + rhsType.getElementType(), resultType.getLayoutAttr()); + Value lhs = + materializeDataLayout(typedOp.getLhs(), lhsResultType, loc, builder); + Value rhs = + materializeDataLayout(typedOp.getRhs(), rhsResultType, loc, builder); + return builder + .create>(loc, resultType, lhs, rhs) + .getResult(); + }; + + if (auto addf = dyn_cast(op)) + return rebuild(addf); + if (auto addi = dyn_cast(op)) + return rebuild(addi); + if (auto subf = dyn_cast(op)) + return rebuild(subf); + if (auto subi = dyn_cast(op)) + return rebuild(subi); + if (auto mulf = dyn_cast(op)) + return rebuild(mulf); + if (auto muli = dyn_cast(op)) + return rebuild(muli); + if (auto divf = dyn_cast(op)) + return rebuild(divf); + if (auto minf = dyn_cast(op)) + return rebuild(minf); + if (auto maxf = dyn_cast(op)) + return rebuild(maxf); + if (auto andi = dyn_cast(op)) + return rebuild(andi); + if (auto ori = dyn_cast(op)) + return rebuild(ori); + if (auto xori = dyn_cast(op)) + return rebuild(xori); + if (auto shli = dyn_cast(op)) + return rebuild(shli); + if (auto shrui = dyn_cast(op)) + return rebuild(shrui); + return std::nullopt; +} + +static std::optional +rematerializeUnaryDataOp(Operation *op, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto rebuild = [&](auto typedOp) -> std::optional { + auto sourceType = dyn_cast(typedOp.getSource().getType()); + if (!sourceType) + return std::nullopt; + auto sourceResultType = VMIVRegType::get( + sourceType.getContext(), sourceType.getElementCount(), + sourceType.getElementType(), resultType.getLayoutAttr()); + Value source = materializeDataLayout(typedOp.getSource(), sourceResultType, + loc, builder); + return builder + .create>(loc, resultType, source) + .getResult(); + }; + + if (auto negf = dyn_cast(op)) + return rebuild(negf); + if (auto absf = dyn_cast(op)) + return rebuild(absf); + if (auto absi = dyn_cast(op)) + return rebuild(absi); + if (auto sqrt = dyn_cast(op)) + return rebuild(sqrt); + if (auto exp = dyn_cast(op)) + return rebuild(exp); + if (auto ln = dyn_cast(op)) + return rebuild(ln); + if (auto relu = dyn_cast(op)) + return rebuild(relu); + if (auto notOp = dyn_cast(op)) + return rebuild(notOp); + return std::nullopt; +} + +static std::optional +rematerializeFma(VMIFmaOp fma, VMIVRegType resultType, Location loc, + OpBuilder &builder) { + auto lhsType = dyn_cast(fma.getLhs().getType()); + auto rhsType = dyn_cast(fma.getRhs().getType()); + auto accType = dyn_cast(fma.getAcc().getType()); + if (!lhsType || !rhsType || !accType) + return std::nullopt; + auto makeType = [&](VMIVRegType type) { + return VMIVRegType::get(type.getContext(), type.getElementCount(), + type.getElementType(), resultType.getLayoutAttr()); + }; + Value lhs = materializeDataLayout(fma.getLhs(), makeType(lhsType), loc, + builder); + Value rhs = materializeDataLayout(fma.getRhs(), makeType(rhsType), loc, + builder); + Value acc = materializeDataLayout(fma.getAcc(), makeType(accType), loc, + builder); + return builder.create(loc, resultType, lhs, rhs, acc).getResult(); +} + static std::optional rematerializeDataProducer(Value value, VMIVRegType resultType, Location loc, @@ -48,6 +195,22 @@ static std::optional rematerializeDataProducer(Value value, if (!hasConcreteLayout(resultType)) return std::nullopt; + if (auto extf = value.getDefiningOp()) + return rematerializeWidenExt(extf, resultType, loc, builder); + if (auto extsi = value.getDefiningOp()) + return rematerializeWidenExt(extsi, resultType, loc, builder); + if (auto extui = value.getDefiningOp()) + return rematerializeWidenExt(extui, resultType, loc, builder); + + if (Operation *op = value.getDefiningOp()) { + if (auto fma = dyn_cast(op)) + return rematerializeFma(fma, resultType, loc, builder); + if (auto result = rematerializeBinaryDataOp(op, resultType, loc, builder)) + return result; + if (auto result = rematerializeUnaryDataOp(op, resultType, loc, builder)) + return result; + } + if (auto constant = value.getDefiningOp()) { auto denseAttr = dyn_cast(constant.getValue()); if (denseAttr && denseAttr.isSplat()) @@ -138,29 +301,33 @@ struct VMILayoutRematerializePass void runOnOperation() override { ModuleOp module = getOperation(); - SmallVector helpers; - module.walk([&](Operation *op) { - if (isa(op)) - helpers.push_back(op); - }); - - for (Operation *op : helpers) { - if (op->getBlock() == nullptr) - continue; - - if (auto ensure = dyn_cast(op)) { - tryReplaceDataEnsure(ensure); - continue; + bool changed = true; + while (changed) { + changed = false; + SmallVector helpers; + module.walk([&](Operation *op) { + if (isa(op)) + helpers.push_back(op); + }); + + for (Operation *op : helpers) { + if (op->getBlock() == nullptr) + continue; + + if (auto ensure = dyn_cast(op)) { + changed |= tryReplaceDataEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) { + changed |= tryReplaceMaskEnsure(ensure); + continue; + } + + if (auto ensure = dyn_cast(op)) + changed |= tryReplaceMaskEnsure(ensure); } - - if (auto ensure = dyn_cast(op)) { - tryReplaceMaskEnsure(ensure); - continue; - } - - if (auto ensure = dyn_cast(op)) - tryReplaceMaskEnsure(ensure); } } }; diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index c6b64f06b4..ac0627c29f 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -476,6 +476,50 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( return fail("supports only 8/16-bit <-> 32-bit dense cast layout facts"); } +FailureOr +VMILayoutSupport::getWidenSourceLayoutForResultLayout( + VMIVRegType sourceType, VMIVRegType resultType, + VMILayoutAttr requestedResultLayout, std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + if (!requestedResultLayout) + return fail("requires requested result layout"); + if (sourceType.getElementCount() != resultType.getElementCount()) + return fail("requires source/result lane count to match"); + if (requestedResultLayout.isGroupSlots()) + return fail("dense widen relation does not support group_slots layout"); + if (!requestedResultLayout.isContiguous() && + (!requestedResultLayout.isDeinterleaved() || + requestedResultLayout.getBlockElems() != 1)) + return fail("requires contiguous or deinterleaved result layout with " + "block_elems=1"); + + FailureOr fact = + getPreferredCastLayoutFact(sourceType, resultType, reason); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && + fact->kind != VMICastLayoutKind::Widen4x)) + return fail("requires supported 8/16-bit to 32-bit widen cast"); + + int64_t resultFactor = requestedResultLayout.isDeinterleaved() + ? requestedResultLayout.getFactor() + : 1; + if (resultFactor % fact->factor != 0) + return fail("requested result layout factor is not divisible by widen " + "factor"); + + int64_t sourceFactor = resultFactor / fact->factor; + if (sourceFactor == 1) + return VMILayoutAttr::getContiguous(sourceType.getContext()); + if (sourceFactor == 2 || sourceFactor == 4) + return VMILayoutAttr::getDeinterleaved(sourceType.getContext(), + sourceFactor, /*blockElems=*/1); + return fail("derived source layout factor is unsupported"); +} + FailureOr VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, std::string *reason) const { @@ -1252,10 +1296,13 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { failed(resultArity)) return fail("requires assigned source/result layouts and computable " "physical arity"); - if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + !(sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && + sourceLayout.getBlockElems() == 1)) || !resultType.getElementType().isF32()) - return fail("requires contiguous source layout and deinterleaved f32 " - "result layout"); + return fail("requires contiguous or deinterleaved source layout and " + "deinterleaved f32 result layout with block_elems=1"); FailureOr fact = getPreferredCastLayoutFact(sourceType, resultType, reason); @@ -1264,17 +1311,19 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { return fail("unsupported extf source element width, result factor, or " "physical arity"); - if (fact->kind == VMICastLayoutKind::Widen2x && - resultLayout.getFactor() == fact->factor && - *resultArity == fact->factor * *sourceArity) + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (resultLayout.getFactor() != sourceFactor * fact->factor || + *resultArity != fact->factor * *sourceArity) + return fail("unsupported extf source/result layout factor or physical " + "arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x) return VMIExtFSupport{VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; - if (fact->kind == VMICastLayoutKind::Widen4x && - resultLayout.getFactor() == fact->factor && - *resultArity == fact->factor * *sourceArity) + if (fact->kind == VMICastLayoutKind::Widen4x) return VMIExtFSupport{VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; - return fail("unsupported extf source element width, result factor, or " - "physical arity"); + return fail("unsupported extf source element width"); } template @@ -1324,11 +1373,14 @@ static FailureOr getExtISupportImpl(OpT op, "16-bit"); } - if (!sourceLayout.isContiguous() || !resultLayout.isDeinterleaved() || + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + !(sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && + sourceLayout.getBlockElems() == 1)) || !isa(sourceType.getElementType()) || !isa(resultType.getElementType())) - return fail("requires contiguous integer source layout and deinterleaved " - "integer result layout"); + return fail("requires contiguous or deinterleaved integer source layout " + "and deinterleaved integer result layout with block_elems=1"); FailureOr fact = VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, @@ -1338,17 +1390,19 @@ static FailureOr getExtISupportImpl(OpT op, return fail("unsupported integer extension source/result element width, " "result factor, or physical arity"); - if (fact->kind == VMICastLayoutKind::Widen2x && - resultLayout.getFactor() == fact->factor && - *resultArity == fact->factor * *sourceArity) + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (resultLayout.getFactor() != sourceFactor * fact->factor || + *resultArity != fact->factor * *sourceArity) + return fail("unsupported integer extension source/result layout factor or " + "physical arity"); + + if (fact->kind == VMICastLayoutKind::Widen2x) return VMIExtISupport{VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; - if (fact->kind == VMICastLayoutKind::Widen4x && - resultLayout.getFactor() == fact->factor && - *resultArity == fact->factor * *sourceArity) + if (fact->kind == VMICastLayoutKind::Widen4x) return VMIExtISupport{VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; - return fail("unsupported integer extension source/result element width, " - "result factor, or physical arity"); + return fail("unsupported integer extension source/result element width"); } FailureOr diff --git a/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto index 4e14381743..fbabafa3e3 100644 --- a/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto +++ b/test/lit/vmi/vmi_layout_gate_extf_support_invalid.pto @@ -12,7 +12,7 @@ module { func.func @vmi_layout_gate_extf_support_invalid( %source: !pto.vmi.vreg<128xf16, #pto.vmi.layout>) { // CHECK: VMI-LAYOUT-CONTRACT: pto.vmi.extf has no registered layout support - // CHECK-SAME: requires contiguous source layout and deinterleaved f32 result layout + // CHECK-SAME: requires contiguous or deinterleaved source layout and deinterleaved f32 result layout with block_elems=1 // CHECK: note: see current operation: %{{.*}} = "pto.vmi.extf" %out = pto.vmi.extf %source : !pto.vmi.vreg<128xf16, #pto.vmi.layout> diff --git a/test/lit/vmi/vmi_layout_rematerialize_relation.pto b/test/lit/vmi/vmi_layout_rematerialize_relation.pto new file mode 100644 index 0000000000..bc631893d5 --- /dev/null +++ b/test/lit/vmi/vmi_layout_rematerialize_relation.pto @@ -0,0 +1,131 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-rematerialize -canonicalize | FileCheck %s --check-prefix=REMAT +// RUN: pto-test-opt %s -vmi-layout-rematerialize -canonicalize -vmi-layout-fold -canonicalize -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_rematerialize_direct_ext( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d2 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %x32_d4 = pto.vmi.ensure_layout %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y = pto.vmi.truncf %x32_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_mulf_chain( + %lhs: !pto.ptr, %rhs: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> { + %lhs16 = pto.vmi.load %lhs[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %rhs16 = pto.vmi.load %rhs[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %lhs32_d2 = pto.vmi.extf %lhs16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %rhs32_d2 = pto.vmi.extf %rhs16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mul_d2 = pto.vmi.mulf %lhs32_d2, %rhs32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout>, + !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %mul_d4 = pto.vmi.ensure_layout %mul_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y = pto.vmi.truncf %mul_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y : !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } + + func.func @vmi_layout_rematerialize_ext_multi_consumer( + %src: !pto.ptr, %off: index) + -> (!pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout>) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d2 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y16 = pto.vmi.truncf %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> + %x32_d4 = pto.vmi.ensure_layout %x32_d2 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + %y8 = pto.vmi.truncf %x32_d4 + : !pto.vmi.vreg<256xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + return %y16, %y8 + : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, + !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> + } +} + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_direct_ext( +// REMAT: %[[X16:.*]] = pto.vmi.load +// REMAT: %[[X16_D2:.*]] = pto.vmi.ensure_layout %[[X16]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// REMAT: %[[X32_D4:.*]] = pto.vmi.extf %[[X16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_mulf_chain( +// REMAT-DAG: %[[LHS_D2:.*]] = pto.vmi.ensure_layout +// REMAT-DAG: %[[RHS_D2:.*]] = pto.vmi.ensure_layout +// REMAT-DAG: %[[LHS_D4:.*]] = pto.vmi.extf %[[LHS_D2]] +// REMAT-DAG: %[[RHS_D4:.*]] = pto.vmi.extf %[[RHS_D2]] +// REMAT: %[[MUL_D4:.*]] = pto.vmi.mulf %[[LHS_D4]], %[[RHS_D4]] +// REMAT-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[MUL_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// REMAT-LABEL: func.func @vmi_layout_rematerialize_ext_multi_consumer( +// REMAT: %[[X16:.*]] = pto.vmi.load +// REMAT: %[[X32_D2:.*]] = pto.vmi.extf %[[X16]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D2]] +// REMAT: %[[X16_D2:.*]] = pto.vmi.ensure_layout %[[X16]] +// REMAT: %[[X32_D4:.*]] = pto.vmi.extf %[[X16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT: pto.vmi.truncf %[[X32_D4]] +// REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_layout_rematerialize_direct_ext( +// LOWER: pto.vldsx2 +// LOWER-SAME: "DINTLV_B16" +// LOWER-COUNT-2: pto.vcvt {{.*}} {part = "EVEN"} +// LOWER-COUNT-2: pto.vcvt {{.*}} {part = "ODD"} +// LOWER: pto.vcvt {{.*}} {part = "P0" +// LOWER: pto.vcvt {{.*}} {part = "P1" +// LOWER: pto.vcvt {{.*}} {part = "P2" +// LOWER: pto.vcvt {{.*}} {part = "P3" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_rematerialize_mulf_chain( +// LOWER-COUNT-2: pto.vldsx2 +// LOWER-SAME: "DINTLV_B16" +// LOWER: pto.vmul +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_op_verifier_basic.pto b/test/lit/vmi/vmi_op_verifier_basic.pto index 4ff2199aa7..38575c8a47 100644 --- a/test/lit/vmi/vmi_op_verifier_basic.pto +++ b/test/lit/vmi/vmi_op_verifier_basic.pto @@ -13,7 +13,18 @@ module { %ptr: !pto.ptr, %layouted: !pto.vmi.vreg<128xf32, #pto.vmi.layout>, %mask_b16: !pto.vmi.mask<128xb16, #pto.vmi.layout>, - %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) { + %mask_b32: !pto.vmi.mask<128xb32, #pto.vmi.layout>) + -> (!pto.vmi.vreg<128xf32>, + !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf16>, + !pto.vmi.vreg<4xf32>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %f32 = arith.constant 1.000000e+00 : f32 @@ -84,7 +95,19 @@ module { %icmp = pto.vmi.cmpi "slt", %iv0, %iv0 : !pto.vmi.vreg<128xi32>, !pto.vmi.vreg<128xi32> -> !pto.vmi.mask<128xpred> - return + return %add, %cmp, %sel, %trunc, %merged, %layouted_trunc, %mask_layout, + %mask_granularity, %packed, %iadd, %icmp + : !pto.vmi.vreg<128xf32>, + !pto.vmi.mask<128xpred>, + !pto.vmi.vreg<128xf32>, + !pto.vmi.vreg<128xf16>, + !pto.vmi.vreg<4xf32>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.mask<128xb32, #pto.vmi.layout>, + !pto.vmi.vreg<128xf32, #pto.vmi.layout>, + !pto.vmi.vreg<128xi32>, + !pto.vmi.mask<128xpred> } } From 5baa23883dc033551f6e29b4a2117d9bfe451436 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 17:48:07 +0800 Subject: [PATCH 42/54] Optimize equivalent VPTO vcvt normalization --- .../vmi-layout-assignment-implementation.md | 2 - .../vmi-layout-assignment-lowering-design.md | 4 - ...ayout-relation-rematerialization-design.md | 3 - ...lation-rematerialization-implementation.md | 2 - include/PTO/Transforms/Passes.h | 1 + include/PTO/Transforms/Passes.td | 15 +++ lib/PTO/Transforms/CMakeLists.txt | 1 + .../VPTONormalizeEquivalentVcvt.cpp | 96 +++++++++++++++++++ .../vpto/vpto_normalize_equivalent_vcvt.pto | 92 ++++++++++++++++++ tools/ptoas/ptoas.cpp | 6 +- 10 files changed, 209 insertions(+), 13 deletions(-) create mode 100644 lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp create mode 100644 test/lit/vpto/vpto_normalize_equivalent_vcvt.pto diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 4c8d52a97d..20f6ac92ee 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -14,8 +14,6 @@ Recommended pass pipeline: pto-validate-vmi-ir -> vmi-layout-assignment // hard legalization baseline -> canonicalize/cse - -> vmi-layout-fold // optional optimization - -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse -> vmi-layout-fold // optional optimization over remat-exposed helpers diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 98881fe6a8..73e2a7a118 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -11,16 +11,12 @@ VMI surface IR -> pto-validate-vmi-ir -> vmi-layout-assignment // hard legalization baseline -> canonicalize/cse - -> vmi-layout-fold // optional optimization - -> canonicalize/cse -> vmi-layout-rematerialize // optional optimization -> canonicalize/cse -> vmi-layout-fold // optional optimization over remat-exposed helpers -> canonicalize/cse -> vmi-layout-sink-materialization // optional optimization -> canonicalize/cse - -> optional later layout optimization passes - -> canonicalize/cse -> vmi-legalize-arith-select -> pto-validate-vmi-layout-ir -> layout-assigned and optimized VMI IR diff --git a/docs/designs/vmi-layout-relation-rematerialization-design.md b/docs/designs/vmi-layout-relation-rematerialization-design.md index dc0a74907c..8e59b91b65 100644 --- a/docs/designs/vmi-layout-relation-rematerialization-design.md +++ b/docs/designs/vmi-layout-relation-rematerialization-design.md @@ -141,8 +141,6 @@ load contiguous + ensure_layout to deinterleaved=2 ```text vmi-layout-assignment - -> canonicalize/cse - -> vmi-layout-fold -> canonicalize/cse -> vmi-layout-rematerialize -> canonicalize/cse @@ -236,4 +234,3 @@ relation-aware remat 必须在 `vmi-to-vpto` 前把 IR 显式改写为: ``` 之后 lowering 只消费这个 local shape。 - diff --git a/docs/designs/vmi-layout-relation-rematerialization-implementation.md b/docs/designs/vmi-layout-relation-rematerialization-implementation.md index 605964799b..540d7bc5e2 100644 --- a/docs/designs/vmi-layout-relation-rematerialization-implementation.md +++ b/docs/designs/vmi-layout-relation-rematerialization-implementation.md @@ -259,8 +259,6 @@ Use a pipeline with fold after remat: ```text vmi-layout-assignment - -> canonicalize/cse - -> vmi-layout-fold -> canonicalize/cse -> vmi-layout-rematerialize -> canonicalize/cse diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 4ac1eaa5bb..cc73821a30 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -101,6 +101,7 @@ std::unique_ptr createPTOFusionLoadStoreElisionPass(); std::unique_ptr createPTOFlattenFusionRegionPass(); std::unique_ptr createVPTOPtrNormalizePass(); std::unique_ptr createVPTOPtrCastCleanupPass(); +std::unique_ptr createVPTONormalizeEquivalentVcvtPass(); LogicalResult validateVPTOAuthoringIR(ModuleOp module, llvm::raw_ostream *diagOS = nullptr); LogicalResult validateVPTOEmissionIR(ModuleOp module, diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 9f23689230..f000a50060 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -1082,4 +1082,19 @@ def VPTOPtrCastCleanup "mlir::memref::MemRefDialect"]; } +def VPTONormalizeEquivalentVcvt + : Pass<"vpto-normalize-equivalent-vcvt", "ModuleOp"> { + let summary = "Normalize equivalent VPTO vcvt part selections"; + let description = [{ + Rewrites `pto.vcvt` operations whose `EVEN` and `ODD` part selections are + provably equivalent into the canonical `EVEN` form. The pass currently + recognizes all-true masked narrow-to-wide conversions from VPTO values with + pair-wise equivalent input lanes, such as scalar/vector broadcasts and + selected broadcast load distributions. A following CSE pass can then merge + duplicate conversions. + }]; + let constructor = "mlir::pto::createVPTONormalizeEquivalentVcvtPass()"; + let dependentDialects = ["mlir::pto::PTODialect"]; +} + #endif // MLIR_DIALECT_PTO_PASSES diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index ddf73ba356..116d4f10cb 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -30,6 +30,7 @@ add_mlir_dialect_library(PTOTransforms VPTOLLVMEmitterHelper.cpp VPTOPtrNormalize.cpp VPTOPtrCastCleanup.cpp + VPTONormalizeEquivalentVcvt.cpp VPTOExpandWrapperOps.cpp PTOVPTOPtrBoundary.cpp VPTOBufferMaterialization.cpp diff --git a/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp b/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp new file mode 100644 index 0000000000..22dbbfc094 --- /dev/null +++ b/lib/PTO/Transforms/VPTONormalizeEquivalentVcvt.cpp @@ -0,0 +1,96 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VPTONORMALIZEEQUIVALENTVCVT +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static bool isOddPart(StringRef part) { + return part == "ODD" || part == "PART_ODD"; +} + +static bool isAllTrueMask(Value mask) { + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + if (auto op = mask.getDefiningOp()) + return op.getPattern() == "PAT_ALL"; + return false; +} + +static bool isPairEquivalentLoadDist(StringRef dist) { + return dist == "BRC_B8" || dist == "BRC_B16" || dist == "BRC_B32" || + dist == "US_B8" || dist == "US_B16" || dist == "E2B_B16" || + dist == "E2B_B32"; +} + +static bool hasEvenOddEquivalentLanes(Value value) { + if (value.getDefiningOp()) + return true; + + auto load = value.getDefiningOp(); + if (!load || value != load.getResult()) + return false; + + std::optional dist = load.getDist(); + return dist && isPairEquivalentLoadDist(*dist); +} + +static bool isNarrowToWideVcvt(VcvtOp op) { + auto inputType = dyn_cast(op.getInput().getType()); + auto resultType = dyn_cast(op.getResult().getType()); + if (!inputType || !resultType) + return false; + + unsigned inputBits = getPTOStorageElemBitWidth(inputType.getElementType()); + unsigned resultBits = getPTOStorageElemBitWidth(resultType.getElementType()); + return inputBits != 0 && resultBits != 0 && inputBits < resultBits; +} + +struct VPTONormalizeEquivalentVcvtPass + : public pto::impl::VPTONormalizeEquivalentVcvtBase< + VPTONormalizeEquivalentVcvtPass> { + void runOnOperation() override { + StringAttr even = StringAttr::get(&getContext(), "EVEN"); + + getOperation().walk([&](VcvtOp op) { + std::optional part = op.getPart(); + if (!part || !isOddPart(*part)) + return; + if (!isNarrowToWideVcvt(op)) + return; + if (!isAllTrueMask(op.getMask())) + return; + if (!hasEvenOddEquivalentLanes(op.getInput())) + return; + + op.setPartAttr(even); + }); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVPTONormalizeEquivalentVcvtPass() { + return std::make_unique(); +} diff --git a/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto b/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto new file mode 100644 index 0000000000..4b610b1ec9 --- /dev/null +++ b/test/lit/vpto/vpto_normalize_equivalent_vcvt.pto @@ -0,0 +1,92 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vpto-normalize-equivalent-vcvt -canonicalize -cse | FileCheck %s + +module { + func.func @e2b_load(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] {dist = "E2B_B16"} + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @brc_load(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] {dist = "BRC_B16"} + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @scalar_broadcast(%seed: f16) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %broadcast = pto.vbr %seed : f16 -> !pto.vreg<128xf16> + %even = pto.vcvt %broadcast, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %broadcast, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @normal_load_is_not_changed(%src: !pto.ptr, %off: index) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_ALL" : !pto.mask + %load = pto.vlds %src[%off] + : !pto.ptr -> !pto.vreg<128xf16> + %even = pto.vcvt %load, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %load, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } + + func.func @masked_broadcast_is_not_changed(%seed: f16) + -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { + %mask = pto.pset_b16 "PAT_VL1" : !pto.mask + %broadcast = pto.vbr %seed : f16 -> !pto.vreg<128xf16> + %even = pto.vcvt %broadcast, %mask {part = "EVEN"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + %odd = pto.vcvt %broadcast, %mask {part = "ODD"} + : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + return %even, %odd : !pto.vreg<64xf32>, !pto.vreg<64xf32> + } +} + +// CHECK-LABEL: func.func @e2b_load +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @brc_load +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @scalar_broadcast +// CHECK: %[[CVT:.*]] = pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vcvt +// CHECK: return %[[CVT]], %[[CVT]] + +// CHECK-LABEL: func.func @normal_load_is_not_changed +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> + +// CHECK-LABEL: func.func @masked_broadcast_is_not_changed +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 0f30adfb9c..536d1537f5 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1545,6 +1545,7 @@ static void prepareVPTOForEmission(PassManager &pm) { kernelModulePM.addPass(createCSEPass()); kernelModulePM.addPass(pto::createVPTOPtrNormalizePass()); kernelModulePM.addPass(pto::createVPTOPtrCastCleanupPass()); + kernelModulePM.addPass(pto::createVPTONormalizeEquivalentVcvtPass()); kernelModulePM.addPass(createReconcileUnrealizedCastsPass()); kernelModulePM.addNestedPass( createVPTOExpandWrapperOpsPass()); @@ -1809,10 +1810,10 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.addPass(pto::createVMILayoutAssignmentPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(pto::createVMILayoutFoldPass()); + pm.addPass(pto::createVMILayoutRematerializePass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - pm.addPass(pto::createVMILayoutRematerializePass()); + pm.addPass(pto::createVMILayoutFoldPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); pm.addPass(pto::createVMILayoutSinkMaterializationPass()); @@ -1821,6 +1822,7 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.addPass(pto::createVMILegalizeArithSelectPass()); pm.addPass(pto::createPTOValidateVMILayoutIRPass()); pm.addPass(pto::createVMIToVPTOPass()); + pm.addPass(pto::createVPTONormalizeEquivalentVcvtPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); if (failed(applyConfiguredPassManagerCLOptions(pm, From a9c80097c55621bb0789c1b366a9b2137d57f62a Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 19:24:32 +0800 Subject: [PATCH 43/54] Run public LICM in VMI pipeline --- test/lit/vmi/vmi_ptoas_cli_licm.pto | 39 +++++++++++++++++++++++++++++ tools/ptoas/ptoas.cpp | 2 ++ 2 files changed, 41 insertions(+) create mode 100644 test/lit/vmi/vmi_ptoas_cli_licm.pto diff --git a/test/lit/vmi/vmi_ptoas_cli_licm.pto b/test/lit/vmi/vmi_ptoas_cli_licm.pto new file mode 100644 index 0000000000..b6f7261bce --- /dev/null +++ b/test/lit/vmi/vmi_ptoas_cli_licm.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - | FileCheck %s + +module attributes {pto.target_arch = "a5"} { + module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind} { + func.func @vmi_ptoas_cli_licm( + %src: !pto.ptr, + %dst: !pto.ptr, + %count: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + pto.vecscope { + scf.for %i = %c0 to %count step %c1 { + %x16 = pto.vmi.load %src[%i] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + pto.vmi.store %x32, %dst[%i] + : !pto.vmi.vreg<128xf32>, !pto.ptr + } + } + return + } + } +} + +// CHECK-LABEL: func.func @vmi_ptoas_cli_licm +// CHECK: pto.vecscope +// CHECK: %[[MASK:.*]] = pto.pset_b16 "PAT_ALL" : !pto.mask +// CHECK: scf.for +// CHECK-NOT: pto.pset_b16 +// CHECK: pto.vcvt {{.*}}, %[[MASK]] diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 536d1537f5..0c69864f66 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -37,6 +37,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Target/Cpp/CppEmitter.h" +#include "mlir/Transforms/Passes.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/FileSystem.h" // [Fix] Required for OF_None @@ -1823,6 +1824,7 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { pm.addPass(pto::createPTOValidateVMILayoutIRPass()); pm.addPass(pto::createVMIToVPTOPass()); pm.addPass(pto::createVPTONormalizeEquivalentVcvtPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); if (failed(applyConfiguredPassManagerCLOptions(pm, From f6f5aaa049879f5ca473cf9d1471e72fb740feaf Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 19:50:35 +0800 Subject: [PATCH 44/54] Support VMI u8 to u16 integer extension --- lib/PTO/Transforms/VMILayoutSupport.cpp | 6 ++++-- lib/PTO/Transforms/VMIToVPTO.cpp | 20 +++++++++++--------- test/lit/vmi/vmi_to_vpto_integer_casts.pto | 21 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index ac0627c29f..2d17ed90b7 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -437,7 +437,8 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( fact.sourceBits = sourceBits; fact.resultBits = resultBits; - if (resultBits == 32 && sourceBits == 16) { + if ((sourceBits == 8 || sourceBits == 16) && + resultBits == sourceBits * 2) { fact.kind = VMICastLayoutKind::Widen2x; fact.factor = 2; fact.sourceLayout = VMILayoutAttr::getContiguous(ctx); @@ -473,7 +474,8 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( return fact; } - return fail("supports only 8/16-bit <-> 32-bit dense cast layout facts"); + return fail("supports only 8/16-bit integer widening and 32-bit integer " + "narrowing dense cast layout facts"); } FailureOr diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index e5ddd0a366..b45a792aac 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -7265,9 +7265,8 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { auto resultVRegType = dyn_cast(resultType); if (!resultVRegType || !isa(resultVRegType.getElementType()) || - (resultVRegTypes.empty() ? pto::getPTOStorageElemBitWidth( - resultVRegType.getElementType()) != 32 - : resultVRegType != resultVRegTypes.front())) + (!resultVRegTypes.empty() && + resultVRegType != resultVRegTypes.front())) return rewriter.notifyMatchFailure( op, "unsupported physical integer extension result type"); resultVRegTypes.push_back(resultVRegType); @@ -7275,13 +7274,16 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { unsigned sourceBits = pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + unsigned resultBits = pto::getPTOStorageElemBitWidth( + resultVRegTypes.front().getElementType()); ArrayRef parts; int64_t factor = 0; - if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { + if (resultBits == sourceBits * 2 && + resultTypes.size() == 2 * sourceParts.size()) { static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; parts = kEvenOddParts; factor = 2; - } else if (sourceBits == 8 && + } else if (resultBits == sourceBits * 4 && resultTypes.size() == 4 * sourceParts.size()) { static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; parts = kPacked4Parts; @@ -9394,8 +9396,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, extsi.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.extsi supports contiguous signed/signless 8-bit or " - "16-bit integer physical source chunks to 32-bit integer " - "deinterleaved=4/2 results, or matching " + "16-bit integer physical source chunks to 2x/4x wider integer " + "deinterleaved results, or matching " "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " "32-bit integer result (" << reason << ")"; @@ -9410,8 +9412,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, extui.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.extui supports contiguous unsigned 8-bit or 16-bit " - "integer physical source chunks to unsigned 32-bit integer " - "deinterleaved=4/2 results, or matching " + "integer physical source chunks to 2x/4x wider unsigned integer " + "deinterleaved results, or matching " "group_slots(num_groups=G, slots=8) 8/16-bit integer source to " "32-bit integer result (" << reason << ")"; diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto index caf4a09525..56a42fb2b6 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_casts.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -27,6 +27,18 @@ module { !pto.vreg<64xui32>, !pto.vreg<64xui32> } + func.func @vmi_to_vpto_extui_u8_to_u16( + %input: !pto.vmi.vreg<256xui8, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %wide = pto.vmi.extui %input + : !pto.vmi.vreg<256xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %even, %odd = "pto.vmi.unpack"(%wide) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %even, %odd : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } + func.func @vmi_to_vpto_trunci_i32_to_ui8( %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) -> !pto.vreg<256xui8> { @@ -109,6 +121,15 @@ module { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast +// CHECK-LABEL: func.func @vmi_to_vpto_extui_u8_to_u16( +// CHECK-SAME: %[[INPUT:.*]]: !pto.vreg<256xui8> +// CHECK: %[[MASK:.*]] = pto.pset_b8 "PAT_ALL" : !pto.mask +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "EVEN"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vcvt %[[INPUT]], %[[MASK]] {part = "ODD"} : !pto.vreg<256xui8>, !pto.mask -> !pto.vreg<128xui16> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + // CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_to_ui8( // CHECK: %[[P0:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P0", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> // CHECK: %[[P1:.*]] = pto.vcvt {{.*}}, {{.*}} {part = "P1", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<256xui8> From b913aa790e3db333e17fbf773ac78f916e63766f Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Mon, 29 Jun 2026 20:41:43 +0800 Subject: [PATCH 45/54] Optimize VMI trunci layout rematerialization --- lib/PTO/Transforms/VMILayoutRematerialize.cpp | 63 ++++++++++++- lib/PTO/Transforms/VMILayoutSupport.cpp | 42 +++++---- lib/PTO/Transforms/VMIToVPTO.cpp | 92 +++++++++++-------- .../vmi/vmi_layout_rematerialize_relation.pto | 20 ++++ test/lit/vmi/vmi_to_vpto_integer_casts.pto | 25 +++++ .../vmi_to_vpto_trunci_i8_signed_invalid.pto | 4 +- 6 files changed, 190 insertions(+), 56 deletions(-) diff --git a/lib/PTO/Transforms/VMILayoutRematerialize.cpp b/lib/PTO/Transforms/VMILayoutRematerialize.cpp index 5a3d1e48ec..be4842ad5c 100644 --- a/lib/PTO/Transforms/VMILayoutRematerialize.cpp +++ b/lib/PTO/Transforms/VMILayoutRematerialize.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/VMILayoutSupport.h" @@ -277,6 +278,63 @@ static bool tryReplaceDataEnsure(VMIEnsureLayoutOp ensure) { return true; } +static bool tryRematerializeTruncIThroughSourceEnsure(VMITruncIOp trunc) { + auto resultType = dyn_cast(trunc.getResult().getType()); + if (!resultType || !hasConcreteLayout(resultType)) + return false; + + auto ensure = trunc.getSource().getDefiningOp(); + if (!ensure) + return false; + + auto originalSourceType = dyn_cast(ensure.getSource().getType()); + if (!originalSourceType || !hasConcreteLayout(originalSourceType)) + return false; + VMILayoutAttr originalSourceLayout = originalSourceType.getLayoutAttr(); + if (!originalSourceLayout.isDeinterleaved() || + originalSourceLayout.getBlockElems() != 1) + return false; + + VMILayoutSupport supports; + FailureOr fact = + supports.getPreferredCastLayoutFact(originalSourceType, resultType); + if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && + fact->kind != VMICastLayoutKind::Narrow4x)) + return false; + if (originalSourceLayout.getFactor() % fact->factor != 0) + return false; + + unsigned resultBits = + pto::getPTOStorageElemBitWidth(resultType.getElementType()); + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return false; + + int64_t rematResultFactor = originalSourceLayout.getFactor() / fact->factor; + VMILayoutAttr rematResultLayout = + rematResultFactor == 1 + ? VMILayoutAttr::getContiguous(resultType.getContext()) + : VMILayoutAttr::getDeinterleaved(resultType.getContext(), + rematResultFactor, + /*blockElems=*/1); + auto rematResultType = + VMIVRegType::get(resultType.getContext(), resultType.getElementCount(), + resultType.getElementType(), rematResultLayout); + if (rematResultType == resultType) + return false; + + OpBuilder builder(trunc); + Value remat = + builder.create(trunc->getLoc(), rematResultType, + ensure.getSource()) + .getResult(); + Value replacement = + materializeDataLayout(remat, resultType, trunc->getLoc(), builder); + trunc.getResult().replaceAllUsesWith(replacement); + trunc.erase(); + return true; +} + template static bool tryReplaceMaskEnsure(EnsureOp ensure) { auto resultType = dyn_cast(ensure.getResult().getType()); @@ -307,7 +365,7 @@ struct VMILayoutRematerializePass SmallVector helpers; module.walk([&](Operation *op) { if (isa(op)) + VMIEnsureMaskGranularityOp, VMITruncIOp>(op)) helpers.push_back(op); }); @@ -327,6 +385,9 @@ struct VMILayoutRematerializePass if (auto ensure = dyn_cast(op)) changed |= tryReplaceMaskEnsure(ensure); + + if (auto trunc = dyn_cast(op)) + changed |= tryRematerializeTruncIThroughSourceEnsure(trunc); } } } diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 2d17ed90b7..f1fd1097fb 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -456,7 +456,8 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( return fact; } - if (sourceBits == 32 && resultBits == 16) { + if ((resultBits == 8 || resultBits == 16) && + sourceBits == resultBits * 2) { fact.kind = VMICastLayoutKind::Narrow2x; fact.factor = 2; fact.sourceLayout = @@ -1458,32 +1459,41 @@ VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { return VMITruncISupport{VMITruncISupportKind::GroupSlots1I32ToNarrow}; } - if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - *resultArity != 1) - return fail("requires integer deinterleaved source and contiguous " - "integer result"); - FailureOr fact = getPreferredCastLayoutFact(sourceType, resultType, reason); if (failed(fact) || (fact->kind != VMICastLayoutKind::Narrow2x && fact->kind != VMICastLayoutKind::Narrow4x)) return fail("unsupported deinterleaved trunci factor, arity, result " - "element width, or result signedness; 32-bit to 8-bit integer " - "narrowing requires unsigned i8 result"); - - if (fact->kind == VMICastLayoutKind::Narrow2x && - sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor) + "element width, or result signedness; 8-bit integer narrowing " + "requires unsigned i8 result"); + + if (!sourceLayout.isDeinterleaved() || sourceLayout.getBlockElems() != 1 || + !(resultLayout.isContiguous() || + (resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1))) + return fail("requires integer deinterleaved source and contiguous or " + "deinterleaved integer result with block_elems=1"); + + int64_t resultFactor = + resultLayout.isDeinterleaved() ? resultLayout.getFactor() : 1; + if (sourceLayout.getFactor() != resultFactor * fact->factor || + *sourceArity != *resultArity * fact->factor) + return fail("unsupported deinterleaved trunci source/result layout factor " + "or physical arity"); + + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return fail("8-bit integer narrowing requires unsigned i8 result"); + + if (fact->kind == VMICastLayoutKind::Narrow2x) return VMITruncISupport{ VMITruncISupportKind::Deinterleaved2I32ToContiguousI16}; - if (fact->kind == VMICastLayoutKind::Narrow4x && - sourceLayout.getFactor() == fact->factor && - *sourceArity == fact->factor && - cast(resultType.getElementType()).isUnsigned()) + if (fact->kind == VMICastLayoutKind::Narrow4x) return VMITruncISupport{ VMITruncISupportKind::Deinterleaved4I32ToContiguousI8}; return fail("unsupported deinterleaved trunci factor, arity, result element " - "width, or result signedness; 32-bit to 8-bit integer narrowing " + "width, or result signedness; 8-bit integer narrowing " "requires unsigned i8 result"); } diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index b45a792aac..d774545de3 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -7405,36 +7405,47 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { return success(); } - if ((sourceParts.size() != 2 && sourceParts.size() != 4) || - resultTypes.size() != 1) + if (sourceParts.empty() || resultTypes.empty()) return rewriter.notifyMatchFailure( - op, "only 32-bit integer deinterleaved=2/4 to 16/8-bit contiguous " - "trunci is supported"); + op, "trunci requires non-empty physical source and result parts"); auto sourceType0 = dyn_cast(sourceParts.front().getType()); - auto resultType = dyn_cast(resultTypes.front()); + auto resultType0 = dyn_cast(resultTypes.front()); if (!sourceType0 || !isa(sourceType0.getElementType()) || - !resultType || !isa(resultType.getElementType())) + !resultType0 || !isa(resultType0.getElementType())) return rewriter.notifyMatchFailure( op, "unsupported physical trunci source/result type"); for (Value sourcePart : sourceParts) { auto sourceType = dyn_cast(sourcePart.getType()); if (!sourceType || sourceType != sourceType0) return rewriter.notifyMatchFailure( - op, "trunci source physical parts must have matching 32-bit " - "integer type"); + op, "trunci source physical parts must have matching integer type"); + } + for (Type resultType : resultTypes) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != resultType0) + return rewriter.notifyMatchFailure( + op, "trunci result physical parts must have matching integer type"); } - if (pto::getPTOStorageElemBitWidth(sourceType0.getElementType()) != 32) - return rewriter.notifyMatchFailure( - op, "trunci source physical element width must be 32-bit"); + unsigned sourceBits = + pto::getPTOStorageElemBitWidth(sourceType0.getElementType()); unsigned resultBits = - pto::getPTOStorageElemBitWidth(resultType.getElementType()); + pto::getPTOStorageElemBitWidth(resultType0.getElementType()); + if (sourceBits == 0 || resultBits == 0 || sourceBits % resultBits != 0) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result width relation"); + int64_t factor = sourceBits / resultBits; + if ((factor != 2 && factor != 4) || + sourceParts.size() != resultTypes.size() * factor) + return rewriter.notifyMatchFailure( + op, "unsupported physical trunci source/result arity relation"); + ArrayRef parts; - if (sourceParts.size() == 2 && resultBits == 16) { + if (factor == 2) { static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; parts = kEvenOddParts; - } else if (sourceParts.size() == 4 && resultBits == 8) { + } else if (factor == 4) { static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; parts = kPacked4Parts; } else { @@ -7445,30 +7456,38 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { FailureOr sourceMask = createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); FailureOr resultMask = - createAllTrueMaskForVReg(op.getLoc(), resultType, rewriter); + createAllTrueMaskForVReg(op.getLoc(), resultType0, rewriter); if (failed(sourceMask) || failed(resultMask)) return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); StringAttr sat = rewriter.getStringAttr("SAT"); - SmallVector partials; - partials.reserve(parts.size()); - for (auto [sourcePart, part] : llvm::zip_equal(sourceParts, parts)) { - partials.push_back(rewriter - .create(op.getLoc(), resultType, - sourcePart, *sourceMask, - /*rnd=*/nullptr, sat, - rewriter.getStringAttr(part)) - .getResult()); - } - - Value merged = partials.front(); - for (Value partial : llvm::drop_begin(partials)) - merged = rewriter - .create(op.getLoc(), resultType, merged, partial, - *resultMask) - .getResult(); + SmallVector results; + results.reserve(resultTypes.size()); + for (int64_t resultIndex = 0, resultCount = resultTypes.size(); + resultIndex < resultCount; ++resultIndex) { + Type resultType = resultTypes[resultIndex]; + SmallVector partials; + partials.reserve(parts.size()); + for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + Value sourcePart = sourceParts[resultIndex * factor + partIndex]; + partials.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, /*rnd=*/nullptr, sat, + rewriter.getStringAttr(parts[partIndex])) + .getResult()); + } + + Value merged = partials.front(); + for (Value partial : llvm::drop_begin(partials)) + merged = rewriter + .create(op.getLoc(), resultType, merged, partial, + *resultMask) + .getResult(); + results.push_back(merged); + } - rewriter.replaceOp(op, merged, adaptor.getResultMapping()); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } }; @@ -9427,10 +9446,9 @@ verifySupportedVMIToVPTOOps(ModuleOp module, trunci.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.trunci supports only 32-bit integer deinterleaved=2 " - "source parts to one contiguous 16-bit integer result chunk, " - "32-bit integer deinterleaved=4 source parts to one contiguous " - "8-bit integer result chunk, or 32-bit integer " + << "pto.vmi.trunci supports integer deinterleaved source layouts " + "whose factor is the 2x/4x narrowing multiple of the contiguous " + "or deinterleaved result layout factor, or 32-bit integer " "group_slots(num_groups=G, slots=1 or 8) to 8/16-bit integer " "group_slots(num_groups=G, slots=1 or 8) (" << reason << ")"; diff --git a/test/lit/vmi/vmi_layout_rematerialize_relation.pto b/test/lit/vmi/vmi_layout_rematerialize_relation.pto index bc631893d5..b9174c2527 100644 --- a/test/lit/vmi/vmi_layout_rematerialize_relation.pto +++ b/test/lit/vmi/vmi_layout_rematerialize_relation.pto @@ -79,6 +79,18 @@ module { : !pto.vmi.vreg<256xf16, #pto.vmi.layout>, !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> } + + func.func @vmi_layout_rematerialize_trunci_source_ensure( + %x32_d4: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> { + %x32_d2 = pto.vmi.ensure_layout %x32_d4 + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + %y16 = pto.vmi.trunci %x32_d2 + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + return %y16 : !pto.vmi.vreg<256xui16, #pto.vmi.layout> + } } // REMAT-LABEL: func.func @vmi_layout_rematerialize_direct_ext( @@ -111,6 +123,14 @@ module { // REMAT: pto.vmi.truncf %[[X32_D4]] // REMAT-NOT: !pto.vmi.vreg<256xf32, #pto.vmi.layout> -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// REMAT-LABEL: func.func @vmi_layout_rematerialize_trunci_source_ensure( +// REMAT-SAME: %[[X32_D4:.*]]: !pto.vmi.vreg<256xi32, #pto.vmi.layout> +// REMAT: %[[Y16_D2:.*]] = pto.vmi.trunci %[[X32_D4]] +// REMAT-SAME: !pto.vmi.vreg<256xi32, #pto.vmi.layout> -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// REMAT: pto.vmi.ensure_layout %[[Y16_D2]] +// REMAT-SAME: !pto.vmi.vreg<256xui16, #pto.vmi.layout> -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> +// REMAT-NOT: !pto.vmi.vreg<256xi32, #pto.vmi.layout> -> !pto.vmi.vreg<256xi32, #pto.vmi.layout> + // LOWER-LABEL: func.func @vmi_layout_rematerialize_direct_ext( // LOWER: pto.vldsx2 // LOWER-SAME: "DINTLV_B16" diff --git a/test/lit/vmi/vmi_to_vpto_integer_casts.pto b/test/lit/vmi/vmi_to_vpto_integer_casts.pto index 56a42fb2b6..d65b028c70 100644 --- a/test/lit/vmi/vmi_to_vpto_integer_casts.pto +++ b/test/lit/vmi/vmi_to_vpto_integer_casts.pto @@ -51,6 +51,18 @@ module { return %p : !pto.vreg<256xui8> } + func.func @vmi_to_vpto_trunci_i32_d4_to_ui16_d2( + %wide: !pto.vmi.vreg<256xi32, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) { + %narrow = pto.vmi.trunci %wide + : !pto.vmi.vreg<256xi32, #pto.vmi.layout> + -> !pto.vmi.vreg<256xui16, #pto.vmi.layout> + %low, %high = "pto.vmi.unpack"(%narrow) + : (!pto.vmi.vreg<256xui16, #pto.vmi.layout>) + -> (!pto.vreg<128xui16>, !pto.vreg<128xui16>) + return %low, %high : !pto.vreg<128xui16>, !pto.vreg<128xui16> + } + func.func @vmi_to_vpto_fptosi_f32_to_i32( %input: !pto.vmi.vreg<256xf32, #pto.vmi.layout>) -> (!pto.vreg<64xi32>, !pto.vreg<64xi32>, @@ -142,6 +154,19 @@ module { // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast +// CHECK-LABEL: func.func @vmi_to_vpto_trunci_i32_d4_to_ui16_d2( +// CHECK-SAME: %[[P0:.*]]: !pto.vreg<64xi32>, %[[P1:.*]]: !pto.vreg<64xi32>, %[[P2:.*]]: !pto.vreg<64xi32>, %[[P3:.*]]: !pto.vreg<64xi32> +// CHECK: %[[MASK:.*]] = pto.pset_b32 "PAT_ALL" : !pto.mask +// CHECK: %[[R0_EVEN:.*]] = pto.vcvt %[[P0]], %[[MASK]] {part = "EVEN", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[R0_ODD:.*]] = pto.vcvt %[[P1]], %[[MASK]] {part = "ODD", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vor %[[R0_EVEN]], %[[R0_ODD]] +// CHECK: %[[R1_EVEN:.*]] = pto.vcvt %[[P2]], %[[MASK]] {part = "EVEN", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: %[[R1_ODD:.*]] = pto.vcvt %[[P3]], %[[MASK]] {part = "ODD", sat = "SAT"} : !pto.vreg<64xi32>, !pto.mask -> !pto.vreg<128xui16> +// CHECK: pto.vor %[[R1_EVEN]], %[[R1_ODD]] +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast + // CHECK-LABEL: func.func @vmi_to_vpto_fptosi_f32_to_i32( // CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> // CHECK: pto.vcvt {{.*}} {rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xi32> diff --git a/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto index 145ef2a7b9..c87af13167 100644 --- a/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_trunci_i8_signed_invalid.pto @@ -25,5 +25,5 @@ module { } // CHECK: VMI-UNSUPPORTED -// CHECK: pto.vmi.trunci supports only -// CHECK: 32-bit to 8-bit integer narrowing requires unsigned i8 result +// CHECK: pto.vmi.trunci supports integer deinterleaved source layouts +// CHECK: 8-bit integer narrowing requires unsigned i8 result From 4765122be233a1c3abafadcbeec144218a9ac12b Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 30 Jun 2026 11:39:29 +0800 Subject: [PATCH 46/54] Generalize VMI dense lane-stride layouts --- .../vmi-lane-stride-generalization-design.md | 784 ++++++++ ...ne-stride-generalization-implementation.md | 1583 +++++++++++++++++ .../vmi-layout-assignment-implementation.md | 9 +- .../vmi-layout-assignment-lowering-design.md | 6 + include/PTO/IR/VMIAttrs.td | 19 +- include/PTO/Transforms/VMILayoutSupport.h | 12 + .../PTO/Transforms/VMITargetCapabilities.h | 8 +- lib/PTO/IR/VMI.cpp | 106 +- lib/PTO/IR/VPTO.cpp | 4 + lib/PTO/Transforms/PTOValidateVPTOIR.cpp | 10 + lib/PTO/Transforms/VMILayoutAssignment.cpp | 68 +- lib/PTO/Transforms/VMILayoutFold.cpp | 20 +- lib/PTO/Transforms/VMILayoutSupport.cpp | 304 +++- lib/PTO/Transforms/VMIToVPTO.cpp | 629 ++++++- .../vmi/vmi_lane_stride_dense_load_store.pto | 197 ++ test/lit/vmi/vmi_lane_stride_masked_store.pto | 85 + ..._layout_assignment_dense_f16_f32_store.pto | 14 +- ..._layout_assignment_f32_f8_store_reduce.pto | 22 +- .../vmi_layout_assignment_f8_compute_f8.pto | 16 +- ...ignment_group_broadcast_multi_consumer.pto | 10 +- ...ut_assignment_group_load_block8_truncf.pto | 6 +- ...out_assignment_group_reduce_maxf_quant.pto | 11 +- ...roup_reduce_s16_truncf_broadcast_store.pto | 8 +- .../vmi/vmi_layout_assignment_load_truncf.pto | 31 +- ...ignment_mask_granularity_f32_f16_store.pto | 1 - ...ment_packed_group_slots_truncf_invalid.pto | 4 +- .../vmi_layout_assignment_truncf_ensure.pto | 19 +- test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 9 +- test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 23 +- test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 18 +- ...vpto_truncf_fp8_128_contiguous_invalid.pto | 23 +- test/lit/vpto/vmi_truncf_hif8.pto | 10 +- 32 files changed, 3856 insertions(+), 213 deletions(-) create mode 100644 docs/designs/vmi-lane-stride-generalization-design.md create mode 100644 docs/designs/vmi-lane-stride-generalization-implementation.md create mode 100644 test/lit/vmi/vmi_lane_stride_dense_load_store.pto create mode 100644 test/lit/vmi/vmi_lane_stride_masked_store.pto diff --git a/docs/designs/vmi-lane-stride-generalization-design.md b/docs/designs/vmi-lane-stride-generalization-design.md new file mode 100644 index 0000000000..a839e337c3 --- /dev/null +++ b/docs/designs/vmi-lane-stride-generalization-design.md @@ -0,0 +1,784 @@ +# VMI Lane-Stride Layout Generalization Design + +本文定义 `lane_stride` 从 group-slot 专用属性泛化为 VMI layout 的通用 +物理 lane 映射轴。目标不是只优化 `64xf16 -> 64xf32`,而是给 dense +value、group-slot value、类型转换、broadcast materialization 和 load/store +rematerialization 提供统一表达。 + +## 1. Problem + +当前文档对 `lane_stride` 的语义是: + +```text +logical lane-sized physical slot 之间有固定间距 +``` + +但实现只允许它出现在: + +```text +#pto.vmi.layout +``` + +并且现有 helper 会把 `ui8 lane_stride=4` 这类 group-slot lowering 映射为 +b32 carrier。这导致两个问题: + +1. dense value 无法表达“64 个 f16 logical lanes 放在一个 128xf16 物理向量 + 的偶数 lane 上”。 +2. `lane_stride` 的 layout 语义和 group-slot carrier lowering 被混在一起。 + +泛化后必须保持以下边界: + +```text +lane_stride: + layout lane map, does not change logical element type + +carrier packing: + one lowering strategy for selected group-slot integer stores +``` + +## 2. Semantic Model + +### 2.1 Dense Layout + +Dense layout 仍然表示每个 logical lane 都有语义值。第一阶段只增加 +`lane_stride` 一个新轴: + +```text +deinterleave factor F +block elems B +lane stride LS +``` + +建议 surface spelling: + +```text +#pto.vmi.layout +#pto.vmi.layout + +#pto.vmi.layout +#pto.vmi.layout +``` + +Defaults: + +```text +F = 1 for contiguous +B = 1 +LS = 1 +``` + +Dense lane map: + +```text +logical lane i + +block q = i / B +in-block lane r = i % B +part p = q % F +part block t = q / F + +dense lane index in part = t * B + r +physical part p, physical lane dense lane index * LS +``` + +The current stage intentionally describes only phase-zero strided dense layouts. +For `lane_stride = 2`, that means semantic lanes occupy even physical lanes. + +An optional future `lane_offset` or `lane_phase` field is useful only after the +IR has a concrete zero-copy view or producer whose logical lane `i` is +intentionally represented at physical lane `2 * i + 1` or another non-zero +phase. The current stage has no such producer. The field should +not be added just because the target has a `vcvt ODD` instruction. + +`vcvt ODD` is needed in two different situations: + +```text +1. Full conversion of a packed contiguous source. + Example: contiguous f16 -> deinterleaved=2 f32 uses EVEN and ODD. + This is not an odd-phase dense source layout; it is the normal multi-part + lowering of a packed source. + +2. Single-part conversion of a future zero-copy odd-lane view. + Example: if a logical deinterleave/extract result were represented as + f16 lane_stride=2, lane_offset=1 instead of being compacted, then converting + that view to f32 contiguous would use ODD. This requires an explicit VMI + producer or consumer contract; current-stage dense stride does not + create such values. +``` + +The current design implements case 1 with existing conversion lowering and case +2 only as a non-goal extension. The useful dense-stride optimization in this +stage uses phase-zero layout and therefore selects `EVEN` for `W=2`. + +### 2.2 Deinterleaved vs Lane Stride + +Use `deinterleaved` when multiple semantic residue classes or physical parts of +the same dense logical value are all present. + +Use dense `lane_stride` when one semantic stream is stored sparsely inside each +physical part and the skipped lanes have no semantic value for this VMI value. + +Decision rule: + +```text +all residue classes are semantic: + use deinterleaved + +only one phase-zero residue class is semantic: + use lane_stride + +multiple parts are semantic and each part is internally strided: + use deinterleaved + lane_stride +``` + +Examples: + +```text +contiguous f16 -> f32 full dense widen: + source lanes 0,1,2,3,... are all semantic + result naturally has even/odd conversion parts + use result deinterleaved=2 + +64xf16 -> 64xf32 where the f32 consumer wants contiguous: + the vcvt layout support may request source lane_stride=2 + if the source producer/rematerialization can satisfy that request, source + lanes 0,2,4,... become semantic and lanes 1,3,5,... are holes for this value + extf result can then be contiguous through one EVEN conversion + +group-reduce or dense consumer that needs two/four logical fragments: + the fragments are semantic parts of the same dense value + use deinterleaved=2/4, not lane_stride +``` + +Do not use `lane_stride` to describe a full packed value that happens to need an +ODD conversion part. Do not use `deinterleaved` to describe holes inside one +physical part. + +Important distinction: + +```text +one hardware vcvt output: + always one contiguous VPTO output register + +VMI ext result layout: + describes how one or more hardware output registers map back to logical lane + order +``` + +For `W=2`, with logical f16 lanes named by their logical indices: + +```text +source contiguous: + physical lanes: 0, 1, 2, 3, 4, 5, ... + vcvt EVEN output carries logical lanes 0, 2, 4, ... + vcvt ODD output carries logical lanes 1, 3, 5, ... + VMI result layout is deinterleaved=2 unless another materialization + interleaves the two outputs. + +source lane_stride=2: + physical lanes: 0, _, 1, _, 2, _, ... + vcvt EVEN output carries logical lanes 0, 1, 2, ... + VMI result layout is contiguous. +``` + +So "vcvt output is contiguous" does not by itself mean the VMI `extf` result is +contiguous. The result layout depends on the logical lane mapping of the source +layout and the selected conversion parts. + +### 2.3 Group-Slot Layout + +Group-slot layout remains non-dense. Only `G` group result slots have semantic +values: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Existing mapping is preserved: + +```text +slot_block(g) = g / K +slot_lane(g) = (g % K) * LS +``` + +This remains a group-slot placement property. It does not make non-slot lanes +semantic. Existing `ui8 lane_stride=4` to b32 carrier lowering is still legal, +but it is not the definition of `lane_stride`. + +Group-slot `lane_offset` is not needed in the current stage. It should remain +out of scope unless a real group-slot producer needs non-zero phase. + +## 3. Physical Capacity + +`lane_stride` increases the number of physical lane slots needed by a dense +part, but it does not change the VMI logical element type. + +For one dense physical part in the current stage: + +```text +logical lanes in this part = M +required physical lanes = (M - 1) * LS + 1 +``` + +The number of VPTO physical registers for each part is: + +```text +ceil(required physical lanes / lanes_per_vpto_register(T)) +``` + +Total physical arity: + +```text +deinterleave factor F * registers per part +``` + +Example: + +```text +!vmi.vreg<64xf16, contiguous, lane_stride=2> + +lanes_per_vpto_register(f16) = 128 +required physical lanes = 63 * 2 + 1 = 127 +physical arity = 1 +``` + +The 64 logical f16 lanes occupy physical f16 lanes `0, 2, 4, ... 126` of one +`!pto.vreg<128xf16>`. The other lanes are undefined unless another layout +value gives them semantics. + +Some lowerings represent the same lane map with wider carrier slots instead of +logical-element lanes. For example, a b16 value with `lane_stride=2` may be +lowered as the low b16 element of each b32 carrier slot when using +`UNPK_B16`/`PK_B32` or register pack/unpack materialization. This does not +change the VMI logical element type; it is a VPTO lowering representation choice. + +## 4. Type And Operation Generalization + +The design is element-type agnostic. Dense `lane_stride` applies to any VMI +element type whose physical VPTO lane count is known: + +```text +f8, f16, bf16, f32 +i8, ui8, i16, ui16, i32, ui32 +pred masks at an explicit predicate granularity +``` + +An op may support a strided dense layout only when its VPTO lowering can +preserve the lane map. Unsupported combinations are rejected by layout support +queries, not silently repaired in `vmi-to-vpto`. + +### 4.1 VPTO Pack/Unpack Support Boundary + +Dense `lane_stride` is not a generic VPTO load/store operand. It is supported +only when the lane map matches a concrete VPTO distribution or register +materializer. + +Direct compact memory support: + +| Dense lane_stride | compact load | compact store | +|---:|---|---| +| 2, b8 | `vlds UNPK_B8` | `vsts PK_B16` | +| 2, b16 | `vlds UNPK_B16` | `vsts PK_B32` | +| 2, b32 | `vlds UNPK_B32` | `vsts PK_B64` | +| 4, b8 | `vlds UNPK4` | `vsts PK4_B32` | +| 4, b16/b32 | no direct dist | no direct dist | + +Direct scalar broadcast load target capability: + +```text +lane_stride=2/4, b8/b16/b32: + vlds BRC_B8/B16/B32 +``` + +The current stage does not add a VMI scalar broadcast-load op. BRC is therefore +a target capability for a separate scalar broadcast-load semantic, not part of +the current `vmi.load -> ensure_layout` compact-stream fold. + +Register fallback between contiguous and dense `lane_stride` should use the +register-side counterpart of these distributions: + +```text +contiguous -> lane_stride: + vsunpack/vzunpack-style placement into wider slots + +lane_stride -> contiguous: + vpack-style extraction from wider slots +``` + +`vintlv`/`vdintlv` remain the materializers for two-stream +interleave/deinterleave layouts; they are not the primary fallback for dense +`lane_stride`. + +### 4.2 Layout-Transparent Dense Ops + +Layout-transparent dense ops include ordinary elementwise arithmetic and +select-like ops when every dense data operand/result has the same layout: + +```text +add/mul/fma/min/max/select: + operands and result require identical dense layout key + key includes F, B, and LS +``` + +No physical shuffle is implied by these ops. + +### 4.3 Widening Conversion + +Let a widening conversion increase element storage width by ratio `W`: + +```text +f16 -> f32: W = 2 +bf16 -> f32: W = 2 +i16 -> i32: W = 2 +ui16 -> ui32: W = 2 +f8 -> f32: W = 4 +i8 -> i32: W = 4 +ui8 -> ui32: W = 4 +ui8 -> ui16: W = 2 +``` + +For a phase-zero source dense layout with `lane_stride = LS`, a single hardware +conversion part is sufficient when: + +```text +LS % W == 0 +``` + +The selected hardware part in the current stage is: + +```text +part = 0 +``` + +The result layout after conversion is: + +```text +result lane_stride = LS / W +``` + +For a future phase-aware layout with `lane_offset = O`, the generic relation is: + +```text +part = O % W +result lane_stride = LS / W +result lane_offset = (O - part) / W +``` + +That future relation should be enabled only when a real odd/non-zero-phase VMI +producer or consumer exists. + +Examples: + +```text +f16 source: contiguous, lane_stride=2 +extf to f32: + use vcvt EVEN + result contiguous + +f16 source: contiguous, lane_stride=4 +extf to f32: + use vcvt EVEN + result contiguous, lane_stride=2 +``` + +If `LS < W` or `LS % W != 0`, the conversion may need multiple hardware parts +and may naturally produce a deinterleaved result. The current contiguous source +case is the common example: + +```text +f16 source: contiguous, lane_stride=1 +extf to f32: + use vcvt EVEN and vcvt ODD + result deinterleaved=2 +``` + +Assignment chooses one preferred fact for the op before lowering. Consumer +requests are handled by the existing use-site materialization path after the +op's assigned result layout is fixed. + +The preferred direction for this optimization is not "notice the input is +already strided". The conversion op can be the layout-entry point and compute a +single preferred layout fact for the current op instance: + +```text +baseline fact: + source contiguous + result deinterleaved=W + cost: W conversion parts + +lane-stride fact: + source lane_stride=W + result contiguous + hardware conversion parts: one + source layout request: explicit +``` + +In the current single-preference framework, `ext` should publish one preferred +fact. The lane-stride fact is an op-local preference: assignment records the +required source/result relation in the IR and inserts `ensure_layout` at the +source use if the producer is not already in that layout. Later +rematerialization or fold passes may remove that helper when a concrete producer +rewrite exists; otherwise the helper is either lowered by a registered +contiguous/lane-stride materializer or rejected before `vmi-to-vpto`. + +This keeps the optimization in layout assignment/rematerialization, not in a +late `vmi-to-vpto` peephole, and stays within the existing single-preference +assignment model. + +### 4.4 Narrowing Conversion + +Narrowing is the inverse relation. If source element width is `W` times the +result element width, a single hardware narrowing part can produce a +phase-zero strided result when: + +```text +result lane_stride = source lane_stride * W +part = 0 +``` + +This covers more than f32-to-f16. The same relation applies to: + +```text +f32 -> f16/bf16 +i32 -> i16/i8 +ui32 -> ui16/ui8 +ui16 -> ui8 +``` + +The exact supported parts are target-op dependent. The layout assignment layer +should ask the op support interface whether a given source/result layout pair is +legal, rather than encoding type-specific shortcuts. + +### 4.5 Broadcast Materialization + +Broadcast remains a logical operation. `lane_stride` only describes the chosen +materialized layout. + +Scalar or group broadcast can materialize to a dense layout only when the +broadcast lowering or rematerialization support query accepts that lane map: + +```text +logical broadcast: + lane i gets value group(i) + +materialized layout: + lane i is stored at physical lane map(i) +``` + +This keeps E2B-style optimizations in the layout/rematerialization layer. A +group broadcast load may choose a dense strided layout when that layout directly +matches a consumer or a target instruction. If another consumer needs a +different layout, rematerialization may clone the broadcast or insert +`ensure_layout`. + +`group_broadcast_load` is also a VMI semantic, not an E2B semantic. It means: + +```text +for each logical group g: + load one scalar from source[offset + g * source_group_stride] + broadcast that scalar to all lanes in group g +``` + +E2B is a target lowering choice for the subset where that logical memory pattern, +the group size, the element width, and the assigned result layout match the E2B +packet semantics. Other lowering strategies may implement the same VMI +operation, so support queries should report "E2B is applicable" instead of +rewriting the VMI meaning to "this op is E2B". + +### 4.6 Masked Lane-Stride Stores + +Masks are logical predicates. A `masked_store` mask bit denotes whether a +logical element participates in the store; it is not automatically a predicate +for the physical lane slot that happens to carry that element after layout +assignment. + +For dense `lane_stride`, this distinction matters. With `lane_stride=2`, +logical lane `i` is carried in physical lane `2*i`. A packed store then +compacts those even physical lanes into a contiguous memory stream. A user mask +that is still contiguous cannot be passed directly to that packed store, because +the packed-store predicate is interpreted after the value lanes have been +compacted. + +A direct masked compact store is therefore legal only when the compiler has +assigned the value and mask the same lane map. That may happen because the mask +producer can directly produce the requested lane map, because assignment inserts +a mask `ensure_layout`, or because rematerialization rebuilds the mask producer +for that lane map. Without that compiler-derived proof, assignment should keep +a layout that the existing masked-store path can lower, even if the corresponding +unmasked store could use a dense lane-stride `PK` instruction. + +## 5. Assignment And Optimization Boundary + +The assignment pipeline should keep the existing responsibility split: + +```text +layout assignment: + collect consumer requests + ask producer/op support + assign explicit layout attrs + insert ensure_layout for use-local conflicts + +rematerialization: + clone cheap producers for incompatible use-site layouts + replace ensure_layout(producer) when producer can directly create target layout + +layout fold: + erase or fuse materialization helpers when the producer already has the + requested lane map + +vmi-to-vpto: + lower explicit assigned layouts only + no hidden layout selection policy +``` + +Dense `lane_stride` is therefore an assigned layout fact, not a lowering-side +pattern. An entry op such as `extf` may prefer it from the conversion ratio +alone; producer-specific rewrites are handled later by fold/rematerialization +passes over explicit helpers. The selected layout is fixed before +`vmi-to-vpto`, and `vmi-to-vpto` does not rediscover the preference. + +## 6. End-To-End Case Walkthroughs + +These cases are the intended test for the design. They show when dense +`lane_stride` is useful and when it should lose to the existing deinterleaved +plan. + +The logical programs in this section are pre-assignment VMI and do not carry +concrete layouts. Layouts shown under "baseline plan" or "lane-stride plan" are +possible assignment results, not layouts written in the input program. + +### 6.1 Contiguous Load, Ext, Contiguous Store + +Logical program: + +```text +%x16 = vmi.load %in : 64xf16 +%x32 = vmi.extf %x16 : 64xf16 -> 64xf32 +vmi.store %x32, %out : dense contiguous memory effect +``` + +Baseline plan: + +```text +load result: + contiguous f16 + +ext relation: + source contiguous f16 + result deinterleaved=2 f32 + lower: vcvt EVEN + vcvt ODD + +store: + needs contiguous f32 + requires result materialization deinterleaved=2 -> contiguous +``` + +Lane-stride plan: + +```text +load result: + lane_stride=2 f16 + +ext relation: + source lane_stride=2 f16 + result contiguous f32 + lower: vcvt EVEN + +store: + consumes contiguous f32 directly +``` + +The load side then has two concrete outcomes: + +```text +accepted direct load fold: + the original load has only the lane-stride use + compact load semantics match a supported UNPK dist + vmi-layout-fold changes the VMI load result layout in place + +no direct load fold: + keep the explicit source ensure_layout + lower it through register pack/unpack if that materialization is supported + otherwise keep the baseline contiguous-source/deinterleaved-result relation +``` + +This case proves that `extf` can be the layout-entry point, while `load` support +is still decided by the load/ensure fold or by the explicit materialization +helper. + +### 6.2 Broadcast, Ext, Contiguous Store + +Logical program: + +```text +%b16 = vmi.broadcast %s : 1xf16 -> 64xf16 +%b32 = vmi.extf %b16 : 64xf16 -> 64xf32 +vmi.store %b32, %out +``` + +Baseline plan: + +```text +broadcast materializes contiguous f16 +ext produces deinterleaved=2 f32 through EVEN + ODD +store materializes deinterleaved=2 -> contiguous +``` + +Lane-stride plan: + +```text +broadcast rematerializes directly as lane_stride=2 f16 +ext produces contiguous f32 through one EVEN +store consumes contiguous f32 +``` + +Here the lane-stride plan is accepted because broadcast is a rematerializable +producer: it can be rebuilt with the requested physical lane map instead of +requiring a register layout conversion. This is the kind of producer where +`vcvt` should drive a source `lane_stride=2` request. + +### 6.3 Ext Feeding A Deinterleaved Consumer + +Logical program: + +```text +%x16 = producer : 128xf16 +%x32 = vmi.extf %x16 : 128xf16 -> 128xf32 +%r = vmi.group_reduce %x32 // requests deinterleaved=2 +``` + +Baseline plan: + +```text +source contiguous f16 +result deinterleaved=2 f32 +consumer consumes result directly +``` + +Lane-stride plan: + +```text +source lane_stride=2 f16 +result contiguous f32 +consumer then needs contiguous -> deinterleaved=2 materialization +``` + +The baseline plan should win. A lane-stride fact is not useful when it creates a +layout the consumer does not want; for full chunks it may not reduce the +conversion count either. + +### 6.4 One Ext Result Feeding Store And Reduce + +Logical program: + +```text +%x16 = cheap_or_expensive_producer : 128xf16 +%x32 = vmi.extf %x16 : 128xf16 -> 128xf32 +vmi.store %x32, %out // requests contiguous +vmi.group_reduce %x32 // requests deinterleaved=2 +``` + +If `%x16` is not cheap to rematerialize: + +```text +assign ext result deinterleaved=2 for the reduce +insert ensure_layout at the store use +``` + +If `%x16` and `extf` are cheap to rematerialize: + +```text +shared path: + source contiguous -> ext result deinterleaved=2 -> reduce + +store-only remat path: + rematerialized source lane_stride=2 -> ext result contiguous -> store +``` + +This is a rematerialization decision, not a local `vcvt` peephole. + +### 6.5 Group Broadcast Load Feeding Ext + +Logical program: + +```text +%g16 = vmi.group_broadcast_load %scale : logical dense 64xf16 +%g32 = vmi.extf %g16 : 64xf16 -> 64xf32 +consumer requests contiguous %g32 +``` + +The lane-stride plan is accepted only if the group broadcast load lowering can +emit the requested lane map directly: + +```text +group broadcast load result lane_stride=2 f16 +ext result contiguous f32 +``` + +If the broadcast load can only produce contiguous or deinterleaved packets for +the target element width, assignment should keep those layouts and let later +materialization/rematerialization handle the consumer conflict. Dense +`lane_stride` is a requestable layout, not a guarantee that every producer can +create it. + +## 7. Compatibility Rules + +Two dense layouts are identical only if all lane-map fields match: + +```text +F, B, LS +``` + +Two dense layouts may be related by an explicit materialization only if a +registered relation can lower the map conversion. Examples: + +```text +contiguous <-> deinterleaved=2 +deinterleaved=2 <-> deinterleaved=4 when supported by existing intlv/dintlv +contiguous <-> contiguous, lane_stride=2 when pack/unpack materialization or +producer rematerialization supports it +``` + +The baseline assignment must not assume an arbitrary dense-to-dense +`ensure_layout` is free or legal. Unsupported materializations should fail in +verification or remain unselected by support queries. + +## 8. Non-Goals + +This design does not: + +1. Turn memory layout into strided memory semantics. Dense VMI `lane_stride` + describes register materialization, not GM/UB address stride. +2. Make non-slot lanes of group-slot layouts semantic. +3. Require every VPTO op to support every strided layout. +4. Encode `64xf16 -> 64xf32` as a one-off `vcvt EVEN` peephole. + +## 9. First Useful Optimization + +The motivating case becomes one instance of the generic rule: + +```text +source: + requested as !vmi.vreg<64xf16, contiguous, lane_stride=2> + +op: + extf f16 -> f32, W=2 + +result: + !vmi.vreg<64xf32, contiguous> + +lowering: + one vcvt EVEN +``` + +The same mechanism also covers: + +```text +bf16 -> f32 with phase-zero lane_stride=2 +ui8 -> ui16 with lane_stride=2 +ui8 -> ui32 with lane_stride=4 +f8 -> f32 with lane_stride=4 +narrowing conversions that intentionally produce phase-zero strided results +broadcast materialization into a consumer-required strided dense layout +``` diff --git a/docs/designs/vmi-lane-stride-generalization-implementation.md b/docs/designs/vmi-lane-stride-generalization-implementation.md new file mode 100644 index 0000000000..4faeb7bf4e --- /dev/null +++ b/docs/designs/vmi-lane-stride-generalization-implementation.md @@ -0,0 +1,1583 @@ +# VMI Lane-Stride Layout Generalization Implementation Plan + +本文给出 `lane_stride` 泛化的实现路径。设计目标是把 lane-strided dense +layout 作为一等 layout fact 固化、传播、rematerialize 和 lower,而不是在 +`vmi-to-vpto` 中识别单个 `64xf16 -> 64xf32` pattern。 + +## 1. Implementation Principles + +1. `lane_stride` is a lane-map field. +2. Dense `lane_stride` does not change the VMI logical element type. +3. Group-slot carrier packing is a separate lowering helper. +4. Layout assignment decides layout before `vmi-to-vpto`. +5. `vmi-to-vpto` only lowers explicit assigned layout attrs. + +Pre-existing baseline before this design: + +```text +dense contiguous/deinterleaved layouts: + did not carry lane_stride + +regular VMI load/store: + did not support dense lane_stride + support contiguous and selected deinterleaved lowering/materialization paths + +VPTO load/store: + pto.vlds/pto.vsts have a dist string and the VPTO surface supports several + distribution families, but there is no generic lane_stride operand + +group-slot lane_stride: + already existed and was used by selected group-store packed-byte lowering +``` + +Any dense lane-stride load/store support must enter explicitly by mapping a VMI +lane-stride layout to a specific supported VPTO dist family or materialization +sequence. It must not be inferred in `vmi-to-vpto` from a one-off producer or +consumer pattern. + +Current stage status: + +| Area | Status | Notes | +|---|---|---| +| Dense layout attrs | Supported | Dense contiguous/deinterleaved layouts carry `lane_stride`; group-slot carrier layout remains separate. | +| Direct compact load/store | Supported for selected phase-zero maps | LS=2 b8/b16/b32 through `UNPK_B8/B16/B32` and `PK_B16/B32/B64`; LS=4 b8 through `UNPK4` and `PK4_B32`. | +| Load/store layout folds | Supported with one-load/one-store preservation | `load -> ensure_layout(lane_stride)` rewrites the original load layout when all uses agree; `ensure_layout(lane_stride -> contiguous) -> store` lets the VMI store consume the lane-stride value. | +| Dense widening ext | Supported | Source lane_stride=W can lower to a single `vcvt` part when the cast relation matches. | +| Dense narrowing trunc | Supported for ordinary dense store paths | Source contiguous, result lane_stride=W, then direct compact store when supported. | +| Masked compact store | Partially supported | Legal only when value and mask have the same lane map and the mask can be compacted for the selected store dist. | +| Masked trunc tail | Not optimized yet | Keep the existing legal path until mask lane-stride assignment/materialization is available. | +| Register fallback | Partially supported | Only same-physical-arity contiguous `<->` lane_stride paths with legal pack/unpack carriers. Arity-changing fallback is not in scope for this stage. | +| Group broadcast load | Supported only through specific strategies | `group_broadcast_load` remains a VMI semantic; E2B is one strategy with exact shape/layout constraints. | + +Remaining design/implementation work from this discussion is intentionally +limited to two areas: + +| Area | Work to settle | Required proof before enabling | +|---|---|---| +| Masked store | Let `masked_store` request the same lane map for value and mask, or keep the existing legal path when the mask cannot be assigned/rematerialized into that lane map. | No path may lower a lane-stride value with a stale contiguous user mask; lowering must compact the assigned mask into the packed-store predicate. | +| Group broadcast load | Keep `group_broadcast_load` as a VMI logical operation and make E2B only one support/lowering strategy selected by shape, element width, stride, and assigned result layout. | A failed E2B match must mean "this lowering strategy is unavailable", not "the VMI op is invalid" unless no fallback strategy is registered. | + +Known support boundaries that are not part of this discussion's remaining-work +queue: + +```text +b32 contiguous <-> lane_stride register fallback through generic vpack/vunpack +generic scalar broadcast-load VMI semantic for BRC +dense lane-stride masked_load +arity-changing register fallback +LS=4 b16/b32 direct compact load/store +LS > 4 direct compact load/store +non-zero lane_offset / lane_phase +ordinary load cloning/rematerialization without safe-read proof +global cost search across conflicting consumer layouts +partial-chunk dense lane-stride direct memory beyond the current full-chunk gate +``` + +### 1.1 VPTO Dist Capability Boundary + +VPTO already exposes fixed distribution families that can implement specific +layout-producing or layout-consuming memory operations: + +```text +vlds: + NORM + BRC_B8/B16/B32 + US_B8/B16 + DS_B8/B16 + UNPK_B8/B16/B32 + BRC_BLK + E2B_B16/B32 + UNPK4 + SPLT4CHN + SPLT2CHN_B8/B16 + +vldsx2: + BDINTLV + DINTLV_B8/B16/B32 + +vsts: + NORM_B8/B16/B32 + 1PT_B8/B16/B32 + PK_B16/B32/B64 + PK4_B32 + MRG4CHN_B8 + MRG2CHN_B8/B16 + +vstsx2: + INTLV_B8/B16/B32 +``` + +These are not equivalent to an arbitrary dense `lane_stride` operand: + +```text +DINTLV/INTLV: + two-stream deinterleave/interleave memory operation + maps naturally to VMI deinterleaved layouts, not to one sparse semantic stream + +US/DS: + fixed 2x upsample/downsample load families for b8/b16 + can serve selected lane-map producers when the semantic mapping matches exactly + +UNPK/PK/PK4: + fixed slot-pack/slot-unpack memory families + directly express selected dense lane_stride layouts such as b16 LS=2 and + b8 LS=4, but not arbitrary LS=N + +BRC/E2B/BRC_BLK: + fixed broadcast or group-expansion load families + useful when logical broadcast plus assigned layout matches the family + +MRG/SPLT: + fixed channel merge/split families + useful only for matching channel layouts +``` + +So VPTO has enough surface area to support selected dense lane-stride memory +optimizations, but VMI must model them as explicit support cases: + +```text +VMI layout fact + op semantics + element width + -> exact VPTO dist family + or materialization/rematerialization sequence + or unsupported +``` + +Concrete support matrix for dense phase-zero `lane_stride`: + +| Dense lane_stride | Compact stream load -> dense LS | Single-scalar broadcast load -> dense LS | Dense LS -> compact stream store | +|---:|---|---|---| +| 2 | direct for b8/b16/b32 through `vlds UNPK_B8/B16/B32` | target dist exists as `vlds BRC_B8/B16/B32`; needs a separate single-scalar broadcast-load VMI semantic | direct for b8 through `vsts PK_B16`, b16 through `vsts PK_B32`, and b32 through `vsts PK_B64` | +| 4 | direct for b8 through `vlds UNPK4` | target dist exists as `vlds BRC_B8/B16/B32`; needs a separate single-scalar broadcast-load VMI semantic | direct for b8 through `vsts PK4_B32` | + +| VMI memory semantic | Element width | VPTO op/dist | VMI result layout | Direct dense `lane_stride` support | +|---|---:|---|---|---| +| load one scalar and every logical lane uses it | b8 | `vlds BRC_B8` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load one scalar and every logical lane uses it | b16 | `vlds BRC_B16` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load one scalar and every logical lane uses it | b32 | `vlds BRC_B32` | any dense phase-zero lane map | target dist exists; needs a separate VMI scalar broadcast-load semantic | +| load compact stream `x[i]` into semantic lane `2*i` | b8 | `vlds UNPK_B8` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `2*i` | b16 | `vlds UNPK_B16` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `2*i` | b32 | `vlds UNPK_B32` | `contiguous, lane_stride=2` | yes | +| load compact stream `x[i]` into semantic lane `4*i` | b8 | `vlds UNPK4` | `contiguous, lane_stride=4` | yes | +| load compact stream `x[i]` into semantic lane `4*i` | b16/b32 | none | `contiguous, lane_stride=4` | no direct VPTO dist | +| load compact stream `x[i]` into semantic lane `K*i`, `K > 4` | any | none | `contiguous, lane_stride=K` | no direct VPTO dist | +| load memory `x[2*i]` into logical lane `i` | b8 | `vlds DS_B8` | contiguous | no; this is memory downsample | +| load memory `x[2*i]` into logical lane `i` | b16 | `vlds DS_B16` | contiguous | no; this is memory downsample | +| load alternating memory stream into even/odd logical streams | b8 | `vldsx2 DINTLV_B8` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| load alternating memory stream into even/odd logical streams | b16 | `vldsx2 DINTLV_B16` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| load alternating memory stream into even/odd logical streams | b32 | `vldsx2 DINTLV_B32` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store semantic lane `2*i` as compact memory `x[i]` | b8 | `vsts PK_B16` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `2*i` as compact memory `x[i]` | b16 | `vsts PK_B32` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `2*i` as compact memory `x[i]` | b32 | `vsts PK_B64` | source `contiguous, lane_stride=2` | yes | +| store semantic lane `4*i` as compact memory `x[i]` | b8 | `vsts PK4_B32` | source `contiguous, lane_stride=4` | yes | +| store semantic lane `4*i` as compact memory `x[i]` | b16/b32 | none | source `contiguous, lane_stride=4` | no direct VPTO dist | +| store semantic lane `K*i` as compact memory `x[i]`, `K > 4` | any | none | source `contiguous, lane_stride=K` | no direct VPTO dist | +| store two compact streams as alternating memory | b8 | `vstsx2 INTLV_B8` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store two compact streams as alternating memory | b16 | `vstsx2 INTLV_B16` | two compact streams or deinterleaved=2 | no; not one sparse stream | +| store two compact streams as alternating memory | b32 | `vstsx2 INTLV_B32` | two compact streams or deinterleaved=2 | no; not one sparse stream | + +Masked compact stores have an extra legality rule. The `vmi.masked_store` +predicate is a logical-lane predicate, while a VPTO packed store consumes a +predicate in the compacted store coordinate after the sparse lanes have been +packed. Therefore a lane-stride value cannot be paired with an unrelated +contiguous mask and lowered directly to `PK`/`PK4`. + +The direct masked-store path is legal only when all of these hold: + +```text +value source layout == mask source layout +value/mask physical arity matches +mask granularity matches the logical value element width before compaction +target has a predicate compaction path for the packed-store dist +``` + +For example, an f16 value with `lane_stride=2` places logical lanes in even +physical lanes. If a user mask remains contiguous, mask bit `i` still denotes +logical lane `i`, not physical lane `2*i`. Passing that mask directly to +`vsts PK_B32` would gate the wrong compact positions for tail or sparse masks. +The current legal path requires the mask to carry the same lane map as the value +and then compacts it with predicate unpack operations before emitting the +packed store. Ordinary unmasked `vmi.store` is different: lowering creates the +compact prefix predicate itself, so there is no user mask to reinterpret. + +Until masked-store assignment can request and prove the same lane map for value +and mask, assignment must keep masked-tail narrowing on an existing legal path +instead of choosing a lane-stride trunc result solely because the store could +otherwise use `PK`. + +Concrete implementation plan for lane-stride `masked_store`: + +```text +lib/PTO/Transforms/VMILayoutAssignment.cpp + +1. Replace the current VMIMaskedStoreOp consumer request: + requestDataUse(value, contiguous) + requestMaskUse(mask, contiguous, elementGranularity) + + with a helper: + getPreferredMaskedStoreUseRequest(store) + -> optional {valueLayout, maskLayout, maskGranularity} + +2. The helper must be support-query driven: + candidateValueLayout = getPreferredDenseStoreUseLayout(store.value) + if candidateValueLayout is not dense lane_stride: + return none + + sourceValueType = value type with candidateValueLayout + resultValueType = value type with contiguous layout + sourceMaskType = mask type with: + elementCount = value lanes + granularity = getMaskGranularityForElement(value element type) + layout = candidateValueLayout + resultMaskType = same mask granularity with contiguous layout + + require: + canFoldContiguousMaskedStoreMaterialization( + sourceValueType, sourceMaskType, + resultValueType, resultMaskType) + + return {candidateValueLayout, candidateValueLayout, maskGranularity} + +3. The store request becomes: + if helper returns a request: + requestDataUse(store.value, request.valueLayout) + requestMaskUse(store.mask, request.maskLayout, request.maskGranularity) + else: + requestDataUse(store.value, contiguous) + requestMaskUse(store.mask, contiguous, elementGranularity) + +4. Replace the coarse hasMaskedStoreUse(trunc.result) guard with a support + predicate. A trunc result should stay on the conservative path only when a + masked_store use cannot request the same lane_stride value/mask layout. + Do not make trunc inspect mask producers directly; masked_store remains the + consumer that owns the joint value/mask request. +``` + +The intended dataflow after assignment is: + +```text +%n = vmi.trunc* %wide + : source contiguous -> result contiguous, lane_stride = W + +%m_ls = vmi.ensure_mask_layout %m + : mask contiguous -> mask contiguous, lane_stride = W + +vmi.masked_store %n, %dst[%off], %m_ls +``` + +or, if the mask producer can already produce the requested lane map: + +```text +%m_ls = mask producer result + : !vmi.mask<..., layout = contiguous, lane_stride = W> + +vmi.masked_store %n, %dst[%off], %m_ls +``` + +The assignment pass does not emit VPTO predicate compaction. It only creates +the local VMI proof that value and mask have the same lane map. Existing later +stages then do the mechanical work: + +```text +vmi-layout-fold: + may fold ensure_layout(value) + ensure_mask_layout(mask) into masked_store + only through canFoldContiguousMaskedStoreMaterialization + +vmi-to-vpto: + sees valueLayout == maskLayout + calls createDenseLaneStrideStorePredicate + emits LOWER punpack on the mask + emits vsts PK_B16/PK_B32/PK4_B32 as selected by value element width/layout +``` + +Required masked-store tests: + +```text +assignment positive: + truncf/trunci -> masked_store where value LS=2 b16 or LS=4 b8 is supported + CHECK value result layout is lane_stride + CHECK mask use is requested as the same lane_stride, with ensure_mask_layout + when the original mask is contiguous + +fold positive: + ensure_layout(value lane_stride -> contiguous) and + ensure_mask_layout(mask lane_stride -> contiguous) feeding masked_store + CHECK masked_store consumes the lane_stride value/mask directly + +lowering positive: + lane_stride value + same-lane-map mask feeding masked_store + CHECK mask compaction uses punpack + CHECK store dist is PK_B32 for b16 LS=2 and PK4_B32 for b8 LS=4 + +fallback: + mask cannot be assigned/materialized to the candidate lane_stride + CHECK masked_store keeps contiguous value/mask request + CHECK no PK/PK4 masked compact store is emitted with a stale contiguous mask +``` + +The remaining VPTO dist tokens are fixed non-lane-stride operations: + +```text +UNPK_B8/B16/B32: + compact load into one element per 16/32/64-bit slot, giving lane_stride=2 for + b8/b16/b32 dense values + +UNPK4: + compact load into one b8 element per 32-bit slot, giving lane_stride=4 for b8 + +PK_B16/B32/B64 and PK4_B32: + compact store from one active low element per 16/32/64-bit slot. PK_B32 is + exactly the direct compact store for a b16 value with lane_stride=2, and + PK4_B32 is exactly the direct compact store for a b8 value with lane_stride=4 + +MRG4CHN_B8 and MRG2CHN_B8/B16: + fixed channel merge stores, not generic lane_stride stores + +SPLT4CHN and SPLT2CHN_B8/B16: + fixed channel split loads, not generic lane_stride loads + +BRC_BLK and E2B_B16/B32: + usable only after their exact block/group expansion semantic is modeled as a + VMI broadcast producer; do not count them as generic dense lane_stride load +``` + +### 1.2 Contiguous/Lane-Stride Fallback Materialization + +Direct load/store support is preferred. When a value already lives in VPTO +registers and a consumer requires the other layout, `ensure_layout` provides the +fallback conversion between contiguous and dense phase-zero `lane_stride`. + +For `contiguous -> lane_stride`, use register unpack placement when the VPTO +surface supports the required carrier type: + +```text +LS=2: + use vzunpack/vsunpack-style widening placement + b8 contiguous -> b16 slots with low b8 semantic + b16 contiguous -> b32 slots with low b16 semantic + b32 contiguous -> b64 slots with low b32 semantic + +LS=4: + for b8, apply two LS=2 unpack placements: + b8 contiguous -> b16 slots -> b32 slots with low b8 semantic +``` + +For `lane_stride -> contiguous`, use register pack when the VPTO surface +supports the required carrier type: + +```text +LS=2: + use vpack-style narrowing placement + low b8 from each b16 slot -> b8 contiguous + low b16 from each b32 slot -> b16 contiguous + low b32 from each b64 slot -> b32 contiguous + +LS=4: + for b8, apply two LS=2 pack placements: + low b8 from each b32 slot -> b16 slots -> b8 contiguous +``` + +This is the register-side counterpart of `UNPK`/`PK` memory distributions. Do +not use `vintlv`/`vdintlv` as the primary fallback for dense `lane_stride`; those +belong to two-stream interleave/deinterleave layouts. + +Current checked-in VPTO coverage: + +```text +register pack: + vpack supports integer 32 -> u16 and integer 16 -> u8 + so b16 LS=2 -> contiguous and b8 LS=2/4 -> contiguous are directly covered + when the VMI source/result physical arity is the same + b32 LS=2 -> contiguous needs 64 -> 32 pack support or another materializer + +register unpack: + vsunpack/vzunpack support integer widening by 2x + so integer b8/b16 contiguous -> LS=2 and b8 contiguous -> LS=4 are covered + when the VMI source/result physical arity is the same + +floating-point lane_stride: + b8/b16 FloatType values use bit-preserving vbitcast to unsigned integer + carriers around the same pack/unpack sequence; non-FloatType low precision + types need a VPTO vbitcast contract before enabling this fallback + +arity-changing lane_stride materialization: + contiguous -> lane_stride can be expressed as multiple unpack parts, and + lane_stride -> contiguous needs an explicit multi-part merge/pack plan. + The current stage rejects those helpers instead of guessing a cross + physical-chunk materialization. +``` + +This fallback is a materialization cost, not a layout preference. Assignment may +insert the `ensure_layout`; later folding/rematerialization should remove it when +the producer or consumer has direct support: + +```text +load -> ensure_layout(lane_stride) + fold into a VMI load whose result has the requested lane_stride; vmi-to-vpto + later lowers that load to UNPK when the element width and stride match. + BRC remains the target dist for a separate scalar broadcast-load VMI semantic. + +ensure_layout(lane_stride) -> store + fold into a VMI store that directly consumes the lane_stride value; vmi-to-vpto + later lowers that store to PK/PK4 when the element width and stride match + +ordinary producer -> ensure_layout(contiguous <-> lane_stride) + lower to register pack/unpack materialization when the element width is + supported +``` + +### 1.3 Pass Responsibilities + +Dense `lane_stride` should use the existing helper-driven layout pipeline. Do +not add a separate global candidate solver for the current stage. + +```text +pto-validate-vmi-ir: + verify surface syntax before assignment + reject malformed dense lane_stride attrs once the parser accepts them + keep lane_offset unavailable in the public attr + +vmi-layout-assignment: + assign explicit dense layouts, including lane_stride, on VMI value types + use op support queries to choose local cast relations: + widening can request source lane_stride=W and result contiguous + narrowing can request source contiguous and result lane_stride=W + keep unsupported or conflicting uses legal by inserting ensure_layout + serialize all decisions as type attrs or helper ops + do not clone producers, fold memory ops, or solve a global cost problem + +canonicalize/cse: + remove dead helpers and merge identical rematerialized values when normal MLIR + canonicalization can prove equivalence + no lane_stride-specific decision logic + +vmi-layout-rematerialize: + consume producer -> ensure_layout shapes + clone/rematerialize cheap producers directly in the requested lane_stride + layout when the producer support query says it can create that layout + examples: scalar broadcast, splat constants, iota, layout-transparent chains, + widening ext, and supported mask producers + do not rematerialize ordinary loads unless the load form has an explicit + safe-read proof and direct UNPK lowering support + +vmi-layout-fold: + consume helper-adjacent producer/consumer shapes + fold ensure_layout(lane_stride) feeding store into a VMI store that directly + consumes the lane_stride value when the support table has a direct compact + store lowering; this is still a VMI store, not a VPTO PK op + fold load -> ensure_layout when the load can directly produce the requested + lane map with UNPK and the rewrite preserves one load at the original + program point + fold identity lane-map conversions + leave unsupported conversions as explicit ensure_layout for validation or + vmi-to-vpto materialization + +vmi-layout-sink-materialization: + move ensure_layout across pure layout-transparent ops when all operands/results + can keep one identical dense lane map + reduce duplicated contiguous <-> lane_stride materializations + do not sink through cast, load, store, reduce, group_broadcast, or control flow + +pto-validate-vmi-layout-ir: + verify every dense value has a supported layout attr + verify ensure_layout has a supported materialization path: + identity + contiguous <-> lane_stride through register pack/unpack when supported + existing contiguous <-> deinterleaved relations + verify direct layout-aware load/store choices: + LS=2 b8/b16/b32 through UNPK/PK + LS=4 b8 through UNPK4/PK4 + BRC only after a scalar broadcast-load VMI semantic is modeled + reject unsupported direct cases such as LS=4 b16/b32 compact load/store + +vmi-to-vpto: + lower only from assigned type attrs, helper ops, and op attributes + emit direct vlds/vsts dist for UNPK/PK-supported memory cases + lower surviving contiguous <-> lane_stride ensure_layout through register + pack/unpack materialization when the VPTO verifier supports the carrier path + lower widening/narrowing casts according to the assigned source/result + lane_stride relation and concrete vcvt part + emit diagnostics instead of inventing hidden layout conversions +``` + +Implementation impact by pass/component: + +| Component or pass | Lane-stride implementation work | +|---|---| +| `VMILayoutAttr` ODS/C++ helpers | Yes. Add dense `laneStride` storage, parse/print, verifier, equality, lane-map helpers, and keep it separate from group-slot carrier packing. | +| VMI type physicalization helpers | Yes. Compute dense physical arity from `laneStride`; expose carrier-slot lowering helpers for pack/unpack paths without changing the VMI logical element type. | +| `VMILayoutSupport` / target capability helpers | Yes. Add support queries for dense `lane_stride` layouts, cast layout relations, direct UNPK/PK memory support, and contiguous `<->` lane-stride materialization support. BRC remains target capability for a separate scalar broadcast-load semantic. | +| `pto-validate-vmi-ir` | No lane-stride-specific pass algorithm. It relies on attr/op verifier updates; keep the existing surface-IR validation role. | +| `vmi-layout-assignment` | Yes. Assign dense lane-stride layouts when support queries choose them; insert `ensure_layout` for incompatible uses; serialize all decisions in types/helpers. | +| `canonicalize/cse` between VMI passes | No implementation. It remains ordinary cleanup for dead helpers and identical rematerialized producers. | +| `vmi-layout-rematerialize` | Yes. Teach producer rematerialization to create requested dense lane-stride layouts for cheap/safe producers. Do not add ordinary load remat without safe-read proof. | +| `vmi-layout-fold` | Yes. Fold `ensure_layout` into layout-aware VMI consumers, especially stores that can consume lane_stride and later lower to `PK/PK4`; fold `load -> ensure_layout` into a direct layout-aware load when it can preserve one load at the original program point; fold identity lane-map conversions. | +| `vmi-layout-sink-materialization` | Minimal generic update. It should compare dense layout keys including `laneStride` and reuse existing layout-transparent sinking; do not add cast/load/store/reduce-specific lane-stride patterns here. | +| `vmi-legalize-arith-select` | No implementation. Lane stride does not change scalar-condition select legalization. | +| `pto-validate-vmi-layout-ir` | Yes. Reject unsupported assigned layouts/helpers before lowering, including unsupported LS=4 b16/b32 compact load/store and unsupported register pack/unpack materializations. | +| `vmi-to-vpto` | Yes. Lower assigned dense lane-stride layouts, direct `UNPK/PK` memory cases, register pack/unpack `ensure_layout`, and lane-stride-aware ext/trunc lowering. | +| VPTO op verifier/emitter | Only if needed by the selected support matrix. Existing `vlds/vsts` dist tokens are already present; extending register fallback to b32 or floating-point carriers requires verifier/emitter support for the corresponding pack/unpack or bitcast form. | +| Lower VPTO/backend passes after `vmi-to-vpto` | No lane-stride-specific implementation. They see ordinary VPTO ops and existing dist tokens. | + +Any pass not listed above should not implement lane-stride-specific logic in the +current stage. New behavior must enter through the explicit layout attr, +support queries, helper ops, validation, or `vmi-to-vpto` lowering. + +Current-stage component checklist: + +This checklist records the components that participate in the current-stage +lane-stride implementation. It is not the remaining-work queue; remaining work +is limited to the masked-store and group-broadcast-load items above. + +```text +include/PTO/IR/VMIAttrs.td +lib/PTO/IR/VMI.cpp + add laneStride storage for dense contiguous/deinterleaved layouts + keep group-slot laneStride parse/print compatibility + add getContiguous(ctx, laneStride) and getDeinterleaved(..., laneStride) + split helpers into isDenseLaneStrided(), isGroupSlotLaneStrided(), + getLaneStride(), and exact dense lane-map equality helpers + update attr verifier so laneStride > 0 and lane_offset is not accepted + +lib/PTO/IR/VMI.cpp +lib/PTO/Transforms/VMIToVPTO.cpp + replace the current "hasLaneStride implies unsigned carrier widening" helper + with: + logical-element physicalization for ordinary dense VPTO values + selected carrier-slot physicalization for pack/unpack materializations + existing group-slot packed-byte carrier lowering + +include/PTO/Transforms/VMILayoutSupport.h +lib/PTO/Transforms/VMILayoutSupport.cpp + extend VMIContiguousStoreSupportKind with dense lane-stride PK/PK4 cases + extend VMILayoutMaterializationSupportKind with: + ContiguousToLaneStrideViaUnpack + LaneStrideToContiguousViaPack + LaneStrideToLaneStrideViaContiguous, only if needed + update getPreferredCastLayoutFact: + baseline facts remain contiguous <-> deinterleaved + add optional preferred dense lane-stride fact from the conversion ratio W + and the target single-part cast support; do not inspect source producers + here + update getWidenSourceLayoutForResultLayout for dense lane_stride result/source + update getContiguousStoreSupport and canFoldContiguousStoreMaterialization for + LS=2 b8/b16/b32 -> PK_B16/B32/B64 + LS=4 b8 -> PK4_B32 + update canMaterializeDataLayout for contiguous <-> dense lane_stride through + register pack/unpack when the element/carrier path is supported + +lib/PTO/Transforms/VMILayoutAssignment.cpp + teach natural/preferred layout collection to accept dense lane_stride facts + from VMILayoutSupport + keep conflict handling unchanged: insert ensure_layout at mismatched uses + do not add producer cloning, memory folding, or global cost selection here + +lib/PTO/Transforms/VMILayoutRematerialize.cpp + allow cheap producers to be cloned with dense lane_stride result types when + VMILayoutSupport says the producer can directly create that lane map + keep ordinary load/group_load/masked_load blocked until a safe-read proof is + added for the specific direct UNPK lowering + +lib/PTO/Transforms/VMILayoutFold.cpp + add producer-side fold for load -> ensure_layout: + replace the load result layout with the ensure target layout when the load + has no other incompatible uses and VMILayoutSupport has direct UNPK + support + erase the ensure_layout and keep a single load at the original program point + do not clone ordinary loads in this fold + add fold for ensure_layout(lane_stride -> contiguous) feeding pto.vmi.store or + pto.vmi.masked_store into a VMI store that consumes the lane_stride source + directly; this pass does not emit or model VPTO PK. VMIToVPTO later selects + the corresponding PK/PK4 store lowering from the assigned VMI store contract + masked_store direct fold additionally requires the mask to carry the same + dense lane_stride layout and a compactable element-width granularity: + LS=2 b8/b16 and LS=4 b8 are supported through LOWER punpack mask compaction + LS=2 b32 is left as explicit materialization until b32 lane-stride mask + compaction is specified and implemented + a contiguous user mask is not enough, even if the value layout can be + compact-stored; assignment/rematerialization must first derive the same + lane map for the mask + fold exact dense lane-map identity helpers + do not fold unsupported LS=4 b16/b32 cases + +lib/PTO/Transforms/VMILayoutSinkMaterialization.cpp + include laneStride in dense layout equality/support checks + reuse existing layout-transparent sinking logic + do not add lane-stride-specific sinking through casts or memory ops + +lib/PTO/Transforms/PTOValidateVMIIR.cpp + no new lane-stride algorithm + validation changes should come from attr/op verifier and VMILayoutSupport + diagnostics at the layout gate + +lib/PTO/Transforms/VMIToVPTO.cpp + update OneToN physical type conversion for dense laneStride and carrier slots + lower direct compact loads: + LS=2 b8/b16/b32 -> vlds UNPK_B8/B16/B32 + LS=4 b8 -> vlds UNPK4 + lower direct compact stores: + LS=2 b8/b16/b32 -> vsts PK_B16/B32/B64 + LS=4 b8 -> vsts PK4_B32 + lower direct compact masked_stores: + LS=2 b8/b16 -> LOWER punpack mask compaction + vsts PK_B16/B32 + LS=4 b8 -> two LOWER punpack steps + vsts PK4_B32 + LS=2 b32 -> no direct masked compact store until b32 lane-stride mask + compaction is specified and implemented + lower surviving ensure_layout contiguous <-> lane_stride through vpack and + vsunpack/vzunpack when the carrier path is legal + lower lane-stride-aware ext by selecting the concrete vcvt part from + the assigned source/result relation + lower lane-stride-aware trunc by selecting the concrete vcvt part from + the assigned source/result relation + +lib/PTO/IR/VPTO.cpp +lib/PTO/Transforms/VPTOLLVMEmitter.cpp +lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp + no change for existing vlds/vsts dist tokens + extend vpack/vsunpack/vzunpack verifier/emitter only if the first implemented + fallback needs currently unsupported b64->b32 or floating-point carrier paths + +test/lit/vmi + add parser/verifier tests for dense laneStride attrs + add assignment tests for ext lane-stride facts + add fold/remat/sink tests for helper-driven rewrites + add vmi-to-vpto checks for UNPK/PK and vpack/unpack fallback + add negative tests for unsupported LS=4 b16/b32 compact load/store +``` + +Load/`ensure_layout` fold algorithm: + +```text +input shape: + %x0 = pto.vmi.load ... : !vmi.vreg + %x1 = pto.vmi.ensure_layout %x0 + : !vmi.vreg + -> !vmi.vreg + +preconditions: + the load result has no other use that requires the old layout + the load semantics are a compact logical stream + VMILayoutSupport says the target layout has a direct load lowering: + compact stream: + LS=2 b8/b16/b32 -> UNPK_B8/B16/B32 + LS=4 b8 -> UNPK4 + masks/passthroughs, if present, already have compatible assigned layouts + +rewrite: + replace the original load op in place, or create the replacement load at the + same insertion point and erase the old load + the replacement load result type is the ensure target type + all ensure users use the replacement load result + erase the ensure_layout + +output shape: + %x = pto.vmi.load ... : !vmi.vreg + +lowering: + vmi-to-vpto emits the corresponding vlds UNPK dist +``` + +This fold changes the assigned result layout of the existing load; it does not +clone the load at the helper use-site. If the original load has both contiguous +and lane-stride consumers, the fold must leave the helper in place unless a +separate rematerialization step has a safe-read proof to clone the load. + +### 1.4 Scenario Ownership + +Each optimization scenario has exactly one owning pass. Other passes may verify +or lower the resulting explicit IR, but should not solve the same rewrite again. + +| Scenario | Example shape | Owning pass | Non-owners | +|---|---|---|---| +| Assign a layout request | `ext -> store` where store wants contiguous | `vmi-layout-assignment` inserts explicit layouts/helpers | Assignment does not clone, fold, or lower | +| Direct load produces requested lane map | `load(contiguous) -> ensure_layout(lane_stride=2)` | `vmi-layout-fold` rewrites the original load result layout when UNPK support exists | Remat must not clone this load without safe-read proof | +| Direct store consumes lane map | `ensure_layout(lane_stride -> contiguous) -> store` | `vmi-layout-fold` rewrites the VMI store to consume the lane_stride source directly when direct compact-store support exists | `vmi-to-vpto` emits the actual `vsts PK/PK4` | +| Cheap producer can produce target layout | `broadcast -> ensure_layout(lane_stride=2)` | `vmi-layout-rematerialize` rebuilds broadcast with lane-stride result | Fold does not rebuild arbitrary producers | +| Widening ext can move materialization to cheap source | `ext -> ensure_layout(contiguous)` with source broadcast/load-fold case | `vmi-layout-rematerialize` rebuilds ext with required source lane stride | Assignment only creates the helper; fold only handles the load subcase | +| Layout-transparent op has ensured operands | `ensure(a), ensure(b) -> add` | `vmi-layout-sink-materialization` sinks matching helpers to the result | Remat handles the opposite shape `add -> ensure` | +| Surviving supported helper | `ensure_layout(contiguous <-> lane_stride)` after optimizations | `vmi-to-vpto` lowers to register pack/unpack | Earlier passes are allowed to leave it explicit | +| Unsupported helper or layout | `lane_stride=4 b16 compact store` | `pto-validate-vmi-layout-ir` rejects before lowering | `vmi-to-vpto` should not invent a repair | +| Multi-consumer value with incompatible layouts | one load feeds contiguous user and lane-stride user | baseline keeps helper; optional remat only with safe-read proof | Fold must not silently duplicate memory effects | + +Examples: + +```text +load fold, owned by vmi-layout-fold: + before: + %x0 = pto.vmi.load ... : contiguous + %x1 = pto.vmi.ensure_layout %x0 : contiguous -> lane_stride=2 + after: + %x1 = pto.vmi.load ... : lane_stride=2 + vmi-to-vpto: + %x1 = pto.vlds ... {dist = "UNPK_B8/B16/B32 or UNPK4"} + +store fold, owned by vmi-layout-fold: + before: + %x_c = pto.vmi.ensure_layout %x_ls : lane_stride=2 -> contiguous + pto.vmi.store %x_c, %dst + after: + pto.vmi.store %x_ls, %dst // VMI store consumes lane_stride source + vmi-to-vpto: + pto.vsts %x_ls, %dst {dist = "PK_B16/B32/B64 or PK4_B32"} + +broadcast remat, owned by vmi-layout-rematerialize: + before: + %b0 = pto.vmi.broadcast %s : contiguous + %b1 = pto.vmi.ensure_layout %b0 : contiguous -> lane_stride=2 + after: + %b1 = pto.vmi.broadcast %s : lane_stride=2 + +elementwise sink, owned by vmi-layout-sink-materialization: + before: + %a1 = ensure_layout %a0 -> lane_stride=2 + %b1 = ensure_layout %b0 -> lane_stride=2 + %c1 = pto.vmi.addf %a1, %b1 + after: + %c0 = pto.vmi.addf %a0, %b0 + %c1 = ensure_layout %c0 -> lane_stride=2 +``` + +## 2. IR Attribute Changes + +### 2.1 Extend `VMILayoutAttr` + +Current storage reuses `blockElems` as group-slot `lane_stride`. Generalization +should first split lane stride from block elems: + +```c++ +kind +factor +blockElems +slots +laneStride +``` + +Meaning by kind: + +```text +contiguous: + factor = 1 + blockElems = 1 + slots = 0 + laneStride >= 1 + +deinterleaved: + factor = F + blockElems = B + slots = 0 + laneStride >= 1 + +num_groups: + factor = G + blockElems = 1 + slots = K + laneStride >= 1 +``` + +Do not add a public `laneOffset` field in the current stage. The targeted +optimization only needs phase-zero strided dense layouts. A future +phase field is justified only when there is a concrete VMI value whose logical +lane map is intentionally non-zero-phase, for example a zero-copy +deinterleave/extract view that keeps the odd lanes in place or a narrowing +conversion whose consumer explicitly requires an odd-lane result. + +Recommended helpers: + +```c++ +bool isDense() const; +bool hasDenseLaneStride() const; +bool hasGroupSlotLaneStride() const; +int64_t getLaneStride() const; +VMILayoutAttr withLaneStride(int64_t stride) const; +``` + +Keep old constructor defaults source-compatible where possible: + +```c++ +getContiguous(ctx) +getDeinterleaved(ctx, factor, blockElems = 1, + laneStride = 1) +getGroupSlots(ctx, numGroups, slots = 0, laneStride = 1) +``` + +### 2.2 Parser And Printer + +Accepted dense spellings: + +```text +#pto.vmi.layout +#pto.vmi.layout + +#pto.vmi.layout +#pto.vmi.layout +#pto.vmi.layout +``` + +Existing group-slot spellings remain valid: + +```text +#pto.vmi.layout +#pto.vmi.layout +``` + +Printing omits defaults: + +```text +lane_stride = 1 is omitted +``` + +### 2.3 Verifier + +Verifier rules: + +```text +all layouts: + laneStride > 0 + +contiguous: + factor == 1 + blockElems == 1 + slots == 0 + +deinterleaved: + factor in supported dense factors + blockElems > 0 + slots == 0 + +num_groups: + factor > 0 + slots >= 0 + blockElems == 1 +``` + +The verifier should not require every strided layout to fit one VPTO register. +Fit depends on the VMI type shape and element type, so it belongs in type +physicalization and op support checks. + +## 3. Physicalization Helpers + +### 3.1 Separate Element Carrier From Lane Map + +Replace the current shared helper shape: + +```c++ +getVMIPhysicalElementType(type) +``` + +with two concepts: + +```c++ +getVMILogicalStorageElementType(type) +getVMIPhysicalCarrierElementType(type, loweringKind) +``` + +Dense lane-strided values keep the VMI logical element type. The lowering may +represent the same lane map either as logical-element lanes or as wider carrier +slots when the selected VPTO instruction is a pack/unpack family. + +Logical-element lane representation: + +```text +!vmi.vreg + -> !pto.vreg<128xf16> physical register +``` + +Carrier-slot representation for pack/unpack lowering: + +```text +!vmi.vreg + -> low ui16 in each ui32 slot for vpack/PK_B32-style lowering + +!vmi.vreg + -> low ui8 in each ui32 slot for PK4_B32-style lowering +``` + +Group-slot packed stores also request a wider carrier in the specific lowering +path: + +```text +!vmi.vreg + group_store -> b32 carrier + PK4_B32 +``` + +Do not let dense `hasLaneStride()` imply unsigned-integer carrier widening +globally. Carrier widening is a property of a selected materialization or +load/store lowering, not of the VMI logical type itself. + +### 3.2 Physical Arity + +Add a dense lane-map helper: + +```c++ +struct DenseLaneMap { + int64_t deinterleaveFactor; + int64_t blockElems; + int64_t laneStride; +}; + +int64_t getPhysicalLaneForDenseLogicalLane(DenseLaneMap map, + int64_t logicalLane); +``` + +For a VMI vreg type: + +```text +lanesPerVPTO = getVPTOPhysicalLanes(elementType) +lanesInDensePart = ceil(N / F) with block-aware distribution +requiredLanes = O + (lanesInDensePart - 1) * LS + 1 +registersPerDensePart = ceil(requiredLanes / lanesPerVPTO) +physicalArity = F * registersPerDensePart +``` + +For the current stage, require full block divisibility for dense +deinterleaved strided layouts, matching existing direct lowering restrictions: + +```text +N % (F * B) == 0 +``` + +Relaxing tail handling is outside the current stage and should be enabled only +with an explicit materialization/lowering proof. + +## 4. Layout Support Interface + +Extend support queries to include dense strided layouts: + +```text +supportsResultLayout(op, resultIndex, layout) +supportsOperandLayout(op, operandIndex, layout) +supportsLayoutRelation(op, operandLayouts, resultLayouts) +``` + +The important change is relation support. Some ops are not independently +described by "operand supports layout X" and "result supports layout Y"; they +support specific pairs. + +Examples: + +```text +elementwise: + all dense operands/results must use identical dense layout key + +extf/extui/extsi: + source/result layouts must satisfy a widening relation + +truncf/trunci: + dense narrowing may request source contiguous and result lane_stride=W, where + W is the storage-width narrowing factor; masked-store consumers stay on the + existing legal deinterleaved-to-contiguous path until mask lane-stride + assignment/materialization is available + +broadcast/group_broadcast: + result may use a dense layout only when the materialization lowering has an + explicit support case for that lane map + +load: + default result contiguous + producer rematerialization may create selected strided layouts if a direct + load/mask sequence can produce that lane map + +store: + memory effect is contiguous unless the op is an explicit logical interleave + store; a strided source requires store lowering support or ensure_layout +``` + +Assignment should still insert `ensure_layout` for incompatible use-local +requests. Rematerialization/fold can later remove it. + +### 4.1 Current Framework Fit + +The existing assignment pass already has use-site requests. For example, +`pto.vmi.store` requests a contiguous source operand, and assignment can insert +`ensure_layout` when the stored value is assigned another layout. + +The dense-stride `ext` optimization should keep the same model: the cast op is +the layout-entry point and stores one preferred source/result relation. The +current preferred relation is: + +```text +extf: + request source contiguous + set result deinterleaved=W +``` + +The current stage keeps the existing single-preference framework +and let `ext` choose one fact for the current op: + +```text +baseline fact: + source contiguous + result deinterleaved=W + +lane-stride fact: + source lane_stride=W + result contiguous +``` + +The `ext` support query chooses between these facts from op-local information: + +```text +conversion ratio W +target support for one selected hardware conversion part +requested or preferred result layout for the current op instance +``` + +It does not inspect the defining source producer. If it selects the +lane-stride fact and the source is not already in that layout, assignment inserts +an explicit source `ensure_layout`. Later passes either discharge that helper by +rematerializing/folding a concrete producer, lower it with a registered +pack/unpack materializer, or let validation reject the unsupported relation. + +## 5. Widening Conversion Lowering + +Let: + +```text +W = result element storage bits / source element storage bits +``` + +For a dense source layout: + +```text +source lane_stride = LS +``` + +Single-part lowering is legal when: + +```text +LS % W == 0 +``` + +Then: + +```text +hardware part = 0 +result lane_stride = LS / W +``` + +The current stage only emits the zero-phase single-part conversion. +`vcvt ODD` remains necessary for full packed contiguous conversion, but that is +handled by the existing multi-part relation: + +```text +source contiguous, lane_stride=1 +result deinterleaved=W +``` + +Do not add a phase field merely to name that existing ODD instruction. Add a +phase field only when an assigned VMI layout needs to represent a concrete +zero-copy value/view already resident in odd/non-zero-phase lanes. + +The support query for the conversion should accept the pair only when the +requested result layout equals this computed result lane map, including +deinterleave/block fields. + +The support query should expose helpers for both legal facts, but assignment +chooses one immediately: + +```text +baseline fact: + source contiguous + result deinterleaved=W + lowering cost = W conversion parts + +lane-stride fact: + result contiguous + source same dense shape with lane_stride = W + lowering cost = one conversion part +``` + +For example, for `f16 -> f32`, the `extf` op can prefer +`source lane_stride=2 -> result contiguous` when the target has a single EVEN +conversion for that relation. The source producer is handled by the explicit +source `ensure_layout` and later fold/rematerialization; it is not part of the +cast support query. + +Current contiguous widening remains a separate legal relation: + +```text +source dense contiguous, lane_stride=1 +result deinterleaved=W, lane_stride=1 +``` + +Implementation steps: + +1. Factor conversion ratio calculation by storage bit width. +2. Add helper that computes baseline conversion count. +3. Add helper that computes lane-stride conversion count and required source + layout. +4. Teach `VMIToVPTO` conversion lowering to emit only the selected hardware + part when the relation is single-part. +5. Keep existing multi-part lowering for contiguous-to-deinterleaved cases. +6. Add diagnostics when an assigned conversion layout pair has no lowering. + +Hardware part names should be abstracted: + +```text +W=2: + part 0 -> EVEN + part 1 -> ODD + +W=4: + part 0..3 -> target-specific conversion part names or the existing sequence +``` + +Do not special-case f16/f32 in the matcher. The type only determines `W` and +the concrete VPTO conversion opcode. + +## 6. Narrowing Conversion Lowering + +Let: + +```text +W = source element storage bits / result element storage bits +``` + +For a single-part narrowing relation: + +```text +result lane_stride = source lane_stride * W +hardwarePart = 0 for the current stage +``` + +Implementation steps: + +1. Share ratio and lane-map helpers with widening. +2. Add support query for valid narrowing layout pairs. +3. Lower single-part narrowing directly when the target has a part-selecting + narrow instruction. +4. Preserve existing deinterleaved-to-contiguous narrowing for the packed full + result case. + +This is the same family as the recently discussed `d4 -> c -> d2 -> vcvt -> c` +optimization: if a cast op has a direct source/result layout relation, +assignment/rematerialization should expose that relation before lowering. + +## 7. Ensure-Layout And Rematerialization + +### 7.1 `ensure_layout` + +`ensure_layout` remains the explicit use-site materialization op. + +Verifier/lowering policy: + +```text +same source and target dense lane map: + fold away + +known dense relation: + lower contiguous <-> lane_stride through register pack/unpack when supported + lower contiguous/deinterleaved relations through existing intlv/dintlv paths + +producer can rematerialize target layout: + rematerialization should replace ensure_layout(producer) + +unknown relation: + reject before vmi-to-vpto +``` + +Avoid adding a generic "any dense layout to any dense layout" promise unless the +target really has a lowering for it. + +### 7.2 Rematerialization + +The current checked-in `vmi-layout-rematerialize` cheap producers are: + +```text +data: + VMIExtFOp / VMIExtSIOp / VMIExtUIOp when the source layout can be + materialized for the requested result relation + VMIFmaOp + binary layout-transparent ops: + addf/addi/subf/subi/mulf/muli/divf/minf/maxf/andi/ori/xori/shli/shrui + unary layout-transparent ops: + negf/absf/absi/sqrt/exp/ln/relu/not + VMIConstantOp only when the DenseElementsAttr is a splat + VMIBroadcastOp + VMIIotaOp + +mask: + VMICreateMaskOp + VMICreateGroupMaskOp + VMIConstantMaskOp + +special rewrite: + selected VMITruncIOp through a source ensure_layout when the cast relation is + a supported narrowing relation +``` + +Not included as cheap producers in the current pass: + +```text +load / masked_load / group_load / group_slot_load / group_broadcast / +group_broadcast_load / store / reduce / control-flow ops +``` + +Loads need a separate policy. `load -> ensure_layout` should be folded in +`vmi-layout-fold` when one original load can directly produce the requested +layout. A normal load should not be cloned/rematerialized unless a later safe-read +proof explicitly permits that clone. + +Relationship between cheap producers and dense `lane_stride`: + +```text +assignment: + creates the target layout request explicitly, usually as ensure_layout(... -> + lane_stride) or as a cast source/result relation + +rematerialize: + does not choose lane_stride as a preference + only consumes the explicit helper/request and rebuilds the producer with the + requested lane_stride result type when the producer is cheap and locally legal +``` + +Required rematerialize changes for dense `lane_stride`: + +```text +materializeDataLayout: + no special producer logic, but canMaterializeDataLayout must understand + contiguous <-> lane_stride through register pack/unpack + +splat constant / broadcast / iota: + rebuild the op with the requested lane_stride result type + lowering later materializes that layout directly or through ensure_layout + +layout-transparent unary/binary/fma: + rebuild the op with the requested lane_stride result type + materialize each operand to the same lane_stride layout before rebuilding + this relies on canMaterializeDataLayout for operand conversions + +widening ext: + update getWidenSourceLayoutForResultLayout so a requested result layout derives + the required source lane_stride: + result contiguous, W=2 -> source lane_stride=2 + result lane_stride=R, W=2 -> source lane_stride=2*R + remat then inserts/uses source ensure_layout and rebuilds ext with the + requested result layout + +trunci special rewrite: + extend the existing source-ensure rewrite to recognize lane_stride narrowing + relations, not only deinterleaved narrowing relations + +mask producers: + only participate after mask layout support defines the corresponding + lane-stride or predicate-granularity relation; otherwise unchanged +``` + +Example: + +```text +before remat: + %b0 = pto.vmi.broadcast %s : !vmi.vreg<64xf16, contiguous> + %b1 = pto.vmi.ensure_layout %b0 + : contiguous -> contiguous, lane_stride=2 + %y = pto.vmi.extf %b1 : f16 -> f32 + +after remat: + %b1 = pto.vmi.broadcast %s + : !vmi.vreg<64xf16, contiguous, lane_stride=2> + %y = pto.vmi.extf %b1 : f16 -> f32 +``` + +This removes a register layout materialization and lets `vmi-to-vpto` lower the +ext as the single selected conversion part. It is still driven by the explicit +layout request; remat does not inspect sibling consumers or choose lane_stride by +itself. + +Do lane-stride ext rematerialization only in these cases: + +```text +required shape: + ext result is followed by ensure_layout to a requested dense result layout + widening ratio W > 1 + the requested result lane_stride is R, where contiguous means R=1 + source lane_stride = R * W is supported + ext with that source layout can lower as one selected conversion part + +acceptance/safety gate: + the source-side lane_stride request must be discharged by a concrete local + rewrite, not merely moved from result side to source side + accepted cases: + the source already has the required lane_stride + the source producer is in the checked-in cheap producer list and can be + rebuilt with the required lane_stride + the source is load -> ensure_layout and vmi-layout-fold can replace it with + a single original-position layout-aware VMI load + a layout-transparent chain can be sunk/rematerialized until one of the above + concrete producer cases is reached + +do not apply: + result consumer already accepts the natural ext layout + source lane_stride = R * W is unsupported + source is an ordinary load with other incompatible consumers and no safe-read + proof to clone it + the rewrite only moves an expensive materialization from result side to source + side without exposing a direct lowering +``` + +Typical accepted shapes: + +```text +broadcast -> ext -> ensure_layout(contiguous) -> store + remat broadcast as lane_stride=W + ext lowers with one conversion part + no source-side ensure_layout remains + +load -> ensure_layout(lane_stride=W) -> ext -> store + fold load into a layout-aware VMI load + vmi-to-vpto later emits the matching UNPK dist + ext lowers with one conversion part + no extra load is cloned + +elementwise cheap chain -> ext -> ensure_layout(contiguous) + remat/sink the chain to lane_stride=W only when the chain reaches a concrete + cheap producer or direct load-fold case +``` + +## 8. Broadcast And E2B Interaction + +Do not encode E2B in `lane_stride`, and do not define +`vmi.group_broadcast_load` in terms of E2B. The VMI operation is a logical +fused memory operation: + +```text +for each logical group g: + scalar = source[offset + g * source_group_stride] + for each lane i in group g: + result[i] = scalar +``` + +The result layout is assigned separately. It may be contiguous, +deinterleaved, or dense lane-strided if the consumer asks for that lane map and +the target support table accepts it. E2B is only one VPTO lowering strategy for +a restricted subset of this logical operation. + +The layering should be: + +```text +logical group broadcast load + -> assigned dense layout, possibly lane-strided + -> support query chooses a lowering strategy + -> selected VPTO dist, if any +``` + +For the current E2B strategy, the support query checks: + +```text +source is direct memory +source_group_stride is constant 1 +num_groups is a multiple of 8 +element storage width +logical group size derived from num_groups +assigned result layout: + contiguous for the direct packet size + or deinterleaved=2, block_elems=1 for the split packet size +``` + +Then it may choose an E2B packet: + +```text +b16 contiguous: direct 1 -> 16 packet +b16 deinterleaved=2: two logical halves / 1 -> 32 reuse +b16 dense lane_stride=2: direct phase-zero strided consumer packet +b32 contiguous or strided: target-specific packet size +``` + +If those conditions do not hold, the operation is still a valid VMI semantic if +some other lowering exists, such as `group_slot_load + group_broadcast`, scalar +loads plus broadcast, or future target-specific broadcast-load support. The +failure is only "this E2B lowering strategy is not applicable", not "the VMI +operation means E2B". + +Concrete implementation plan for lane-stride `group_broadcast_load`: + +```text +include/PTO/Transforms/VMILayoutSupport.h +lib/PTO/Transforms/VMILayoutSupport.cpp + +1. Split semantic support from E2B strategy checks: + + getGroupBroadcastLoadSupport(capabilities, op) + try getE2BGroupBroadcastLoadSupport(capabilities, op) + if success: + return {kind = E2BVlds} + return failure("no registered group_broadcast_load lowering strategy; " + "E2B rejected because ...") + + getE2BGroupBroadcastLoadSupport(capabilities, op) + contains the current E2B constraints: + source is !pto.ptr direct memory + element width is b16 or b32 + source_group_stride is constant 1 + num_groups is a multiple of 8 + group size matches direct or split E2B packet size + result layout is contiguous or deinterleaved=2/block_elems=1 + result has full physical chunks + +2. Keep VMIGroupBroadcastLoadSupportKind strategy-specific: + E2BVlds means "lower this VMI semantic using E2B" + It must not be used as the definition of the VMI op. +``` + +```text +lib/PTO/Transforms/VMILayoutAssignment.cpp + +3. Rename strategy helpers so the direction is clear: + + isE2BGroupBroadcastLoadCandidate + -> isE2BGroupBroadcastLoadStrategyApplicable + + getPreferredGroupBroadcastLoadLayout + -> getPreferredE2BGroupBroadcastLoadLayout + +4. Fusion from group_slot_load + group_broadcast to group_broadcast_load remains + guarded by E2B applicability. If E2B is not applicable, do not create a + fused group_broadcast_load merely because the VMI semantic would be valid. + That avoids producing an op with no registered lowering strategy. + +5. Layout assignment for an explicit group_broadcast_load uses the support + query: + if E2B strategy applies: + assign the E2B-preferred result layout + else: + leave the op to validation unless a fallback strategy is added +``` + +```text +lib/PTO/Transforms/VMIToVPTO.cpp + +6. Replace duplicated local E2B legality checks with: + support = getGroupBroadcastLoadSupport(capabilities, op) + switch support.kind: + E2BVlds: + emit the existing E2B packet sequence + + The E2B lowering code may still assert/recheck structural invariants needed + for indexing, but user-facing diagnostics should come from the support query. + +7. Diagnostics must name the strategy: + good: "group_broadcast_load has no registered lowering strategy; E2B + rejected because source_group_stride is not constant 1" + bad: "group_broadcast_load requires constant unit source_group_stride" + + The second form is only valid inside an E2B-specific diagnostic. +``` + +Required group-broadcast-load tests: + +```text +E2B positive: + explicit group_broadcast_load with b16/b32, stride=1, matching group size, + and assigned contiguous/deinterleaved result layout + CHECK vmi-to-vpto emits E2B_B16/E2B_B32 + +E2B strategy rejection: + source_group_stride != 1, wrong group size, or unsupported element width + CHECK validation/lowering diagnostic says no registered lowering strategy and + reports E2B as the rejected strategy + +fusion guard: + group_slot_load + group_broadcast shape that is not E2B-applicable + CHECK assignment does not fuse it into group_broadcast_load + +semantic boundary: + explicit group_broadcast_load that is not E2B-applicable + CHECK failure wording does not redefine the op as E2B and does not imply the + logical VMI semantic itself is E2B +``` + +This keeps broadcast optimization generic across type width and layout, instead +of hardcoding one `ComputeY1ToFP8` scale pattern. + +## 9. Tests + +Use the following as the coverage matrix for current-stage support plus the +masked-store and group-broadcast-load follow-up items. It is not a separate +list of all remaining implementation work. + +Parser/verifier: + +```text +parse/print contiguous lane_stride +parse/print deinterleaved + block_elems + lane_stride +``` + +Physicalization: + +```text +64xf16 contiguous lane_stride=2 has one physical 128xf16 part +ui16 contiguous lane_stride=2 may lower through low ui16 in ui32 carrier slots +when the selected materialization is vpack/PK_B32 +ui8 contiguous lane_stride=4 may lower through low ui8 in ui32 carrier slots +when the selected materialization is PK4_B32 +65xf16 contiguous lane_stride=2 is rejected by direct full-chunk-only paths, or +covered only by an arity-changing materialization test outside this discussion +group-slot ui8 lane_stride=4 keeps existing carrier lowering behavior +``` + +Conversion lowering: + +```text +f16 lane_stride=2 -> f32 contiguous emits one EVEN conversion +bf16 lane_stride=2 -> f32 contiguous follows the same relation +ui8 lane_stride=2 -> ui16 contiguous follows W=2 +ui8 lane_stride=4 -> ui32 contiguous follows W=4 when target supports it +contiguous f16 -> deinterleaved=2 f32 still emits EVEN + ODD +ui16 lane_stride=2 -> contiguous can materialize with vpack 32->16 carrier path +ui8 lane_stride=4 -> contiguous can materialize with two vpack stages +``` + +Assignment/rematerialization: + +```text +extf records a strided dense source relation for a supported single-part +widening conversion +layout-transparent op propagates the same strided layout through operands/result +ensure_layout is folded when source and target lane maps match +rematerialization clones a cheap broadcast for two different dense layouts +``` + +End-to-end assignment cases: + +```text +contiguous load -> ext -> contiguous store: + uses lane_stride only when the source ensure_layout can be folded to the + original load, rematerialized from a cheap producer, or lowered by a supported + register materializer + +cheap broadcast -> ext -> contiguous store: + rematerializes broadcast as lane_stride=2 and lowers ext with one EVEN part + +producer -> ext -> deinterleaved reduce: + keeps source contiguous and result deinterleaved=2 + +cheap producer -> ext feeding both store and reduce: + keeps shared deinterleaved path for reduce and rematerializes a contiguous + result path for store only through the checked cheap-producer remat path + +group_broadcast_load -> ext -> contiguous consumer: + chooses lane_stride only if group_broadcast_load supports that lane map +``` + +Negative tests: + +```text +assigned ext layout pair where LS % W != 0 and no multi-part relation exists +ordinary dense op with mismatched lane_stride operands +store consuming strided dense layout without a supported store/materialization +``` + +## 10. Suggested Patch Order + +1. Add attr fields, parser/printer, verifier, and round-trip tests. +2. Split dense lane-map physicalization from group-slot carrier packing. +3. Update physical arity/unpack helpers for dense lane stride. +4. Extend support queries and assignment layout keys. +5. Implement widening single-part relation and tests. +6. Implement narrowing relation and tests. +7. Teach rematerialization/fold about exact dense lane-map equality. +8. Add broadcast/E2B recognition improvements that consume assigned lane maps. + +Each step should keep existing group-slot `lane_stride` tests passing. The first +functional optimization can be the `f16/bf16 lane_stride=2 -> f32 contiguous` +single-part conversion, but the IR and helper changes should already be generic +over type width and lane-map fields. diff --git a/docs/designs/vmi-layout-assignment-implementation.md b/docs/designs/vmi-layout-assignment-implementation.md index 20f6ac92ee..865adf413a 100644 --- a/docs/designs/vmi-layout-assignment-implementation.md +++ b/docs/designs/vmi-layout-assignment-implementation.md @@ -268,6 +268,12 @@ slot placement from producer or consumer context. regular gap between stored group slots. It is used for carrier-style packed stores such as `ui8` group slots lowered through b32 `PK4_B32`. +The current implementation treats this as a group-slot property. The dense +generalization is tracked separately in +`vmi-lane-stride-generalization-implementation.md`; it requires splitting dense +lane-map stride from group-slot carrier packing before `lane_stride` can be used +on `contiguous` or `deinterleaved` layouts. + ### 3.2 VMI Types Surface: @@ -544,7 +550,8 @@ Implementation-relevant layout facts: dense store: requests contiguous source. If the value is assigned deinterleaved, assignment inserts ensure_layout at the store use. A later optimization may - fold ensure_layout + store into a layout-aware store lowering. + fold ensure_layout + store into a layout-aware VMI store. `vmi-to-vpto` + later lowers that explicit store contract. data/mask helper materialization: identity conversions are always legal. diff --git a/docs/designs/vmi-layout-assignment-lowering-design.md b/docs/designs/vmi-layout-assignment-lowering-design.md index 73e2a7a118..98988eb667 100644 --- a/docs/designs/vmi-layout-assignment-lowering-design.md +++ b/docs/designs/vmi-layout-assignment-lowering-design.md @@ -289,6 +289,12 @@ deinterleaved=2, block_elems=8 are different layouts. They cannot be treated as compatible because `F` is the same. +See `vmi-lane-stride-generalization-design.md` for the planned extension that +allows dense layouts to carry `lane_stride` as an additional lane-map field. +That extension keeps dense lane stride separate from the existing group-slot +carrier lowering use case. Non-zero lane phase is left as a future extension +and is not required for the first dense-stride optimization. + ### 2.2 Group-Slot Layouts ```text diff --git a/include/PTO/IR/VMIAttrs.td b/include/PTO/IR/VMIAttrs.td index 111c63fa8a..b871f3b3d6 100644 --- a/include/PTO/IR/VMIAttrs.td +++ b/include/PTO/IR/VMIAttrs.td @@ -18,16 +18,19 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { StringRefParameter<"layout kind">:$kind, "int64_t":$factor, "int64_t":$blockElems, - "int64_t":$slots + "int64_t":$slots, + "int64_t":$laneStride ); let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; let extraClassDeclaration = [{ - static VMILayoutAttr getContiguous(::mlir::MLIRContext *context); + static VMILayoutAttr getContiguous(::mlir::MLIRContext *context, + int64_t laneStride = 1); static VMILayoutAttr getDeinterleaved(::mlir::MLIRContext *context, int64_t factor, - int64_t blockElems = 1); + int64_t blockElems = 1, + int64_t laneStride = 1); static VMILayoutAttr getGroupSlots(::mlir::MLIRContext *context, int64_t numGroups, int64_t slots = 0, @@ -36,13 +39,15 @@ def VMILayoutAttr : PTO_Attr<"VMILayout", "vmi.layout"> { bool isContiguous() const { return getKind() == "contiguous"; } bool isDeinterleaved() const { return getKind() == "deinterleaved"; } bool isGroupSlots() const { return getKind() == "num_groups"; } + bool isDense() const { return isContiguous() || isDeinterleaved(); } int64_t getNumGroups() const { return getFactor(); } - bool hasLaneStride() const { - return isGroupSlots() && getBlockElems() != 1; + bool hasDenseLaneStride() const { + return isDense() && getLaneStride() != 1; } - int64_t getLaneStride() const { - return isGroupSlots() ? getBlockElems() : 1; + bool hasGroupSlotLaneStride() const { + return isGroupSlots() && getLaneStride() != 1; } + bool hasLaneStride() const { return getLaneStride() != 1; } }]; } diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index e957a69058..19a7fc712a 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -25,6 +25,8 @@ class VMITargetCapabilityRegistry; enum class VMIContiguousStoreSupportKind { ContiguousVsts, + LaneStride2PackedVsts, + LaneStride4PackedVsts, Deinterleaved2Vstsx2, DeinterleavedMaterializeThenVsts, }; @@ -39,6 +41,8 @@ enum class VMILayoutMaterializationSupportKind { ContiguousToDeinterleaved, DeinterleavedToContiguous, DeinterleavedToDeinterleavedViaContiguous, + ContiguousToLaneStrideViaUnpack, + LaneStrideToContiguousViaPack, }; struct VMILayoutMaterializationSupport { @@ -152,6 +156,8 @@ struct VMIGroupBroadcastLoadSupport { enum class VMITruncFSupportKind { Deinterleaved2F32ToContiguousF16, Deinterleaved4F32ToContiguousF8, + ContiguousF32ToLaneStrideF16, + ContiguousF32ToLaneStrideF8, GroupSlots1F32ToF16, }; @@ -173,6 +179,8 @@ struct VMIExtFSupport { enum class VMITruncISupportKind { Deinterleaved2I32ToContiguousI16, Deinterleaved4I32ToContiguousI8, + ContiguousI32ToLaneStrideI16, + ContiguousI32ToLaneStrideI8, GroupSlots1I32ToNarrow, }; @@ -219,6 +227,10 @@ class VMILayoutSupport { canFoldContiguousStoreMaterialization(VMIVRegType sourceType, VMIVRegType resultType, std::string *reason = nullptr) const; + LogicalResult canFoldContiguousMaskedStoreMaterialization( + VMIVRegType sourceType, VMIMaskType maskSourceType, + VMIVRegType resultType, VMIMaskType maskResultType, + std::string *reason = nullptr) const; FailureOr getDataLayoutMaterializationSupport(VMIVRegType sourceType, diff --git a/include/PTO/Transforms/VMITargetCapabilities.h b/include/PTO/Transforms/VMITargetCapabilities.h index c9bded10f6..bb64f196e8 100644 --- a/include/PTO/Transforms/VMITargetCapabilities.h +++ b/include/PTO/Transforms/VMITargetCapabilities.h @@ -180,13 +180,17 @@ class VMITargetCapabilityRegistry { "requires assigned source/result layouts"); if (sourceLayout == resultLayout) return VMICapabilityResult::supported(); - if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isDeinterleaved() && resultLayout.getLaneStride() == 1 && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) return VMICapabilityResult::supported(); - if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + if (sourceLayout.isDeinterleaved() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) return VMICapabilityResult::supported(); if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && + resultLayout.getLaneStride() == 1 && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) return VMICapabilityResult::supported(); diff --git a/lib/PTO/IR/VMI.cpp b/lib/PTO/IR/VMI.cpp index 3c25a6f359..a0d4d3dabd 100644 --- a/lib/PTO/IR/VMI.cpp +++ b/lib/PTO/IR/VMI.cpp @@ -167,7 +167,7 @@ static FailureOr getLayoutBlockElems(Type type) { static FailureOr getVMIPhysicalElementType(VMIVRegType type) { Type elementType = type.getElementType(); VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.hasLaneStride()) + if (!layout || !layout.hasGroupSlotLaneStride()) return elementType; auto integerType = dyn_cast(elementType); @@ -195,6 +195,13 @@ static FailureOr getPhysicalLanesPerPart(Type type) { return failure(); } +static FailureOr getDenseLaneStride(Type type) { + FailureOr layout = getAssignedVMILayout(type); + if (failed(layout)) + return failure(); + return (*layout).isDense() ? (*layout).getLaneStride() : 1; +} + static int64_t getMaskGranularityBitWidth(StringRef granularity) { if (granularity == "b8") return 8; @@ -463,21 +470,24 @@ static int64_t getDenseLogicalLanesInPart(int64_t elementCount, int64_t factor, } // namespace -VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context) { - return VMILayoutAttr::get(context, "contiguous", 1, 1, 0); +VMILayoutAttr VMILayoutAttr::getContiguous(MLIRContext *context, + int64_t laneStride) { + return VMILayoutAttr::get(context, "contiguous", 1, 1, 0, laneStride); } VMILayoutAttr VMILayoutAttr::getDeinterleaved(MLIRContext *context, int64_t factor, - int64_t blockElems) { - return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0); + int64_t blockElems, + int64_t laneStride) { + return VMILayoutAttr::get(context, "deinterleaved", factor, blockElems, 0, + laneStride); } VMILayoutAttr VMILayoutAttr::getGroupSlots(MLIRContext *context, int64_t numGroups, int64_t slots, int64_t laneStride) { - return VMILayoutAttr::get(context, "num_groups", numGroups, laneStride, - slots); + return VMILayoutAttr::get(context, "num_groups", numGroups, 1, slots, + laneStride); } Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { @@ -486,22 +496,39 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { int64_t factor = 1; int64_t blockElems = 1; int64_t slots = 0; + int64_t laneStride = 1; if (failed(parser.parseLess()) || failed(parser.parseKeyword(&kind))) return {}; if (kind == "contiguous") { factor = 1; + while (succeeded(parser.parseOptionalComma())) { + StringRef field; + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual()) || + field != "lane_stride" || failed(parser.parseInteger(laneStride))) { + parser.emitError(parser.getCurrentLocation(), + "expected 'lane_stride = '"); + return {}; + } + } } else if (kind == "deinterleaved") { if (failed(parser.parseEqual()) || failed(parser.parseInteger(factor))) return {}; - if (succeeded(parser.parseOptionalComma())) { + while (succeeded(parser.parseOptionalComma())) { StringRef field; - if (failed(parser.parseKeyword(&field)) || field != "block_elems" || - failed(parser.parseEqual()) || - failed(parser.parseInteger(blockElems))) { + if (failed(parser.parseKeyword(&field)) || failed(parser.parseEqual())) + return {}; + if (field == "block_elems") { + if (failed(parser.parseInteger(blockElems))) + return {}; + } else if (field == "lane_stride") { + if (failed(parser.parseInteger(laneStride))) + return {}; + } else { parser.emitError(parser.getCurrentLocation(), - "expected 'block_elems = '"); + "expected 'block_elems = ' or " + "'lane_stride = '"); return {}; } } @@ -516,7 +543,7 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { if (failed(parser.parseInteger(slots))) return {}; } else if (field == "lane_stride") { - if (failed(parser.parseInteger(blockElems))) + if (failed(parser.parseInteger(laneStride))) return {}; } else { parser.emitError(parser.getCurrentLocation(), @@ -536,21 +563,27 @@ Attribute VMILayoutAttr::parse(AsmParser &parser, Type) { return {}; return parser.getChecked(loc, parser.getContext(), kind, - factor, blockElems, slots); + factor, blockElems, slots, + laneStride); } void VMILayoutAttr::print(AsmPrinter &printer) const { printer << "<" << getKind(); - if (isDeinterleaved()) { + if (isContiguous()) { + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); + } else if (isDeinterleaved()) { printer << " = " << getFactor(); if (getBlockElems() != 1) printer << ", block_elems = " << getBlockElems(); + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); } else if (isGroupSlots()) { printer << " = " << getFactor(); if (getSlots() != 0) printer << ", slots = " << getSlots(); - if (getBlockElems() != 1) - printer << ", lane_stride = " << getBlockElems(); + if (getLaneStride() != 1) + printer << ", lane_stride = " << getLaneStride(); } printer << ">"; } @@ -558,7 +591,11 @@ void VMILayoutAttr::print(AsmPrinter &printer) const { LogicalResult VMILayoutAttr::verify(function_ref emitError, StringRef kind, int64_t factor, int64_t blockElems, - int64_t slots) { + int64_t slots, int64_t laneStride) { + if (laneStride <= 0) + return emitError() << "#pto.vmi.layout<" << kind + << "> requires lane_stride to be positive"; + if (kind == "contiguous") { if (factor != 1 || blockElems != 1 || slots != 0) return emitError() @@ -585,9 +622,9 @@ VMILayoutAttr::verify(function_ref emitError, if (factor <= 0) return emitError() << "#pto.vmi.layout requires num_groups to be positive"; - if (blockElems <= 0) + if (blockElems != 1) return emitError() << "#pto.vmi.layout requires lane_stride to be positive"; + << "> requires block_elems to be omitted"; if (slots < 0) return emitError() << "#pto.vmi.layout(value.getType()); + if (!type) + return getContiguousLayout(); + + VMILayoutAttr layout = getExplicitDataLayout(value); + if (!layout || !layout.hasDenseLaneStride()) + layout = type.getLayoutAttr(); + if (!layout || !layout.hasDenseLaneStride()) + return getContiguousLayout(); + + auto candidateType = + VMIVRegType::get(ctx, type.getElementCount(), type.getElementType(), + layout); + VMILayoutSupport supports; + if (succeeded(supports.getContiguousStoreSupport(candidateType))) + return layout; + return getContiguousLayout(); + } + VMILayoutAttr getDataLayout(Value value) { unsigned id = addDataValue(value); if (id == ~0u) @@ -443,6 +463,13 @@ struct LayoutSolver { return dataNodes[find(id)].naturalLayout; } + bool hasMaskedStoreUse(Value value) { + for (OpOperand &use : value.getUses()) + if (isa(use.getOwner()) && use.getOperandNumber() == 0) + return true; + return false; + } + bool hasCompatibleTruncFUseForGroupReduce(Value value, int64_t groupSize) { auto sourceType = dyn_cast(value.getType()); if (!sourceType || !sourceType.getElementType().isF32()) @@ -1251,11 +1278,22 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && hasMaskedStoreUse(truncf.getResult()) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); - VMILayoutAttr resultLayout = - succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(truncf.getResult(), fact->resultLayout, + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutAttr resultLayout = getContiguousLayout(); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { + requestDataUse(truncf.getSourceMutable(), getContiguousLayout()); + resultLayout = + VMILayoutAttr::getContiguous(ctx, /*laneStride=*/fact->factor); + } if (failed(setNaturalLayout(truncf.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -1282,11 +1320,22 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) + if (succeeded(fact) && hasMaskedStoreUse(trunci.getResult()) && + (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); - VMILayoutAttr resultLayout = - succeeded(fact) ? fact->resultLayout : getContiguousLayout(); + if (failed(setNaturalLayout(trunci.getResult(), fact->resultLayout, + op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + VMILayoutAttr resultLayout = getContiguousLayout(); + if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || + fact->kind == VMICastLayoutKind::Narrow4x)) { + requestDataUse(trunci.getSourceMutable(), getContiguousLayout()); + resultLayout = + VMILayoutAttr::getContiguous(ctx, /*laneStride=*/fact->factor); + } if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); return WalkResult::advance(); @@ -1362,7 +1411,8 @@ struct LayoutSolver { return WalkResult::advance(); } if (auto store = dyn_cast(op)) { - requestDataUse(store.getValueMutable(), getContiguousLayout()); + requestDataUse(store.getValueMutable(), + getPreferredDenseStoreUseLayout(store.getValue())); return WalkResult::advance(); } if (auto store = dyn_cast(op)) { diff --git a/lib/PTO/Transforms/VMILayoutFold.cpp b/lib/PTO/Transforms/VMILayoutFold.cpp index a592786f5b..253ab7c3fc 100644 --- a/lib/PTO/Transforms/VMILayoutFold.cpp +++ b/lib/PTO/Transforms/VMILayoutFold.cpp @@ -49,9 +49,18 @@ static bool isLoadProducerLayout(VMIVRegType type) { VMILayoutAttr layout = type.getLayoutAttr(); if (!layout) return false; - if (layout.isContiguous()) + if (layout.isContiguous() && layout.getLaneStride() == 1) return true; + if (layout.isContiguous() && layout.getLaneStride() == 2) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8 || elementBits == 16 || elementBits == 32; + } + if (layout.isContiguous() && layout.getLaneStride() == 4) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + return elementBits == 8; + } if (!layout.isDeinterleaved() || layout.getBlockElems() != 1 || + layout.getLaneStride() != 1 || (layout.getFactor() != 2 && layout.getFactor() != 4)) return false; unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); @@ -164,9 +173,12 @@ static void tryFoldEnsureLayoutIntoMaskedStore( if (sourceLayout != maskSourceLayout || !maskResultLayout.isContiguous()) return; - FailureOr sourceArity = getVMIPhysicalArity(sourceType); - FailureOr maskArity = getVMIPhysicalArity(maskSourceType); - if (failed(sourceArity) || failed(maskArity) || *sourceArity != *maskArity) + auto resultType = dyn_cast(ensure.getResult().getType()); + if (!resultType) + return; + VMILayoutSupport supports; + if (failed(supports.canFoldContiguousMaskedStoreMaterialization( + sourceType, maskSourceType, resultType, maskResultType))) return; store.getValueMutable().set(ensure.getSource()); diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index f1fd1097fb..4d93fa374c 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -71,6 +71,57 @@ static bool hasX2MemoryDistToken(Type elementType) { return elementBits == 8 || elementBits == 16 || elementBits == 32; } +static bool hasDenseLaneStride2PackedStore(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool hasDenseLaneStride4PackedStore(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8; +} + +static bool hasDenseLaneStridePackUnpackElement(Type elementType, + int64_t laneStride) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if (elementBits == 0 || (!isa(elementType) && + !isa(elementType))) + return false; + if (laneStride == 2) + return elementBits == 8 || elementBits == 16; + if (laneStride == 4) + return elementBits == 8; + return false; +} + +static std::optional +getDenseLaneStrideMaskedStoreMaskGranularity(VMIVRegType valueType) { + VMILayoutAttr layout = valueType.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(valueType.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + +static StringRef getMaskGranularityForElementBits(unsigned elementBits) { + switch (elementBits) { + case 8: + return "b8"; + case 16: + return "b16"; + case 32: + return "b32"; + default: + return ""; + } +} + static std::optional getConstantIndexValue(Value value) { if (auto constant = value.getDefiningOp()) return constant.value(); @@ -123,8 +174,19 @@ static FailureOr getVMITypeChunksInPart(Type type, int64_t part) { part < 0 || part >= *factor) return failure(); + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + if (!layout) + return failure(); + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; - return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); + int64_t laneStride = layout.isDense() ? layout.getLaneStride() : 1; + int64_t physicalLanes = + logicalLanesInPart == 0 ? 0 : (logicalLanesInPart - 1) * laneStride + 1; + return ceilDivNonNegative(physicalLanes, *lanesPerPart); } static LogicalResult checkFullVMIPhysicalChunks(Type type, @@ -181,7 +243,7 @@ getContiguousMaterializationPartCount(Type type, std::string *reason) { if (!layout) return fail("requires assigned layout"); - if (layout.isContiguous()) + if (layout.isContiguous() && layout.getLaneStride() == 1) return *arity; if (!layout.isDeinterleaved() || (layout.getFactor() != 2 && layout.getFactor() != 4)) @@ -328,14 +390,17 @@ getLayoutMaterializationSupport(VMILayoutAttr sourceLayout, return VMILayoutMaterializationSupport{ VMILayoutMaterializationSupportKind::Identity}; if (sourceLayout.isContiguous() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) return VMILayoutMaterializationSupport{ VMILayoutMaterializationSupportKind::ContiguousToDeinterleaved}; if (sourceLayout.isDeinterleaved() && resultLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4)) return VMILayoutMaterializationSupport{ VMILayoutMaterializationSupportKind::DeinterleavedToContiguous}; if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && + sourceLayout.getLaneStride() == 1 && resultLayout.getLaneStride() == 1 && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) return VMILayoutMaterializationSupport{ @@ -507,6 +572,11 @@ VMILayoutSupport::getWidenSourceLayoutForResultLayout( fact->kind != VMICastLayoutKind::Widen4x)) return fail("requires supported 8/16-bit to 32-bit widen cast"); + if (requestedResultLayout.isContiguous()) { + return VMILayoutAttr::getContiguous(sourceType.getContext(), + /*laneStride=*/fact->factor); + } + int64_t resultFactor = requestedResultLayout.isDeinterleaved() ? requestedResultLayout.getFactor() : 1; @@ -536,11 +606,27 @@ VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, VMILayoutAttr layout = valueType.getLayoutAttr(); if (!layout) return fail("requires assigned value layout"); - if (layout.isContiguous()) + if (layout.isContiguous() && layout.getLaneStride() == 1) return VMIContiguousStoreSupport{ VMIContiguousStoreSupportKind::ContiguousVsts}; + if (layout.isContiguous() && layout.getLaneStride() == 2) { + if (!hasDenseLaneStride2PackedStore(valueType.getElementType())) + return fail("requires 8/16/32-bit element type for dense lane_stride=2 " + "packed store"); + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::LaneStride2PackedVsts}; + } + if (layout.isContiguous() && layout.getLaneStride() == 4) { + if (!hasDenseLaneStride4PackedStore(valueType.getElementType())) + return fail("requires 8-bit element type for dense lane_stride=4 " + "packed store"); + return VMIContiguousStoreSupport{ + VMIContiguousStoreSupportKind::LaneStride4PackedVsts}; + } if (!layout.isDeinterleaved()) return fail("requires contiguous or deinterleaved value layout"); + if (layout.getLaneStride() != 1) + return fail("deinterleaved packed store requires lane_stride=1"); if (layout.getBlockElems() != 1) return fail("requires block_elems=1 deinterleaved value layout"); if (failed(checkFullDataPhysicalChunks(valueType, reason))) @@ -581,6 +667,60 @@ LogicalResult VMILayoutSupport::canFoldContiguousStoreMaterialization( return success(); } +LogicalResult VMILayoutSupport::canFoldContiguousMaskedStoreMaterialization( + VMIVRegType sourceType, VMIMaskType maskSourceType, + VMIVRegType resultType, VMIMaskType maskResultType, + std::string *reason) const { + if (sourceType.getElementType() != resultType.getElementType()) + return failWithReason("source/result element types must match", reason); + if (sourceType.getElementCount() != resultType.getElementCount()) + return failWithReason("source/result element counts must match", reason); + if (maskSourceType.getElementCount() != sourceType.getElementCount() || + maskResultType.getElementCount() != resultType.getElementCount()) + return failWithReason("value/mask element counts must match", reason); + if (maskSourceType.getGranularity() != maskResultType.getGranularity()) + return failWithReason("mask layout fold cannot change granularity", reason); + + VMILayoutAttr sourceLayout = sourceType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultType.getLayoutAttr(); + VMILayoutAttr maskSourceLayout = maskSourceType.getLayoutAttr(); + VMILayoutAttr maskResultLayout = maskResultType.getLayoutAttr(); + if (!sourceLayout || !resultLayout || !maskSourceLayout || !maskResultLayout) + return failWithReason("requires assigned value/mask layouts", reason); + if (!resultLayout.isContiguous() || !maskResultLayout.isContiguous()) + return failWithReason("result value/mask layouts must be contiguous", + reason); + if (sourceLayout != maskSourceLayout) + return failWithReason("source value/mask layouts must match", reason); + + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr maskArity = getVMIPhysicalArity(maskSourceType); + if (failed(sourceArity) || failed(maskArity) || *sourceArity != *maskArity) + return failWithReason("source value/mask physical arity must match", + reason); + + if (!sourceLayout.hasDenseLaneStride()) + return canFoldContiguousStoreMaterialization(sourceType, resultType, + reason); + + std::optional packedGranularity = + getDenseLaneStrideMaskedStoreMaskGranularity(sourceType); + if (!packedGranularity) + return failWithReason("dense lane_stride masked store supports only " + "LS=2 b8/b16 and LS=4 b8 compact masks", + reason); + + unsigned elementBits = pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + StringRef expectedSourceGranularity = + getMaskGranularityForElementBits(elementBits); + if (expectedSourceGranularity.empty() || + maskSourceType.getGranularity() != expectedSourceGranularity) + return failWithReason("mask granularity must match source element width", + reason); + + return success(); +} + FailureOr VMILayoutSupport::getDataLayoutMaterializationSupport( VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { @@ -600,12 +740,51 @@ VMILayoutSupport::getDataLayoutMaterializationSupport( VMILayoutAttr resultLayout = resultType.getLayoutAttr(); FailureOr support = getLayoutMaterializationSupport(sourceLayout, resultLayout, reason); - if (failed(support)) - return failure(); - if (failed(checkLayoutMaterializationShape( - sourceType, resultType, sourceLayout, resultLayout, reason))) - return failure(); - return support; + if (succeeded(support)) { + if (failed(checkLayoutMaterializationShape( + sourceType, resultType, sourceLayout, resultLayout, reason))) + return failure(); + return support; + } + + if (!sourceLayout || !resultLayout) + return fail("requires assigned source/result layouts"); + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() != 1) { + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("dense lane_stride register materialization currently " + "requires source and result to have the same physical arity"); + if (!hasDenseLaneStridePackUnpackElement(sourceType.getElementType(), + resultLayout.getLaneStride())) + return fail("requires bitcastable 8/16-bit element type for dense " + "lane_stride register unpack materialization"); + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::ContiguousToLaneStrideViaUnpack}; + } + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() != 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) { + FailureOr sourceArity = getVMIPhysicalArity(sourceType); + FailureOr resultArity = getVMIPhysicalArity(resultType); + if (failed(sourceArity) || failed(resultArity)) + return fail("requires computable source/result physical arity"); + if (*sourceArity != *resultArity) + return fail("dense lane_stride register materialization currently " + "requires source and result to have the same physical arity"); + if (!hasDenseLaneStridePackUnpackElement(sourceType.getElementType(), + sourceLayout.getLaneStride())) + return fail("requires bitcastable 8/16-bit element type for dense " + "lane_stride register pack materialization"); + return VMILayoutMaterializationSupport{ + VMILayoutMaterializationSupportKind::LaneStrideToContiguousViaPack}; + } + + return failure(); } LogicalResult VMILayoutSupport::canMaterializeDataLayout( @@ -1255,9 +1434,8 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { return VMITruncFSupport{VMITruncFSupportKind::GroupSlots1F32ToF16}; } - if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || - !sourceType.getElementType().isF32()) - return fail("requires f32 deinterleaved source and contiguous result"); + if (!sourceType.getElementType().isF32()) + return fail("requires f32 source"); FailureOr fact = getPreferredCastLayoutFact(sourceType, resultType, reason); @@ -1266,6 +1444,23 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { return fail("unsupported deinterleaved truncf factor, arity, or result " "element width"); + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == fact->factor && + *sourceArity == *resultArity) { + if (fact->kind == VMICastLayoutKind::Narrow2x) + return VMITruncFSupport{ + VMITruncFSupportKind::ContiguousF32ToLaneStrideF16}; + if (fact->kind == VMICastLayoutKind::Narrow4x) + return VMITruncFSupport{ + VMITruncFSupportKind::ContiguousF32ToLaneStrideF8}; + } + + if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || + resultLayout.getLaneStride() != 1) + return fail("requires f32 deinterleaved source and contiguous result, or " + "contiguous source and lane_stride narrowing result"); + if (fact->kind == VMICastLayoutKind::Narrow2x && sourceLayout.getFactor() == fact->factor && *sourceArity == fact->factor * *resultArity) @@ -1299,13 +1494,6 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { failed(resultArity)) return fail("requires assigned source/result layouts and computable " "physical arity"); - if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || - !(sourceLayout.isContiguous() || - (sourceLayout.isDeinterleaved() && - sourceLayout.getBlockElems() == 1)) || - !resultType.getElementType().isF32()) - return fail("requires contiguous or deinterleaved source layout and " - "deinterleaved f32 result layout with block_elems=1"); FailureOr fact = getPreferredCastLayoutFact(sourceType, resultType, reason); @@ -1314,6 +1502,28 @@ VMILayoutSupport::getExtFSupport(VMIExtFOp op, std::string *reason) const { return fail("unsupported extf source element width, result factor, or " "physical arity"); + if (sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == fact->factor && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + *sourceArity == *resultArity && resultType.getElementType().isF32()) { + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF16ToDeinterleaved2F32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtFSupport{ + VMIExtFSupportKind::ContiguousF8ToDeinterleaved4F32}; + } + + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + resultLayout.getLaneStride() != 1 || + !(sourceLayout.isContiguous() || + (sourceLayout.isDeinterleaved() && + sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) || + !resultType.getElementType().isF32()) + return fail("requires contiguous or deinterleaved source layout and " + "deinterleaved f32 result layout with block_elems=1"); + int64_t sourceFactor = sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; if (resultLayout.getFactor() != sourceFactor * fact->factor || @@ -1349,6 +1559,10 @@ static FailureOr getExtISupportImpl(OpT op, return fail("requires assigned source/result layouts and computable " "physical arity"); + FailureOr fact = + VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, + reason); + if (sourceLayout.isGroupSlots() && resultLayout.isGroupSlots()) { if (!isa(sourceType.getElementType()) || !isa(resultType.getElementType())) @@ -1376,18 +1590,34 @@ static FailureOr getExtISupportImpl(OpT op, "16-bit"); } + if (succeeded(fact) && + (fact->kind == VMICastLayoutKind::Widen2x || + fact->kind == VMICastLayoutKind::Widen4x) && + sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == fact->factor && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + *sourceArity == *resultArity && + isa(sourceType.getElementType()) && + isa(resultType.getElementType())) { + if (fact->kind == VMICastLayoutKind::Widen2x) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI16ToDeinterleaved2I32}; + if (fact->kind == VMICastLayoutKind::Widen4x) + return VMIExtISupport{ + VMIExtISupportKind::ContiguousI8ToDeinterleaved4I32}; + } + if (!resultLayout.isDeinterleaved() || resultLayout.getBlockElems() != 1 || + resultLayout.getLaneStride() != 1 || !(sourceLayout.isContiguous() || (sourceLayout.isDeinterleaved() && - sourceLayout.getBlockElems() == 1)) || + sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) || !isa(sourceType.getElementType()) || !isa(resultType.getElementType())) return fail("requires contiguous or deinterleaved integer source layout " "and deinterleaved integer result layout with block_elems=1"); - FailureOr fact = - VMILayoutSupport().getPreferredCastLayoutFact(sourceType, resultType, - reason); if (failed(fact) || (fact->kind != VMICastLayoutKind::Widen2x && fact->kind != VMICastLayoutKind::Widen4x)) return fail("unsupported integer extension source/result element width, " @@ -1468,11 +1698,31 @@ VMILayoutSupport::getTruncISupport(VMITruncIOp op, std::string *reason) const { "requires unsigned i8 result"); if (!sourceLayout.isDeinterleaved() || sourceLayout.getBlockElems() != 1 || - !(resultLayout.isContiguous() || + !((resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) || (resultLayout.isDeinterleaved() && - resultLayout.getBlockElems() == 1))) - return fail("requires integer deinterleaved source and contiguous or " - "deinterleaved integer result with block_elems=1"); + resultLayout.getBlockElems() == 1 && + resultLayout.getLaneStride() == 1))) + if (!(sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == fact->factor)) + return fail("requires integer deinterleaved source and contiguous or " + "deinterleaved integer result with block_elems=1, or " + "contiguous source and lane_stride narrowing result"); + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == fact->factor && + *sourceArity == *resultArity) { + if (resultBits == 8 && + !cast(resultType.getElementType()).isUnsigned()) + return fail("8-bit integer narrowing requires unsigned i8 result"); + if (fact->kind == VMICastLayoutKind::Narrow2x) + return VMITruncISupport{ + VMITruncISupportKind::ContiguousI32ToLaneStrideI16}; + if (fact->kind == VMICastLayoutKind::Narrow4x) + return VMITruncISupport{ + VMITruncISupportKind::ContiguousI32ToLaneStrideI8}; + } int64_t resultFactor = resultLayout.isDeinterleaved() ? resultLayout.getFactor() : 1; diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index d774545de3..ec6d5e4fac 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -54,6 +54,8 @@ namespace { std::optional getX2MemoryDistToken(Type elementType, StringRef prefix); +std::optional getDenseLaneStrideLoadDistToken(VMIVRegType type); +std::optional getDenseLaneStrideStoreDistToken(VMIVRegType type); bool isVMIType(Type type) { return isa(type); } @@ -247,7 +249,7 @@ materializeVMIToVPTO(OpBuilder &builder, TypeRange resultTypes, Value input, static FailureOr getVMIVRegPhysicalElementType(VMIVRegType type) { Type elementType = type.getElementType(); VMILayoutAttr layout = type.getLayoutAttr(); - if (!layout || !layout.hasLaneStride()) + if (!layout || !layout.hasGroupSlotLaneStride()) return elementType; auto integerType = dyn_cast(elementType); @@ -681,8 +683,19 @@ FailureOr getVMITypeChunksInPart(Type type, int64_t part) { part < 0 || part >= *factor) return failure(); + VMILayoutAttr layout; + if (auto vregType = dyn_cast(type)) + layout = vregType.getLayoutAttr(); + else if (auto maskType = dyn_cast(type)) + layout = maskType.getLayoutAttr(); + if (!layout) + return failure(); + int64_t logicalLanesInPart = (*elementCount + *factor - 1 - part) / *factor; - return ceilDivNonNegative(logicalLanesInPart, *lanesPerPart); + int64_t laneStride = layout.isDense() ? layout.getLaneStride() : 1; + int64_t physicalLanes = + logicalLanesInPart == 0 ? 0 : (logicalLanesInPart - 1) * laneStride + 1; + return ceilDivNonNegative(physicalLanes, *lanesPerPart); } LogicalResult checkFullVMIPhysicalChunks(Type type, std::string *reason) { @@ -797,7 +810,7 @@ FailureOr getContiguousMaterializationPartCount(Type type, auto layout = dyn_cast_or_null(layoutAttr); if (!layout) return fail("requires assigned layout"); - if (layout.isContiguous()) + if (layout.isContiguous() && layout.getLaneStride() == 1) return *arity; if (!layout.isDeinterleaved() || (layout.getFactor() != 2 && layout.getFactor() != 4)) @@ -1154,6 +1167,9 @@ checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, if (!accessPlan.targetCapability.isSupported()) return fail(accessPlan.targetCapability.reason); + if (getDenseLaneStrideLoadDistToken(type)) + return success(); + std::string fullChunkReason; if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) return success(); @@ -1216,6 +1232,9 @@ checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, if (failed(checkSupportedMaskableVReg(capabilities, type, reason))) return failure(); + if (getDenseLaneStrideStoreDistToken(type)) + return success(); + std::string fullChunkReason; if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) return success(); @@ -1231,7 +1250,7 @@ checkSupportedStoreShape(const VMITargetCapabilityRegistry &capabilities, return fail("requires assigned layout"); if (failed(getDataLanesPerPart(type.getElementType()))) return fail("requires known physical lanes per part"); - if (layout.isContiguous()) + if (layout.isContiguous() && layout.getLaneStride() == 1) return success(); std::string materializationReason; @@ -1852,6 +1871,22 @@ checkSupportedMaskedStoreShape(const VMITargetCapabilityRegistry &capabilities, if (failed(valueArity) || failed(maskArity) || *valueArity != *maskArity) return fail("requires matching value/mask physical arity"); + if (valueLayout.hasDenseLaneStride()) { + VMILayoutSupport supports; + auto contiguousValueType = + VMIVRegType::get(valueType.getContext(), valueType.getElementCount(), + valueType.getElementType(), + VMILayoutAttr::getContiguous(valueType.getContext())); + auto contiguousMaskType = + VMIMaskType::get(maskType.getContext(), maskType.getElementCount(), + maskType.getGranularity(), + VMILayoutAttr::getContiguous(maskType.getContext())); + if (succeeded(supports.canFoldContiguousMaskedStoreMaterialization( + valueType, maskType, contiguousValueType, contiguousMaskType, + reason))) + return success(); + } + std::string valueMaterializationReason; FailureOr valueParts = getContiguousMaterializationPartCount( valueType, &valueMaterializationReason); @@ -1883,6 +1918,24 @@ FailureOr getContiguousActiveDataLanes(VMIVRegType vmiType, return std::clamp(remaining, 0, *lanesPerPart); } +FailureOr getActiveDataLanesInPhysicalChunk(VMIVRegType vmiType, + int64_t chunk) { + FailureOr lanesPerPart = + getDataLanesPerPart(vmiType.getElementType()); + if (failed(lanesPerPart)) + return failure(); + + int64_t active = 0; + for (int64_t lane = 0; lane < *lanesPerPart; ++lane) { + FailureOr padding = isPaddingLane(vmiType, /*part=*/0, chunk, lane); + if (failed(padding)) + return failure(); + if (!*padding) + ++active; + } + return active; +} + FailureOr createContiguousStoreMask(Location loc, VMIVRegType vmiType, int64_t chunk, VRegType vregType, PatternRewriter &rewriter) { @@ -1935,6 +1988,55 @@ FailureOr createMaskedStorePredicate(Location loc, VMIVRegType vmiType, .getResult(); } +FailureOr createDenseLaneStrideStorePredicate( + Location loc, VMIVRegType vmiType, int64_t chunk, Value userMask, + StringRef targetGranularity, PatternRewriter &rewriter) { + auto sourceMaskType = dyn_cast(userMask.getType()); + if (!sourceMaskType) + return failure(); + auto targetMaskType = MaskType::get(rewriter.getContext(), targetGranularity); + Value compactMask = userMask; + VMILayoutAttr layout = vmiType.getLayoutAttr(); + if (!layout) + return failure(); + + auto lower = rewriter.getStringAttr("LOWER"); + StringRef sourceGranularity = sourceMaskType.getGranularity(); + if (layout.getLaneStride() == 2) { + compactMask = + rewriter.create(loc, targetMaskType, compactMask, lower) + .getResult(); + } else if (layout.getLaneStride() == 4 && sourceGranularity == "b8" && + targetGranularity == "b32") { + auto b16MaskType = MaskType::get(rewriter.getContext(), "b16"); + compactMask = + rewriter.create(loc, b16MaskType, compactMask, lower) + .getResult(); + compactMask = + rewriter.create(loc, targetMaskType, compactMask, lower) + .getResult(); + } else { + return failure(); + } + + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(vmiType, chunk); + FailureOr maskLanes = getMaskLanesPerPart(targetGranularity); + if (failed(activeLanes) || failed(maskLanes)) + return failure(); + if (*activeLanes == *maskLanes) + return compactMask; + + FailureOr tailMask = createPrefixMaskForActiveLanes( + loc, targetMaskType, *activeLanes, rewriter); + FailureOr allTrue = createAllTrueMask(loc, targetMaskType, rewriter); + if (failed(tailMask) || failed(allTrue)) + return failure(); + return rewriter + .create(loc, targetMaskType, compactMask, *tailMask, *allTrue) + .getResult(); +} + FailureOr> computeShuffleForwardingSourceParts(VMIShuffleOp op, std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr> { @@ -2759,6 +2861,67 @@ std::optional getX2MemoryDistToken(Type elementType, return (Twine(prefix) + "_B" + Twine(elementBits)).str(); } +std::optional getDenseLaneStrideLoadDistToken(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && + (elementBits == 8 || elementBits == 16 || elementBits == 32)) + return (Twine("UNPK_B") + Twine(elementBits)).str(); + if (layout.getLaneStride() == 4 && elementBits == 8) + return std::string("UNPK4"); + return std::nullopt; +} + +std::optional getDenseLaneStrideStoreDistToken(VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return std::string("PK_B16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return std::string("PK_B32"); + if (layout.getLaneStride() == 2 && elementBits == 32) + return std::string("PK_B64"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return std::string("PK4_B32"); + return std::nullopt; +} + +std::optional getDenseLaneStrideStoreMaskGranularity( + VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 2 && elementBits == 32) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + +std::optional getDenseLaneStrideMaskedStoreMaskGranularity( + VMIVRegType type) { + VMILayoutAttr layout = type.getLayoutAttr(); + if (!layout || !layout.isContiguous()) + return std::nullopt; + unsigned elementBits = pto::getPTOStorageElemBitWidth(type.getElementType()); + if (layout.getLaneStride() == 2 && elementBits == 8) + return StringRef("b16"); + if (layout.getLaneStride() == 2 && elementBits == 16) + return StringRef("b32"); + if (layout.getLaneStride() == 4 && elementBits == 8) + return StringRef("b32"); + return std::nullopt; +} + std::optional getPointStoreDistToken(Type elementType) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); if (elementBits != 8 && elementBits != 16 && elementBits != 32) @@ -2851,6 +3014,165 @@ LogicalResult verifyIdentityPartForwarding(Operation *op, return success(); } +FailureOr getUnsignedCarrierVRegType(MLIRContext *ctx, + unsigned elementBits) { + if (elementBits != 8 && elementBits != 16 && elementBits != 32) + return failure(); + auto elementType = + IntegerType::get(ctx, elementBits, + IntegerType::SignednessSemantics::Unsigned); + return VRegType::get(ctx, 2048 / elementBits, elementType); +} + +FailureOr bitcastVReg(Location loc, Value value, Type resultType, + PatternRewriter &rewriter) { + if (value.getType() == resultType) + return value; + auto inputType = dyn_cast(value.getType()); + auto outputType = dyn_cast(resultType); + if (!inputType || !outputType) + return failure(); + return rewriter.create(loc, outputType, value).getResult(); +} + +FailureOr unpackToNextCarrier(Location loc, Value source, + unsigned sourceBits, int64_t partIndex, + PatternRewriter &rewriter) { + FailureOr resultType = + getUnsignedCarrierVRegType(rewriter.getContext(), sourceBits * 2); + if (failed(resultType)) + return failure(); + Value part = rewriter.create(loc, partIndex); + return rewriter.create(loc, *resultType, source, part).getResult(); +} + +FailureOr packToPreviousCarrier(Location loc, Value source, + unsigned resultBits, + PatternRewriter &rewriter) { + FailureOr resultType = + getUnsignedCarrierVRegType(rewriter.getContext(), resultBits); + if (failed(resultType)) + return failure(); + return rewriter + .create(loc, *resultType, source, + rewriter.getStringAttr("LOWER")) + .getResult(); +} + +FailureOr> materializeContiguousToLaneStride( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + Type elementType, int64_t laneStride, PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) { + (void)rewriter.notifyMatchFailure( + op, "dense lane_stride unpack materialization requires matching " + "source/result physical arity"); + return failure(); + } + + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if ((laneStride != 2 && laneStride != 4) || + (laneStride == 4 && elementBits != 8) || + (elementBits != 8 && elementBits != 16)) { + (void)rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride unpack carrier shape"); + return failure(); + } + + MLIRContext *ctx = rewriter.getContext(); + FailureOr inputCarrier = + getUnsignedCarrierVRegType(ctx, elementBits); + if (failed(inputCarrier)) + return failure(); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [resultIndex, resultType] : llvm::enumerate(resultTypes)) { + int64_t sourceIndex = resultIndex / laneStride; + if (sourceIndex >= static_cast(sourceParts.size())) + return failure(); + Value source = sourceParts[sourceIndex]; + FailureOr current = + bitcastVReg(op->getLoc(), source, *inputCarrier, rewriter); + if (failed(current)) + return failure(); + int64_t part = resultIndex % laneStride; + FailureOr unpacked = + unpackToNextCarrier(op->getLoc(), *current, elementBits, + laneStride == 4 ? part / 2 : part, rewriter); + if (failed(unpacked)) + return failure(); + current = *unpacked; + if (laneStride == 4) { + unpacked = + unpackToNextCarrier(op->getLoc(), *current, elementBits * 2, + part % 2, rewriter); + if (failed(unpacked)) + return failure(); + current = *unpacked; + } + FailureOr result = + bitcastVReg(op->getLoc(), *current, resultType, rewriter); + if (failed(result)) + return failure(); + results.push_back(*result); + } + return results; +} + +FailureOr> materializeLaneStrideToContiguous( + Operation *op, ValueRange sourceParts, TypeRange resultTypes, + Type elementType, int64_t laneStride, PatternRewriter &rewriter) { + if (sourceParts.size() != resultTypes.size()) { + (void)rewriter.notifyMatchFailure( + op, "dense lane_stride pack materialization requires matching " + "source/result physical arity"); + return failure(); + } + + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + if ((laneStride != 2 && laneStride != 4) || + (laneStride == 4 && elementBits != 8) || + (elementBits != 8 && elementBits != 16)) { + (void)rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride pack carrier shape"); + return failure(); + } + + unsigned carrierBits = + static_cast(elementBits * static_cast(laneStride)); + FailureOr sourceCarrier = + getUnsignedCarrierVRegType(rewriter.getContext(), carrierBits); + if (failed(sourceCarrier)) + return failure(); + + SmallVector results; + results.reserve(sourceParts.size()); + for (auto [source, resultType] : llvm::zip_equal(sourceParts, resultTypes)) { + FailureOr current = + bitcastVReg(op->getLoc(), source, *sourceCarrier, rewriter); + if (failed(current)) + return failure(); + FailureOr packed = + packToPreviousCarrier(op->getLoc(), *current, carrierBits / 2, rewriter); + if (failed(packed)) + return failure(); + current = *packed; + if (laneStride == 4) { + packed = + packToPreviousCarrier(op->getLoc(), *current, elementBits, rewriter); + if (failed(packed)) + return failure(); + current = *packed; + } + FailureOr result = + bitcastVReg(op->getLoc(), *current, resultType, rewriter); + if (failed(result)) + return failure(); + results.push_back(*result); + } + return results; +} + FailureOr> materializeDataLayoutConversion( Operation *op, ValueRange sourceParts, TypeRange resultTypes, VMILayoutAttr sourceLayout, VMILayoutAttr resultLayout, @@ -2870,10 +3192,14 @@ FailureOr> materializeDataLayoutConversion( bool deint2ToContiguous = sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && - resultLayout.isContiguous(); + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; bool contiguousToDeint2 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isDeinterleaved() && - resultLayout.getFactor() == 2; + resultLayout.getFactor() == 2 && + resultLayout.getLaneStride() == 1; if (deint2ToContiguous || contiguousToDeint2) { if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 2 != 0) { @@ -2915,10 +3241,14 @@ FailureOr> materializeDataLayoutConversion( bool deint4ToContiguous = sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && - resultLayout.isContiguous(); + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; bool contiguousToDeint4 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isDeinterleaved() && - resultLayout.getFactor() == 4; + resultLayout.getFactor() == 4 && + resultLayout.getLaneStride() == 1; if (deint4ToContiguous || contiguousToDeint4) { if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 4 != 0) { @@ -2988,6 +3318,28 @@ FailureOr> materializeDataLayoutConversion( return results; } + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() != 1) { + auto ensure = dyn_cast(op); + if (!ensure) + return failure(); + auto sourceType = cast(ensure.getSource().getType()); + return materializeContiguousToLaneStride( + op, sourceParts, resultTypes, sourceType.getElementType(), + resultLayout.getLaneStride(), rewriter); + } + + if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() != 1 && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1) { + auto ensure = dyn_cast(op); + if (!ensure) + return failure(); + auto sourceType = cast(ensure.getSource().getType()); + return materializeLaneStrideToContiguous( + op, sourceParts, resultTypes, sourceType.getElementType(), + sourceLayout.getLaneStride(), rewriter); + } + if (sourceLayout.isDeinterleaved() && resultLayout.isDeinterleaved() && (sourceLayout.getFactor() == 2 || sourceLayout.getFactor() == 4) && (resultLayout.getFactor() == 2 || resultLayout.getFactor() == 4)) { @@ -3068,10 +3420,14 @@ FailureOr> materializeMaskLayoutConversion( bool deint2ToContiguous = sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 2 && - resultLayout.isContiguous(); + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; bool contiguousToDeint2 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isDeinterleaved() && - resultLayout.getFactor() == 2; + resultLayout.getFactor() == 2 && + resultLayout.getLaneStride() == 1; if (deint2ToContiguous || contiguousToDeint2) { if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 2 != 0) { @@ -3120,10 +3476,14 @@ FailureOr> materializeMaskLayoutConversion( bool deint4ToContiguous = sourceLayout.isDeinterleaved() && sourceLayout.getFactor() == 4 && - resultLayout.isContiguous(); + sourceLayout.getLaneStride() == 1 && + resultLayout.isContiguous() && + resultLayout.getLaneStride() == 1; bool contiguousToDeint4 = sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isDeinterleaved() && - resultLayout.getFactor() == 4; + resultLayout.getFactor() == 4 && + resultLayout.getLaneStride() == 1; if (deint4ToContiguous || contiguousToDeint4) { if (sourceParts.size() != resultTypes.size() || sourceParts.empty() || sourceParts.size() % 4 != 0) { @@ -4098,13 +4458,40 @@ struct OneToNVMILoadOpPattern : OneToNOpConversionPattern { "load offset must convert to one value", rewriter); if (failed(source) || failed(offset)) return failure(); + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (std::optional dist = + getDenseLaneStrideLoadDistToken(resultVMIType)) { + SmallVector results; + results.reserve(resultTypes.size()); + int64_t semanticOffset = 0; + for (auto [index, resultType] : llvm::enumerate(resultTypes)) { + if (!isa(resultType)) + return rewriter.notifyMatchFailure(op, "load result must be vreg"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + results.push_back(rewriter + .create(op.getLoc(), resultType, + /*updated_base=*/Type{}, *source, + chunkOffset, + rewriter.getStringAttr(*dist)) + .getResult()); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(resultVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride load active lanes"); + semanticOffset += *activeLanes; + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + FailureOr lanesPerPart = verifyFullOrSafeReadVRegChunks( op, resultVMIType, op.getSource().getType(), *offset, rewriter); if (failed(lanesPerPart)) return failure(); - TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); - VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); if (resultLayout && resultLayout.isDeinterleaved() && resultLayout.getFactor() == 2) { std::optional dist = @@ -4822,6 +5209,43 @@ struct OneToNVMIStoreOpPattern : OneToNOpConversionPattern { return failure(); ValueRange valueParts = adaptor.getValue(); + if (std::optional dist = + getDenseLaneStrideStoreDistToken(valueVMIType)) { + std::optional maskGranularity = + getDenseLaneStrideStoreMaskGranularity(valueVMIType); + if (!maskGranularity) + return rewriter.notifyMatchFailure( + op, "unsupported lane_stride store mask granularity"); + int64_t semanticOffset = 0; + for (auto [index, value] : llvm::enumerate(valueParts)) { + auto vregType = dyn_cast(value.getType()); + if (!vregType) + return rewriter.notifyMatchFailure(op, "store value must be vreg"); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride store active lanes"); + if (*activeLanes == 0) + continue; + auto maskType = MaskType::get(rewriter.getContext(), *maskGranularity); + FailureOr mask = createPrefixMaskForActiveLanes( + op.getLoc(), maskType, *activeLanes, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to create lane_stride store mask"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *mask); + semanticOffset += *activeLanes; + } + rewriter.eraseOp(op); + return success(); + } + VMILayoutSupport localSupports; FailureOr storeSupport = localSupports.getContiguousStoreSupport(valueVMIType); @@ -5336,6 +5760,50 @@ struct OneToNVMIMaskedStoreOpPattern return rewriter.notifyMatchFailure( op, "masked_store value/mask physical arity mismatch"); + auto maskVMIType = cast(op.getMask().getType()); + if (std::optional dist = + getDenseLaneStrideStoreDistToken(valueVMIType)) { + std::optional maskGranularity = + getDenseLaneStrideMaskedStoreMaskGranularity(valueVMIType); + VMILayoutAttr valueLayout = valueVMIType.getLayoutAttr(); + VMILayoutAttr maskLayout = maskVMIType.getLayoutAttr(); + if (maskGranularity && valueLayout && maskLayout && + valueLayout == maskLayout) { + int64_t semanticOffset = 0; + for (auto [index, valueAndMask] : + llvm::enumerate(llvm::zip_equal(valueParts, maskParts))) { + auto [value, mask] = valueAndMask; + auto vregType = dyn_cast(value.getType()); + if (!vregType || !isa(mask.getType())) + return rewriter.notifyMatchFailure( + op, "lane_stride masked_store parts must be vreg/mask"); + FailureOr activeLanes = + getActiveDataLanesInPhysicalChunk(valueVMIType, index); + if (failed(activeLanes)) + return rewriter.notifyMatchFailure( + op, "failed to compute lane_stride masked_store active lanes"); + if (*activeLanes == 0) + continue; + FailureOr storeMask = createDenseLaneStrideStorePredicate( + op.getLoc(), valueVMIType, index, mask, *maskGranularity, + rewriter); + if (failed(storeMask)) + return rewriter.notifyMatchFailure( + op, "failed to compact lane_stride masked_store predicate"); + Value chunkOffset = + createChunkOffset(op.getLoc(), *offset, semanticOffset, rewriter); + rewriter.create(op.getLoc(), + /*updated_base=*/Type{}, value, *destination, + chunkOffset, rewriter.getStringAttr(*dist), + *storeMask); + semanticOffset += *activeLanes; + } + + rewriter.eraseOp(op); + return success(); + } + } + SmallVector contiguousValueTypes; contiguousValueTypes.reserve(valueParts.size()); for (Value value : valueParts) @@ -5346,7 +5814,6 @@ struct OneToNVMIMaskedStoreOpPattern if (failed(storeParts)) return failure(); - auto maskVMIType = cast(op.getMask().getType()); SmallVector contiguousMaskTypes; contiguousMaskTypes.reserve(maskParts.size()); for (Value mask : maskParts) @@ -6900,6 +7367,8 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { LogicalResult matchAndRewrite(VMIExtFOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { + auto sourceVMIType = cast(op.getSource().getType()); + auto resultVMIType = cast(op.getResult().getType()); ValueRange sourceParts = adaptor.getSource(); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); if (sourceParts.empty()) @@ -6930,6 +7399,34 @@ struct OneToNVMIExtFOpPattern : OneToNOpConversionPattern { unsigned sourceBits = pto::getPTOStorageElemBitWidth(sourceType.getElementType()); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + ((sourceBits == 16 && sourceLayout.getLaneStride() == 2) || + (sourceBits == 8 && sourceLayout.getLaneStride() == 4)) && + resultTypes.size() == sourceParts.size()) { + StringRef part = sourceBits == 16 ? StringRef("EVEN") : StringRef("P0"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure(op, "failed to build extf seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back(rewriter + .create(op.getLoc(), resultType, + sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(part)) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + ArrayRef parts; int64_t factor = 0; if (sourceBits == 16 && resultTypes.size() == 2 * sourceParts.size()) { @@ -7054,6 +7551,43 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { unsigned resultBits = pto::getPTOStorageElemBitWidth( resultVRegTypes.front().getElementType()); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isContiguous() && + resultLayout.getLaneStride() != 1 && + sourceParts.size() == resultTypes.size()) { + StringRef part; + if (resultBits == 16 && resultLayout.getLaneStride() == 2) + part = "EVEN"; + else if (resultBits == 8 && resultLayout.getLaneStride() == 4) + part = "P0"; + else + return rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride truncf result layout"); + + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + if (failed(sourceMask)) + return rewriter.notifyMatchFailure(op, + "failed to build truncf masks"); + + StringAttr rnd = rewriter.getStringAttr( + getTruncFRoundMode(op, resultVRegTypes.front().getElementType())); + StringAttr sat = rewriter.getStringAttr("SAT"); + StringAttr partAttr = rewriter.getStringAttr(part); + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, rnd, sat, partAttr) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + ArrayRef parts; int64_t factor = 0; if (resultBits == 16 && sourceParts.size() == 2 * resultTypes.size()) { @@ -7276,6 +7810,36 @@ struct OneToNVMIExtIOpPattern : OneToNOpConversionPattern { pto::getPTOStorageElemBitWidth(sourceType.getElementType()); unsigned resultBits = pto::getPTOStorageElemBitWidth( resultVRegTypes.front().getElementType()); + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + resultLayout.isContiguous() && resultLayout.getLaneStride() == 1 && + ((resultBits == sourceBits * 2 && + sourceLayout.getLaneStride() == 2) || + (resultBits == sourceBits * 4 && + sourceLayout.getLaneStride() == 4)) && + resultTypes.size() == sourceParts.size()) { + StringRef part = + resultBits == sourceBits * 2 ? StringRef("EVEN") : StringRef("P0"); + FailureOr mask = + createAllTrueMaskForVReg(op.getLoc(), sourceType, rewriter); + if (failed(mask)) + return rewriter.notifyMatchFailure( + op, "failed to build integer extension seed mask"); + + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultVRegTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, *mask, + /*rnd=*/nullptr, /*sat=*/nullptr, + rewriter.getStringAttr(part)) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + ArrayRef parts; int64_t factor = 0; if (resultBits == sourceBits * 2 && @@ -7436,6 +8000,34 @@ struct OneToNVMITruncIOpPattern : OneToNOpConversionPattern { return rewriter.notifyMatchFailure( op, "unsupported physical trunci source/result width relation"); int64_t factor = sourceBits / resultBits; + if (sourceLayout && resultLayout && sourceLayout.isContiguous() && + sourceLayout.getLaneStride() == 1 && resultLayout.isContiguous() && + resultLayout.getLaneStride() == factor && + sourceParts.size() == resultTypes.size()) { + if (factor != 2 && factor != 4) + return rewriter.notifyMatchFailure( + op, "unsupported dense lane_stride trunci result layout"); + StringAttr part = rewriter.getStringAttr(factor == 2 ? "EVEN" : "P0"); + FailureOr sourceMask = + createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); + if (failed(sourceMask)) + return rewriter.notifyMatchFailure(op, "failed to build trunci masks"); + + StringAttr sat = rewriter.getStringAttr("SAT"); + SmallVector results; + results.reserve(resultTypes.size()); + for (auto [sourcePart, resultType] : + llvm::zip_equal(sourceParts, resultTypes)) { + results.push_back( + rewriter + .create(op.getLoc(), resultType, sourcePart, + *sourceMask, /*rnd=*/nullptr, sat, part) + .getResult()); + } + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + if ((factor != 2 && factor != 4) || sourceParts.size() != resultTypes.size() * factor) return rewriter.notifyMatchFailure( @@ -9014,9 +9606,10 @@ verifySupportedVMIToVPTOOps(ModuleOp module, auto sourceType = cast(ensure.getSource().getType()); auto resultType = cast(ensure.getResult().getType()); std::string reason; - if (succeeded(checkSupportedLayoutMaterialization( - capabilities, sourceType, resultType, sourceType.getLayoutAttr(), - resultType.getLayoutAttr(), &reason))) + VMILayoutSupport supports; + if (succeeded( + supports.canMaterializeDataLayout(sourceType, resultType, + &reason))) return WalkResult::advance(); emitEnsureLayoutMaterializationError(ensure, sourceType, resultType, diff --git a/test/lit/vmi/vmi_lane_stride_dense_load_store.pto b/test/lit/vmi/vmi_lane_stride_dense_load_store.pto new file mode 100644 index 0000000000..2cc9cb2625 --- /dev/null +++ b/test/lit/vmi/vmi_lane_stride_dense_load_store.pto @@ -0,0 +1,197 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_layout_fold_load_lane_stride( + %src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + %strided = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_store_lane_stride( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + pto.vmi.store %compact, %dst[%off] + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_extf_lane_stride_even( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> { + %wide = pto.vmi.extf %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> + return %wide : !pto.vmi.vreg<64xf32, #pto.vmi.layout> + } + + func.func @vmi_extui_lane_stride_even( + %value: !pto.vmi.vreg<64xui16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui32, #pto.vmi.layout> { + %wide = pto.vmi.extui %value + : !pto.vmi.vreg<64xui16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui32, #pto.vmi.layout> + return %wide : !pto.vmi.vreg<64xui32, #pto.vmi.layout> + } + + func.func @vmi_layout_fold_store_lane_stride_b32( + %value: !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + %dst: !pto.ptr, %off: index) { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<32xf32, #pto.vmi.layout> + -> !pto.vmi.vreg<32xf32, #pto.vmi.layout> + pto.vmi.store %compact, %dst[%off] + : !pto.vmi.vreg<32xf32, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_layout_fold_load_store_lane_stride4_u8( + %src: !pto.ptr, %dst: !pto.ptr, %off: index) { + %load = pto.vmi.load %src[%off] + : !pto.ptr + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + %strided = pto.vmi.ensure_layout %load + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + pto.vmi.store %strided, %dst[%off] + : !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + !pto.ptr + return + } + + func.func @vmi_ensure_contiguous_to_lane_stride_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %strided = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_ensure_lane_stride_to_contiguous_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + return %compact : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + } + + func.func @vmi_ensure_contiguous_to_lane_stride4_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> { + %strided = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + return %strided + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + } + + func.func @vmi_ensure_lane_stride4_to_contiguous_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>) + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> { + %compact = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + return %compact : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + } +} + +// FOLD-LABEL: func.func @vmi_layout_fold_load_lane_stride( +// FOLD: %[[LOAD:.*]] = pto.vmi.load +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: return %[[LOAD]] + +// FOLD-LABEL: func.func @vmi_layout_fold_store_lane_stride( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD: pto.vmi.store %[[VALUE]] +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout + +// FOLD-LABEL: func.func @vmi_extf_lane_stride_even( + +// LOWER-LABEL: func.func @vmi_layout_fold_load_lane_stride( +// LOWER: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_store_lane_stride( +// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_extf_lane_stride_even( +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// LOWER-NOT: part = "ODD" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_extui_lane_stride_even( +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xui16>, !pto.mask -> !pto.vreg<64xui32> +// LOWER-NOT: part = "ODD" +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_store_lane_stride_b32( +// LOWER: pto.vsts {{.*}} {dist = "PK_B64"} : !pto.vreg<64xf32>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_layout_fold_load_store_lane_stride4_u8( +// LOWER: pto.vlds {{.*}} {dist = "UNPK4"} : !pto.ptr -> !pto.vreg<256xui8> +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_contiguous_to_lane_stride_f16( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xf16> -> !pto.vreg<128xui16> +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<64xui32> -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_lane_stride_to_contiguous_f16( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xf16> -> !pto.vreg<64xui32> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xui32> -> !pto.vreg<128xui16> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<128xf16> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_contiguous_to_lane_stride4_ui8( +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<256xui8> -> !pto.vreg<128xui16> +// LOWER: pto.vzunpack {{.*}} : !pto.vreg<128xui16> -> !pto.vreg<64xui32> +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<64xui32> -> !pto.vreg<256xui8> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_ensure_lane_stride4_to_contiguous_ui8( +// LOWER: pto.vbitcast {{.*}} : !pto.vreg<256xui8> -> !pto.vreg<64xui32> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<64xui32> -> !pto.vreg<128xui16> +// LOWER: pto.vpack {{.*}} "LOWER" : !pto.vreg<128xui16> -> !pto.vreg<256xui8> +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_lane_stride_masked_store.pto b/test/lit/vmi/vmi_lane_stride_masked_store.pto new file mode 100644 index 0000000000..25d1c661fe --- /dev/null +++ b/test/lit/vmi/vmi_lane_stride_masked_store.pto @@ -0,0 +1,85 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-fold | FileCheck %s --check-prefix=FOLD +// RUN: pto-test-opt %s -vmi-layout-fold -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_lane_stride2_masked_store_f16( + %value: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb16, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xf16, #pto.vmi.layout> + -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<64xb16, #pto.vmi.layout> + -> !pto.vmi.mask<64xb16, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<64xf16, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb16, #pto.vmi.layout> + return + } + + func.func @vmi_lane_stride4_masked_store_ui8( + %value: !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + %mask: !pto.vmi.mask<64xb8, #pto.vmi.layout>, + %dst: !pto.ptr, %offset: index) { + %value_c = pto.vmi.ensure_layout %value + : !pto.vmi.vreg<64xui8, #pto.vmi.layout> + -> !pto.vmi.vreg<64xui8, #pto.vmi.layout> + %mask_c = pto.vmi.ensure_mask_layout %mask + : !pto.vmi.mask<64xb8, #pto.vmi.layout> + -> !pto.vmi.mask<64xb8, #pto.vmi.layout> + pto.vmi.masked_store %value_c, %dst[%offset], %mask_c + : !pto.vmi.vreg<64xui8, #pto.vmi.layout>, + !pto.ptr, + !pto.vmi.mask<64xb8, #pto.vmi.layout> + return + } +} + +// FOLD-LABEL: func.func @vmi_lane_stride2_masked_store_f16( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb16, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<64xb16, #pto.vmi.layout> + +// FOLD-LABEL: func.func @vmi_lane_stride4_masked_store_ui8( +// FOLD-SAME: %[[VALUE:.*]]: !pto.vmi.vreg<64xui8, #pto.vmi.layout> +// FOLD-SAME: %[[MASK:.*]]: !pto.vmi.mask<64xb8, #pto.vmi.layout> +// FOLD-NOT: pto.vmi.ensure_layout +// FOLD-NOT: pto.vmi.ensure_mask_layout +// FOLD: pto.vmi.masked_store %[[VALUE]] +// FOLD-SAME: %[[MASK]] +// FOLD-SAME: !pto.vmi.vreg<64xui8, #pto.vmi.layout> +// FOLD-SAME: !pto.vmi.mask<64xb8, #pto.vmi.layout> + +// LOWER-LABEL: func.func @vmi_lane_stride2_masked_store_f16( +// LOWER-SAME: %[[VALUE:[^,]+]]: !pto.vreg<128xf16> +// LOWER-SAME: %[[MASK:[^,]+]]: !pto.mask +// LOWER: %[[COMPACT:.*]] = pto.punpack %[[MASK]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: pto.vsts %[[VALUE]], {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. + +// LOWER-LABEL: func.func @vmi_lane_stride4_masked_store_ui8( +// LOWER-SAME: %[[VALUE:[^,]+]]: !pto.vreg<256xui8> +// LOWER-SAME: %[[MASK:[^,]+]]: !pto.mask +// LOWER: %[[MID:.*]] = pto.punpack %[[MASK]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: %[[COMPACT:.*]] = pto.punpack %[[MID]], "LOWER" : !pto.mask -> !pto.mask +// LOWER: pto.vsts %[[VALUE]], {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xui8>, !pto.ptr, !pto.mask +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto index 147245c484..b16082f89a 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -59,19 +59,19 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( // ASSIGN: %[[X32:.*]] = pto.vmi.load -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN-NOT: pto.vmi.ensure_layout // ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[X16]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( -// LOWER: pto.vldsx2 +// LOWER: pto.vlds // LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} -// LOWER: pto.vor -// LOWER: pto.vsts +// LOWER-NOT: part = "ODD" +// LOWER-NOT: pto.vor +// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index 783baac750..550eed8f28 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -35,30 +35,32 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( // ASSIGN: %[[X32:.*]] = pto.vmi.load -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> +// ASSIGN: %[[X32_REDUCE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32_REDUCE]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[X8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( -// LOWER-COUNT-2: pto.vldsx2 -// LOWER-COUNT-2: pto.vdintlv +// LOWER-COUNT-4: pto.vlds +// LOWER-COUNT-4: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd // LOWER-COUNT-3: pto.vadd // LOWER: pto.vsts // LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} -// LOWER-COUNT-3: pto.vor -// LOWER: pto.vsts +// LOWER-NOT: part = "P1" +// LOWER-NOT: part = "P2" +// LOWER-NOT: part = "P3" +// LOWER-NOT: pto.vor +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto index 0ce6b6b295..4b7496a8f8 100644 --- a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto +++ b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto @@ -41,8 +41,10 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Y32:.*]] = pto.vmi.mulf %[[X32]], %[[SCALE]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[Y32_CONTIG:.*]] = pto.vmi.ensure_layout %[[Y32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32_CONTIG]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[Y8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( @@ -51,11 +53,11 @@ module { // LOWER-COUNT-4: pto.vdup // LOWER-COUNT-4: pto.vmul // LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} -// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} -// LOWER-COUNT-3: pto.vor -// LOWER: pto.vsts +// LOWER-NOT: part = "P1" +// LOWER-NOT: part = "P2" +// LOWER-NOT: part = "P3" +// LOWER-NOT: pto.vor +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto index 400c10093a..2118b0f5d2 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -55,10 +55,8 @@ module { // ASSIGN: pto.vmi.group_store %[[YSUM]] // ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[H]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( @@ -73,7 +71,7 @@ module { // LOWER: pto.vselr // LOWER: pto.vselr // LOWER: pto.vcvt -// LOWER: pto.vor -// LOWER: pto.vsts +// LOWER-NOT: pto.vor +// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto index b96238db35..712eddda7e 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto @@ -38,9 +38,9 @@ module { // CHECK: pto.vsldb // CHECK: pto.vcgadd // CHECK: pto.vintlv -// CHECK: pto.vdintlv -// CHECK: pto.vcvt -// CHECK: pto.vsts +// CHECK-NOT: pto.vdintlv +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto index e4600d1f77..685987b72c 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -61,18 +61,17 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Q:.*]] = pto.vmi.divf %[[X]], %[[SCALE_VEC]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Q_SPLIT:.*]] = pto.vmi.ensure_layout %[[Q]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q_SPLIT]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( // LOWER: pto.vcgmax // LOWER: pto.vmax // LOWER: pto.vsel // LOWER: pto.vdiv -// LOWER: pto.vdintlv -// LOWER: pto.vcvt +// LOWER-NOT: pto.vdintlv +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto index a492118cad..d3a8f54b87 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -39,12 +39,10 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32_SPLIT]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[B16]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( // LOWER: pto.vcgadd diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto index ca1ee6c921..689e7b0836 100644 --- a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -36,39 +36,34 @@ module { } // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[WIDE:.*]] = pto.vmi.load -// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN-NOT: pto.vmi.ensure_layout // ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf( -// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" -// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} -// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] -// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER: %[[P0:.*]] = pto.vlds %arg0[%arg1] +// LOWER: %[[NARROW:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: return %[[NARROW]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[WIDE:.*]] = pto.vmi.load // ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( // LOWER: pto.vsts -// LOWER: pto.vdintlv -// LOWER: pto.vcvt -// LOWER: return {{.*}} : !pto.vreg<128xf16> +// LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: return {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto index 1d3a2f3d0b..44ae6a19c5 100644 --- a/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_mask_granularity_f32_f16_store.pto @@ -56,7 +56,6 @@ module { // LOWER: pto.vor // LOWER: pto.ppack // LOWER: pto.ppack -// LOWER: pto.por // LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto index f7bc538518..f1be94a798 100644 --- a/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_packed_group_slots_truncf_invalid.pto @@ -19,8 +19,8 @@ module { {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> -> !pto.vmi.vreg<8xf32> - // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion - // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) + // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<8xf32, #pto.vmi.layout>'; pto.vmi.ensure_layout cannot materialize this conversion + // CHECK: failed helper conversion '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' -> '!pto.vmi.vreg<8xf32, #pto.vmi.layout>' (unsupported source/result layout pair) %h = pto.vmi.truncf %sum : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> pto.vmi.group_store %h, %dst[%off], %c1 {num_groups = 8} diff --git a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto index 141e85772b..bbf5afe97d 100644 --- a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto +++ b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto @@ -20,20 +20,15 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_truncf_ensure( // ASSIGN-SAME: %[[WIDE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: return %[[NARROW]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_truncf_ensure( -// LOWER-SAME: %[[D0:[^,]+]]: !pto.vreg<64xf32> -// LOWER-SAME: %[[D1:[^)]+]]: !pto.vreg<64xf32> -// LOWER: %[[P0:.*]], %[[P1:.*]] = pto.vdintlv %[[D0]], %[[D1]] -// LOWER: %[[EVEN:.*]] = pto.vcvt %[[P0]]{{.*}}part = "EVEN" -// LOWER: %[[ODD:.*]] = pto.vcvt %[[P1]]{{.*}}part = "ODD" -// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] -// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> +// LOWER-SAME: %[[D0:arg[0-9]+]]: !pto.vreg<64xf32> +// LOWER: %[[NARROW:.*]] = pto.vcvt %[[D0]]{{.*}}part = "EVEN" +// LOWER: return %[[NARROW]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto index dd28fbe21e..584c889233 100644 --- a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -48,11 +48,14 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: unrealized_conversion_cast // CHECK-LABEL: func.func @vmi_ptoas_cli_fold_pipeline -// CHECK: pto.vlds +// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} +// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} // CHECK: pto.vcvt {{.*}} {part = "EVEN"} -// CHECK: pto.vcvt {{.*}} {part = "ODD"} +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} +// CHECK-NOT: part = "ODD" // CHECK-NOT: pto.vintlv -// CHECK: pto.vstsx2 +// CHECK: pto.vsts +// CHECK: pto.vsts // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index c3a1a0fede..e596488dc3 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -252,12 +252,16 @@ module { // CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK-NOT: part = "ODD" +// CHECK-NOT: pto.vor +// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> -// CHECK: scf.if // CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask // CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask // CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask @@ -287,24 +291,19 @@ module { // CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-COUNT-4: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vdintlv // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK: scf.if +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask // CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask -// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask -// CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask // CHECK: pto.ppack {{.*}} : !pto.mask -> !pto.mask -// CHECK: pto.por {{.*}} : !pto.mask, !pto.mask, !pto.mask -> !pto.mask // CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto index 01e92013cc..7ecee61a27 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -30,18 +30,14 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( -// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vldsx2 {{.*}}, "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> -// CHECK: pto.vdintlv {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK-COUNT-4: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vdintlv // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vor {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK-NOT: part = "P1" +// CHECK-NOT: part = "P2" +// CHECK-NOT: part = "P3" +// CHECK-NOT: pto.vor +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index 5297123e5a..1a0f498140 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -6,20 +6,27 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s module { func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( - %input: !pto.vmi.vreg<128xf32>) { + %input: !pto.vmi.vreg<128xf32>, + %dst: !pto.ptr, + %off: index) { %packed = pto.vmi.truncf %input : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + pto.vmi.store %packed, %dst[%off] + : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr return } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf operand #0 has type {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: but requires {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: pto.vmi.ensure_layout cannot materialize this conversion -// CHECK: failed helper conversion {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: {{'?}}!pto.vmi.vreg<128xf32, #pto.vmi.layout> -// CHECK-SAME: requires source and result to have the same physical arity +// CHECK-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( +// CHECK-SAME: %[[P0:.*]]: !pto.vreg<64xf32>, %[[P1:.*]]: !pto.vreg<64xf32> +// CHECK: %[[R0:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: %[[R1:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vsts %[[R0]], {{.*}} {dist = "PK4_B32"} +// CHECK: pto.vsts %[[R1]], {{.*}} {dist = "PK4_B32"} +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vpto/vmi_truncf_hif8.pto b/test/lit/vpto/vmi_truncf_hif8.pto index 260c43ad7a..a638759e01 100644 --- a/test/lit/vpto/vmi_truncf_hif8.pto +++ b/test/lit/vpto/vmi_truncf_hif8.pto @@ -10,14 +10,12 @@ // CHECK-LABEL: func.func @vmi_truncf_hif8_default_kernel // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK-NOT: part = "P1" +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask // CHECK-LABEL: func.func @vmi_truncf_hif8_hybrid_kernel // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK-NOT: part = "P1" +// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { func.func @vmi_truncf_hif8_default_kernel(%src_gm: !pto.ptr, From 7a108098cef94e97b7af54669e618146ab5e27ee Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 30 Jun 2026 23:36:38 +0800 Subject: [PATCH 47/54] Support arity-driven VMI cast layouts --- docs/designs/vmi-implementation-manual.md | 39 +- .../vmi-lane-stride-generalization-design.md | 137 ++++++- ...ne-stride-generalization-implementation.md | 367 +++++++++++------- include/PTO/Transforms/VMILayoutSupport.h | 15 + lib/PTO/Transforms/VMILayoutAssignment.cpp | 49 +-- lib/PTO/Transforms/VMILayoutSupport.cpp | 112 +++++- lib/PTO/Transforms/VMIToVPTO.cpp | 20 +- test/lit/vmi/opt/README.md | 18 + .../lit/vmi/opt/compute_mrope_f16_vmi_opt.pto | 134 +++++++ .../opt/compute_y1_to_fp8_fp16_vmi_opt.pto | 144 +++++++ ..._layout_assignment_dense_f16_f32_store.pto | 78 +++- ..._layout_assignment_f32_f8_store_reduce.pto | 22 +- .../vmi_layout_assignment_f8_compute_f8.pto | 16 +- ...ignment_group_broadcast_multi_consumer.pto | 10 +- ...ut_assignment_group_load_block8_truncf.pto | 6 +- ...out_assignment_group_reduce_maxf_quant.pto | 14 +- ...roup_reduce_s16_truncf_broadcast_store.pto | 9 +- .../vmi/vmi_layout_assignment_load_truncf.pto | 30 +- ...signment_multi_return_conflict_invalid.pto | 2 +- .../vmi_layout_assignment_truncf_ensure.pto | 17 +- ...valid.pto => vmi_to_vpto_load_nonfull.pto} | 14 +- ...to => vmi_to_vpto_load_nonfull_memref.pto} | 14 +- ...load_safe_tail_memref_negative_offset.pto} | 12 +- .../vmi/vmi_to_vpto_memory_space_invalid.pto | 2 +- .../vmi/vmi_to_vpto_memref_layout_invalid.pto | 4 +- test/lit/vmi/vmi_to_vpto_quant_dequant.pto | 16 +- test/lit/vmi/vmi_to_vpto_quant_fp8.pto | 14 +- 27 files changed, 1021 insertions(+), 294 deletions(-) create mode 100644 test/lit/vmi/opt/README.md create mode 100644 test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto create mode 100644 test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto rename test/lit/vmi/{vmi_to_vpto_load_nonfull_invalid.pto => vmi_to_vpto_load_nonfull.pto} (62%) rename test/lit/vmi/{vmi_to_vpto_load_safe_tail_memref_invalid.pto => vmi_to_vpto_load_nonfull_memref.pto} (67%) rename test/lit/vmi/{vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto => vmi_to_vpto_load_safe_tail_memref_negative_offset.pto} (74%) diff --git a/docs/designs/vmi-implementation-manual.md b/docs/designs/vmi-implementation-manual.md index 76f6d966f7..98cd3e9ee5 100644 --- a/docs/designs/vmi-implementation-manual.md +++ b/docs/designs/vmi-implementation-manual.md @@ -1990,8 +1990,10 @@ vmi.store source layout deinterleaved=2: Do not generalize this to `deinterleaved=4` unless the two-level dist composition is proven against the ISA. The fallback for `deinterleaved=4` remains generic layout materialization plus ordinary memory ops. -Partial/tail load-style memory is legal only when the lowering can prove the full physical read footprint is safe. The -current direct path supports this limited proof: +Direct `vmi.load` is lowered as full VPTO physical reads when the source memory kind/layout is supported and the +element type has a known physical lane width, even for non-full logical vectors. Masked/expand/gather read-style +operations still require the lowering to prove that the full physical read footprint is safe, or to use a future +true masked/non-faulting fallback. The current proof handles: ```text source is a statically shaped memref @@ -1999,9 +2001,10 @@ offset is a constant non-negative index offset + physical_arity(result) * lanes_per_physical_part <= static memref element count ``` -When this proof holds, `vmi.load` may still issue full `pto.vlds` chunks. The extra padding lanes are -not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, dynamic -offsets, dynamic memrefs, and insufficient static footprints remain unsupported: +When this proof holds, masked/expand read-style operations may still issue full `pto.vlds` chunks. The extra padding +lanes are not logical VMI lanes and must remain unobservable through later VMI materialization rules. Pointer sources, +dynamic offsets, dynamic memrefs, and insufficient static footprints remain unsupported for those stricter read-style +operations: ```text VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read @@ -3559,7 +3562,7 @@ Unsupported diagnostics: non-splat pto.vmi.constant: VMI-UNSUPPORTED: non-splat pto.vmi.constant requires a vreg immediate or scratch materialization plan - partial/tail pto.vmi.load: + unsupported partial/tail masked/expand read-style op: VMI-UNSUPPORTED: pto.vmi. requires full physical chunks without padding lanes or a statically safe full-read footprint (...; safe-read proof failed: ...) GM-backed direct pto.vmi.load/masked_load/expand_load: @@ -3987,9 +3990,10 @@ Slice 4 完成条件: 2. `f8 -> f32 -> add -> store` lowers with deinterleaved=4 and stores contiguous logical order. Covered by vmi_to_vpto_e2e_widen_add_store.pto. 3. Non-full memory physical arity and valid lane map are tested. - Covered by vmi_to_vpto_load_nonfull_invalid.pto, vmi_to_vpto_store_deint_invalid.pto, + Covered by vmi_to_vpto_load_nonfull.pto, vmi_to_vpto_load_nonfull_memref.pto, + vmi_to_vpto_store_deint_invalid.pto, vmi_to_vpto_load_safe_tail_memref.pto, - vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto, + vmi_to_vpto_load_safe_tail_memref_negative_offset.pto, vmi_to_vpto_masked_load_safe_tail_memref.pto, vmi_to_vpto_masked_load_safe_tail_memref_negative_offset_invalid.pto, vmi_to_vpto_expand_load_all_active.pto, @@ -4042,10 +4046,12 @@ Slice 4 完成条件: ## 7. Slice 5: Memory Padding -The Slice 4 direct path may lower full-footprint `load/store` when the -physical memory footprint is statically safe. Do not lower any partial, -padded, or out-of-bounds read-like operation as a plain `pto.vlds` until a -richer access plan proves it is safe. +The Slice 4 direct path lowers `pto.vmi.load` through plain `pto.vlds` when the +memory source itself is supported and the element type has a known physical lane +width. This includes non-full logical vectors; the operation is treated as a +direct full physical read of the selected VPTO chunk(s). Masked/expand/gather +read-like operations still use the richer access plan because their masks or +lane maps carry additional semantic constraints. Implement an internal `VMIMemoryAccessPlan`: @@ -4084,15 +4090,14 @@ currently routed through the plan: stable gather masked-load option covered by vmi_to_vpto_stable_gather_masked_load_todo_invalid.pto currently emits a TODO diagnostic instead of lowering through VGATHER2 - direct pto.vmi.load partial/tail safe full-read proof + direct pto.vmi.load source/layout capability check for full physical reads pto.vmi.masked_load partial/tail safe full-read proof pto.vmi.expand_load static all-active safe full-read proof - VMI-to-VPTO rewrite match guard for load full-or-safe reads + VMI-to-VPTO rewrite match guard for supported direct load sources pto.vmi.store direct write target decision with all-true writeMask kind pto.vmi.masked_store direct write target decision with explicit writeMask kind - unsafe partial/tail read fallback decision as RequiredUnavailable diagnostic - covered by vmi_to_vpto_load_nonfull_invalid.pto, - vmi_to_vpto_masked_load_nonfull_invalid.pto, and + unsafe masked/expand partial/tail read fallback decision as RequiredUnavailable diagnostic + covered by vmi_to_vpto_masked_load_nonfull_invalid.pto and vmi_to_vpto_expand_load_all_active_negative_offset_invalid.pto currently not implemented by the plan: diff --git a/docs/designs/vmi-lane-stride-generalization-design.md b/docs/designs/vmi-lane-stride-generalization-design.md index a839e337c3..e5134bede3 100644 --- a/docs/designs/vmi-lane-stride-generalization-design.md +++ b/docs/designs/vmi-lane-stride-generalization-design.md @@ -399,18 +399,71 @@ op's assigned result layout is fixed. The preferred direction for this optimization is not "notice the input is already strided". The conversion op can be the layout-entry point and compute a -single preferred layout fact for the current op instance: +single preferred layout fact for the current op instance. The choice must be +arity-driven, not special-cased by a spelling such as `64xf32`. + +For source/result logical lane count `N`, let: + +```text +natural result layout: + source dense factor F, lane_stride 1 + result dense factor F * W, lane_stride 1 + +compact result layout: + result keeps source dense factor F and uses lane_stride 1 + source uses lane_stride W inside the same dense factor F +``` + +The self-preferred widening rule is: + +```text +if physical_arity(compact result) < physical_arity(natural result) + and target supports the required source lane_stride relation: + choose compact result and request source lane_stride=W +else: + choose natural result deinterleaved by W +``` + +For ordinary contiguous `f16 -> f32` this gives: + +```text +64xf32: + compact arity = 1 + natural deinterleaved=2 arity = 2 + choose source lane_stride=2, result contiguous + +128xf32: + compact arity = 2 + natural deinterleaved=2 arity = 2 + choose natural result deinterleaved=2 + +256xf32: + compact arity = 4 + natural deinterleaved=2 arity = 4 + choose natural result deinterleaved=2 +``` + +If the source is already deinterleaved by `F`, the natural result factor is +`F * W`. For example, `deinterleaved=2 f16 -> f32` naturally produces +`deinterleaved=4 f32`. + +The same arity rule applies to other widening ratios and types. For example, +`ui8 -> ui32` has `W=4`; a lane-stride source is preferred only when the +contiguous result has fewer physical chunks than the natural +`deinterleaved=4` result and the target supports the `lane_stride=4` relation. + +The two layout facts are therefore: ```text baseline fact: source contiguous result deinterleaved=W - cost: W conversion parts + arity: physical_arity(result deinterleaved=W) lane-stride fact: source lane_stride=W result contiguous - hardware conversion parts: one + arity: physical_arity(result contiguous) source layout request: explicit ``` @@ -428,9 +481,9 @@ assignment model. ### 4.4 Narrowing Conversion -Narrowing is the inverse relation. If source element width is `W` times the -result element width, a single hardware narrowing part can produce a -phase-zero strided result when: +Narrowing uses the same arity-driven idea in the opposite direction. If source +element width is `W` times the result element width, a single hardware narrowing +part can produce a phase-zero strided result when: ```text result lane_stride = source lane_stride * W @@ -446,6 +499,66 @@ ui32 -> ui16/ui8 ui16 -> ui8 ``` +The natural narrowing relation is the inverse of natural widening: + +```text +source dense factor F * W, lane_stride 1 +result dense factor F, lane_stride 1 +``` + +The compact-store-oriented relation is: + +```text +source keeps dense factor F and lane_stride 1 +result keeps dense factor F and uses lane_stride W +``` + +Narrowing has the same candidate family as widening. The arity comparison is +made on the source side, because the compact relation keeps the source +contiguous while the natural relation may require a deinterleaved source. + +The self-preferred narrowing rule is: + +```text +if physical_arity(compact contiguous source) + < physical_arity(natural deinterleaved source) + and physical_arity(compact source) == physical_arity(strided result) + and target supports the source-contiguous/result-lane_stride relation: + choose source contiguous, result lane_stride=W +else: + choose natural deinterleaved-source to contiguous-result relation +``` + +Use-site requests may still select the strided relation when a later consumer +can directly consume it: + +```text +if a consumer requests result lane_stride=W + and target supports source-contiguous/result-lane_stride narrowing: + request source contiguous + set or rematerialize result lane_stride=W +``` + +For ordinary `f32 -> f16`: + +```text +64xf32 -> 64xf16: + natural source deinterleaved=2 arity = 2 + compact source contiguous arity = 1 + choose source contiguous, result lane_stride=2 + +128xf32 -> 128xf16: + natural source deinterleaved=2 arity = 2 + compact source contiguous arity = 2 + choose natural source deinterleaved=2, result contiguous +``` + +So trunc should not blindly create a lane-stride result for every narrowing. +It should apply the same arity/support checks as ext. A consumer may still +request a strided result when that layout is useful, such as an unmasked compact +store lowered with `PK`/`PK4`. For masked stores, the value and mask must share +the same lane map before a direct packed masked store is legal. + The exact supported parts are target-op dependent. The layout assignment layer should ask the op support interface whether a given source/result layout pair is legal, rather than encoding type-specific shortcuts. @@ -589,6 +702,11 @@ store: consumes contiguous f32 directly ``` +Assignment chooses the lane-stride plan for this shape because the contiguous +`64xf32` result uses one physical chunk while the natural deinterleaved result +uses two physical chunks. This decision is made by the cast arity rule, not by +a pattern that names `64xf32` directly. + The load side then has two concrete outcomes: ```text @@ -600,7 +718,7 @@ accepted direct load fold: no direct load fold: keep the explicit source ensure_layout lower it through register pack/unpack if that materialization is supported - otherwise keep the baseline contiguous-source/deinterleaved-result relation + otherwise validation rejects the unsupported assigned relation ``` This case proves that `extf` can be the layout-entry point, while `load` support @@ -665,8 +783,9 @@ consumer then needs contiguous -> deinterleaved=2 materialization ``` The baseline plan should win. A lane-stride fact is not useful when it creates a -layout the consumer does not want; for full chunks it may not reduce the -conversion count either. +layout the consumer does not want. The cast arity rule also does not prefer +lane_stride here: `128xf32` contiguous and `128xf32 deinterleaved=2` both use +two physical chunks. ### 6.4 One Ext Result Feeding Store And Reduce diff --git a/docs/designs/vmi-lane-stride-generalization-implementation.md b/docs/designs/vmi-lane-stride-generalization-implementation.md index 4faeb7bf4e..4a3b251a41 100644 --- a/docs/designs/vmi-lane-stride-generalization-implementation.md +++ b/docs/designs/vmi-lane-stride-generalization-implementation.md @@ -42,18 +42,19 @@ Current stage status: | Dense layout attrs | Supported | Dense contiguous/deinterleaved layouts carry `lane_stride`; group-slot carrier layout remains separate. | | Direct compact load/store | Supported for selected phase-zero maps | LS=2 b8/b16/b32 through `UNPK_B8/B16/B32` and `PK_B16/B32/B64`; LS=4 b8 through `UNPK4` and `PK4_B32`. | | Load/store layout folds | Supported with one-load/one-store preservation | `load -> ensure_layout(lane_stride)` rewrites the original load layout when all uses agree; `ensure_layout(lane_stride -> contiguous) -> store` lets the VMI store consume the lane-stride value. | -| Dense widening ext | Supported | Source lane_stride=W can lower to a single `vcvt` part when the cast relation matches. | -| Dense narrowing trunc | Supported for ordinary dense store paths | Source contiguous, result lane_stride=W, then direct compact store when supported. | +| Dense widening ext | Supported | `getPreferredCastLayoutFact` chooses the arity-reducing source `lane_stride=W` / result contiguous relation when it beats the natural deinterleaved result; otherwise it keeps the natural relation. | +| Dense narrowing trunc | Supported for dense natural paths | `getPreferredCastLayoutFact` uses the same arity rule in the inverse direction, so trunc keeps the natural deinterleaved-source / contiguous-result relation unless a compact relation actually reduces arity. | | Masked compact store | Partially supported | Legal only when value and mask have the same lane map and the mask can be compacted for the selected store dist. | | Masked trunc tail | Not optimized yet | Keep the existing legal path until mask lane-stride assignment/materialization is available. | | Register fallback | Partially supported | Only same-physical-arity contiguous `<->` lane_stride paths with legal pack/unpack carriers. Arity-changing fallback is not in scope for this stage. | | Group broadcast load | Supported only through specific strategies | `group_broadcast_load` remains a VMI semantic; E2B is one strategy with exact shape/layout constraints. | Remaining design/implementation work from this discussion is intentionally -limited to two areas: +limited to these areas: | Area | Work to settle | Required proof before enabling | |---|---|---| +| Cast assignment | Keep `getPreferredCastLayoutFact` as the single op-local preferred relation helper, but make it shape-aware: compute the natural relation, compute the compact lane-stride relation, and select compact only when physical arity improves. | `64xf16 -> 64xf32` chooses source `lane_stride=2` and result contiguous; `128/256xf16 -> f32` keep natural `deinterleaved=2`; dense trunc keeps the natural relation unless compact arity wins. | | Masked store | Let `masked_store` request the same lane map for value and mask, or keep the existing legal path when the mask cannot be assigned/rematerialized into that lane map. | No path may lower a lane-stride value with a stale contiguous user mask; lowering must compact the assigned mask into the packed-store predicate. | | Group broadcast load | Keep `group_broadcast_load` as a VMI logical operation and make E2B only one support/lowering strategy selected by shape, element width, stride, and assigned result layout. | A failed E2B match must mean "this lowering strategy is unavailable", not "the VMI op is invalid" unless no fallback strategy is registered. | @@ -204,55 +205,26 @@ and mask, assignment must keep masked-tail narrowing on an existing legal path instead of choosing a lane-stride trunc result solely because the store could otherwise use `PK`. -Concrete implementation plan for lane-stride `masked_store`: +Current-stage implementation: ```text lib/PTO/Transforms/VMILayoutAssignment.cpp -1. Replace the current VMIMaskedStoreOp consumer request: - requestDataUse(value, contiguous) - requestMaskUse(mask, contiguous, elementGranularity) - - with a helper: - getPreferredMaskedStoreUseRequest(store) - -> optional {valueLayout, maskLayout, maskGranularity} - -2. The helper must be support-query driven: - candidateValueLayout = getPreferredDenseStoreUseLayout(store.value) - if candidateValueLayout is not dense lane_stride: - return none - - sourceValueType = value type with candidateValueLayout - resultValueType = value type with contiguous layout - sourceMaskType = mask type with: - elementCount = value lanes - granularity = getMaskGranularityForElement(value element type) - layout = candidateValueLayout - resultMaskType = same mask granularity with contiguous layout - - require: - canFoldContiguousMaskedStoreMaterialization( - sourceValueType, sourceMaskType, - resultValueType, resultMaskType) - - return {candidateValueLayout, candidateValueLayout, maskGranularity} - -3. The store request becomes: - if helper returns a request: - requestDataUse(store.value, request.valueLayout) - requestMaskUse(store.mask, request.maskLayout, request.maskGranularity) - else: - requestDataUse(store.value, contiguous) - requestMaskUse(store.mask, contiguous, elementGranularity) +VMIMaskedStoreOp keeps the existing conservative request: + requestDataUse(value, contiguous) + requestMaskUse(mask, contiguous, elementGranularity) -4. Replace the coarse hasMaskedStoreUse(trunc.result) guard with a support - predicate. A trunc result should stay on the conservative path only when a - masked_store use cannot request the same lane_stride value/mask layout. - Do not make trunc inspect mask producers directly; masked_store remains the - consumer that owns the joint value/mask request. +trunc assignment does not inspect masked_store users and does not preserve a +special masked-store guard. It records the source/result relation returned by +getPreferredCastLayoutFact. If that conflicts with a masked_store contiguous +request, normal assignment conflict handling inserts the required +ensure_layout. ``` -The intended dataflow after assignment is: +Future lane-stride `masked_store` support must be added as an explicit +consumer-owned extension, not as a trunc special case. The future dataflow must +prove that value and mask share the same lane map before a packed masked store +is legal: ```text %n = vmi.trunc* %wide @@ -264,18 +236,8 @@ The intended dataflow after assignment is: vmi.masked_store %n, %dst[%off], %m_ls ``` -or, if the mask producer can already produce the requested lane map: - -```text -%m_ls = mask producer result - : !vmi.mask<..., layout = contiguous, lane_stride = W> - -vmi.masked_store %n, %dst[%off], %m_ls -``` - -The assignment pass does not emit VPTO predicate compaction. It only creates -the local VMI proof that value and mask have the same lane map. Existing later -stages then do the mechanical work: +That future extension would need the same local VMI proof before lowering can do +the mechanical predicate compaction: ```text vmi-layout-fold: @@ -289,25 +251,9 @@ vmi-to-vpto: emits vsts PK_B16/PK_B32/PK4_B32 as selected by value element width/layout ``` -Required masked-store tests: +Future negative tests should cover the fallback: ```text -assignment positive: - truncf/trunci -> masked_store where value LS=2 b16 or LS=4 b8 is supported - CHECK value result layout is lane_stride - CHECK mask use is requested as the same lane_stride, with ensure_mask_layout - when the original mask is contiguous - -fold positive: - ensure_layout(value lane_stride -> contiguous) and - ensure_mask_layout(mask lane_stride -> contiguous) feeding masked_store - CHECK masked_store consumes the lane_stride value/mask directly - -lowering positive: - lane_stride value + same-lane-map mask feeding masked_store - CHECK mask compaction uses punpack - CHECK store dist is PK_B32 for b16 LS=2 and PK4_B32 for b8 LS=4 - fallback: mask cannot be assigned/materialized to the candidate lane_stride CHECK masked_store keeps contiguous value/mask request @@ -439,8 +385,12 @@ pto-validate-vmi-ir: vmi-layout-assignment: assign explicit dense layouts, including lane_stride, on VMI value types use op support queries to choose local cast relations: - widening can request source lane_stride=W and result contiguous - narrowing can request source contiguous and result lane_stride=W + widening compares natural deinterleaved result arity with compact + contiguous result arity; when compact wins, request source lane_stride=W + and set result contiguous + narrowing supports the inverse relation; when arity or a supported consumer + request chooses a strided result, request source contiguous and set result + lane_stride=W keep unsupported or conflicting uses legal by inserting ensure_layout serialize all decisions as type attrs or helper ops do not clone producers, fold memory ops, or solve a global cost problem @@ -554,10 +504,11 @@ lib/PTO/Transforms/VMILayoutSupport.cpp LaneStrideToContiguousViaPack LaneStrideToLaneStrideViaContiguous, only if needed update getPreferredCastLayoutFact: - baseline facts remain contiguous <-> deinterleaved - add optional preferred dense lane-stride fact from the conversion ratio W - and the target single-part cast support; do not inspect source producers - here + keep an internal baseline natural relation for dense widening/narrowing + compute the compact lane-stride relation from the same conversion ratio + select compact only when source/result physical arities match and the + relevant arity is strictly smaller than the baseline relation + use the returned source/result layouts for both ext and trunc assignment update getWidenSourceLayoutForResultLayout for dense lane_stride result/source update getContiguousStoreSupport and canFoldContiguousStoreMaterialization for LS=2 b8/b16/b32 -> PK_B16/B32/B64 @@ -574,8 +525,8 @@ lib/PTO/Transforms/VMILayoutAssignment.cpp lib/PTO/Transforms/VMILayoutRematerialize.cpp allow cheap producers to be cloned with dense lane_stride result types when VMILayoutSupport says the producer can directly create that lane map - keep ordinary load/group_load/masked_load blocked until a safe-read proof is - added for the specific direct UNPK lowering + keep ordinary load/group_load/masked_load cloning blocked until a safe-read + proof is added for the specific rematerialized memory operation lib/PTO/Transforms/VMILayoutFold.cpp add producer-side fold for load -> ensure_layout: @@ -692,7 +643,8 @@ or lower the resulting explicit IR, but should not solve the same rewrite again. | Direct load produces requested lane map | `load(contiguous) -> ensure_layout(lane_stride=2)` | `vmi-layout-fold` rewrites the original load result layout when UNPK support exists | Remat must not clone this load without safe-read proof | | Direct store consumes lane map | `ensure_layout(lane_stride -> contiguous) -> store` | `vmi-layout-fold` rewrites the VMI store to consume the lane_stride source directly when direct compact-store support exists | `vmi-to-vpto` emits the actual `vsts PK/PK4` | | Cheap producer can produce target layout | `broadcast -> ensure_layout(lane_stride=2)` | `vmi-layout-rematerialize` rebuilds broadcast with lane-stride result | Fold does not rebuild arbitrary producers | -| Widening ext can move materialization to cheap source | `ext -> ensure_layout(contiguous)` with source broadcast/load-fold case | `vmi-layout-rematerialize` rebuilds ext with required source lane stride | Assignment only creates the helper; fold only handles the load subcase | +| Cast chooses arity-reducing relation | `64xf16 -> 64xf32` or a supported narrowing with smaller strided result | `vmi-layout-assignment` chooses the cast source/result layout relation | Remat only handles later use-site requests; fold only handles adjacent load/store helpers | +| Cast can move materialization to cheap source | `ext/trunc -> ensure_layout(requested layout)` with source broadcast/load-fold case | `vmi-layout-rematerialize` rebuilds the cast with the requested relation | Assignment may already choose the self-preferred relation; fold only handles the load/store subcase | | Layout-transparent op has ensured operands | `ensure(a), ensure(b) -> add` | `vmi-layout-sink-materialization` sinks matching helpers to the result | Remat handles the opposite shape `add -> ensure` | | Surviving supported helper | `ensure_layout(contiguous <-> lane_stride)` after optimizations | `vmi-to-vpto` lowers to register pack/unpack | Earlier passes are allowed to leave it explicit | | Unsupported helper or layout | `lane_stride=4 b16 compact store` | `pto-validate-vmi-layout-ir` rejects before lowering | `vmi-to-vpto` should not invent a repair | @@ -959,13 +911,18 @@ elementwise: all dense operands/results must use identical dense layout key extf/extui/extsi: - source/result layouts must satisfy a widening relation + source/result layouts must satisfy a widening relation. Assignment chooses + between the natural deinterleaved relation and the compact-result + lane-stride-source relation by comparing physical arity, not by matching a + concrete lane count such as 64. truncf/trunci: - dense narrowing may request source contiguous and result lane_stride=W, where - W is the storage-width narrowing factor; masked-store consumers stay on the - existing legal deinterleaved-to-contiguous path until mask lane-stride - assignment/materialization is available + source/result layouts must satisfy a narrowing relation. Assignment uses the + inverse relation conservatively: keep the natural deinterleaved-source to + contiguous-result relation unless arity or a supported consumer request + selects a strided result relation. Masked-store consumers may only use the + strided result relation when the value and mask can be assigned/materialized + to the same lane map. broadcast/group_broadcast: result may use a dense layout only when the materialization lowering has an @@ -984,15 +941,69 @@ store: Assignment should still insert `ensure_layout` for incompatible use-local requests. Rematerialization/fold can later remove it. -### 4.1 Current Framework Fit +### 4.1 Cast Relation Helper Shape + +Keep `getPreferredCastLayoutFact` as the assignment entry point for dense +widening and narrowing casts, but make the helper return the actual preferred +source/result relation for the current shape. Internally it first builds the +natural relation: + +```text +widen: + source contiguous + result deinterleaved=W + +narrow: + source deinterleaved=W + result contiguous +``` + +Then it computes the compact relation: + +```text +widen: + source contiguous, lane_stride=W + result contiguous + +narrow: + source contiguous + result contiguous, lane_stride=W +``` + +The compact relation is selected only when its source/result physical arities +match and it strictly reduces the relevant baseline arity: + +```text +widen: + physical_arity(compact result) < physical_arity(natural result) + +narrow: + physical_arity(compact source) < physical_arity(natural source) +``` + +If the compact relation does not win, the helper returns the natural relation. +`vmi-layout-assignment` calls this helper for `extf/extui/extsi` and +`truncf/trunci`, requests the returned source layout, and records the returned +result layout. + +The support query must validate the returned pair before assignment commits it: + +```text +supportsExtRelation(sourceTypeWithLayout, resultTypeWithLayout) +supportsTruncRelation(sourceTypeWithLayout, resultTypeWithLayout) +``` + +The validation step is a legality check, not a second optimizer. + +### 4.2 Current Framework Fit The existing assignment pass already has use-site requests. For example, `pto.vmi.store` requests a contiguous source operand, and assignment can insert `ensure_layout` when the stored value is assigned another layout. The dense-stride `ext` optimization should keep the same model: the cast op is -the layout-entry point and stores one preferred source/result relation. The -current preferred relation is: +the layout-entry point and stores one preferred source/result relation. The old +preferred relation was: ```text extf: @@ -1000,8 +1011,8 @@ extf: set result deinterleaved=W ``` -The current stage keeps the existing single-preference framework -and let `ext` choose one fact for the current op: +The current stage keeps the existing single-preference framework and lets +`ext` choose one fact for the current op: ```text baseline fact: @@ -1018,14 +1029,19 @@ The `ext` support query chooses between these facts from op-local information: ```text conversion ratio W target support for one selected hardware conversion part -requested or preferred result layout for the current op instance +physical arity of the natural result layout +physical arity of the compact contiguous result layout +requested result layout when a consumer materialization/remat path provides one ``` -It does not inspect the defining source producer. If it selects the -lane-stride fact and the source is not already in that layout, assignment inserts -an explicit source `ensure_layout`. Later passes either discharge that helper by -rematerializing/folding a concrete producer, lower it with a registered -pack/unpack materializer, or let validation reject the unsupported relation. +It does not inspect the defining source producer. If compact result arity is +strictly smaller than natural result arity and the target supports the +single-part relation, it selects the lane-stride fact. If it selects the +lane-stride fact and the source is not already in that layout, assignment +inserts an explicit source `ensure_layout`. Later passes either discharge that +helper by rematerializing/folding a concrete producer, lower it with a +registered pack/unpack materializer, or let validation reject the unsupported +relation. ## 5. Widening Conversion Lowering @@ -1078,19 +1094,31 @@ chooses one immediately: baseline fact: source contiguous result deinterleaved=W - lowering cost = W conversion parts + natural result arity = physical_arity(result deinterleaved=W) lane-stride fact: result contiguous source same dense shape with lane_stride = W - lowering cost = one conversion part + compact result arity = physical_arity(result contiguous) ``` -For example, for `f16 -> f32`, the `extf` op can prefer -`source lane_stride=2 -> result contiguous` when the target has a single EVEN -conversion for that relation. The source producer is handled by the explicit -source `ensure_layout` and later fold/rematerialization; it is not part of the -cast support query. +Assignment uses this deterministic rule: + +```text +if compact result arity < natural result arity + and the lane-stride fact is supported: + choose lane-stride fact +else: + choose baseline fact +``` + +For example, for `f16 -> f32`, the `extf` op chooses +`source lane_stride=2 -> result contiguous` for `64xf32`, because the compact +result has one physical chunk while the natural `deinterleaved=2` result has two +physical chunks. For `128xf32` and `256xf32`, both layouts have the same result +arity, so assignment chooses the natural `deinterleaved=2` result. The source +producer is handled by the explicit source `ensure_layout` and later +fold/rematerialization; it is not part of the cast support query. Current contiguous widening remains a separate legal relation: @@ -1102,9 +1130,9 @@ result deinterleaved=W, lane_stride=1 Implementation steps: 1. Factor conversion ratio calculation by storage bit width. -2. Add helper that computes baseline conversion count. -3. Add helper that computes lane-stride conversion count and required source - layout. +2. Add helper that computes the natural result layout and its physical arity. +3. Add helper that computes the compact result layout, required source + lane-stride layout, and compact result physical arity. 4. Teach `VMIToVPTO` conversion lowering to emit only the selected hardware part when the relation is single-part. 5. Keep existing multi-part lowering for contiguous-to-deinterleaved cases. @@ -1139,13 +1167,54 @@ result lane_stride = source lane_stride * W hardwarePart = 0 for the current stage ``` +The narrowing assignment relation is the inverse of widening, but it must not +blindly choose a lane-stride result. Build two facts: + +```text +baseline fact: + source deinterleaved=W + result contiguous + natural result arity = physical_arity(result contiguous) + +lane-stride fact: + source contiguous + result contiguous, lane_stride=W + strided result arity = physical_arity(result lane_stride=W) +``` + +Then choose a strided result only when it is justified: + +```text +if strided result arity < natural result arity + and the lane-stride fact is supported: + choose lane-stride fact +else if a consumer/requested result layout is the strided result + and the lane-stride fact is supported: + choose or rematerialize lane-stride fact +else: + choose baseline fact +``` + +This keeps trunc symmetric with ext while avoiding the earlier mistake of +producing lane_stride solely because the operation is a narrowing cast. A +consumer may still request or preserve a strided result. For example, an +ordinary store with direct `PK` support can consume a supported lane-stride +result, and rematerialization/fold may keep that relation. A masked store may +do so only when the mask can be assigned/materialized to the same lane map. + Implementation steps: -1. Share ratio and lane-map helpers with widening. -2. Add support query for valid narrowing layout pairs. -3. Lower single-part narrowing directly when the target has a part-selecting +1. Share ratio, dense-factor, lane-map, and physical-arity helpers with + widening. +2. Add helper that computes the natural source/result relation and result + arity. +3. Add helper that computes the strided-result relation and result arity. +4. Add support query for valid narrowing layout pairs. +5. Teach assignment/rematerialization to select the strided fact for explicit + result requests, direct compact-store consumers, or true arity reductions. +6. Lower single-part narrowing directly when the target has a part-selecting narrow instruction. -4. Preserve existing deinterleaved-to-contiguous narrowing for the packed full +7. Preserve existing deinterleaved-to-contiguous narrowing for the packed full result case. This is the same family as the recently discussed `d4 -> c -> d2 -> vcvt -> c` @@ -1201,8 +1270,8 @@ mask: VMIConstantMaskOp special rewrite: - selected VMITruncIOp through a source ensure_layout when the cast relation is - a supported narrowing relation + selected VMITruncFOp / VMITruncIOp through source/result ensure_layout when + the cast relation is a supported narrowing relation ``` Not included as cheap producers in the current pass: @@ -1222,7 +1291,9 @@ Relationship between cheap producers and dense `lane_stride`: ```text assignment: creates the target layout request explicitly, usually as ensure_layout(... -> - lane_stride) or as a cast source/result relation + lane_stride) or as a cast source/result relation. For casts, assignment may + itself choose the arity-reducing lane-stride relation; remat only reacts to + later use-site layout requests. rematerialize: does not choose lane_stride as a preference @@ -1254,9 +1325,19 @@ widening ext: remat then inserts/uses source ensure_layout and rebuilds ext with the requested result layout -trunci special rewrite: +narrowing trunc: + add getNarrowSourceLayoutForResultLayout or an equivalent relation helper. + For a requested result lane_stride=R and narrowing ratio W, derive the source + layout that can produce that result with a selected hardware part: + result lane_stride=W, W=2 -> source contiguous + result lane_stride=R, W=2 -> source lane_stride=R/W when divisible + remat then inserts/uses source ensure_layout and rebuilds trunc with the + requested result layout + +trunc source-ensure rewrite: extend the existing source-ensure rewrite to recognize lane_stride narrowing - relations, not only deinterleaved narrowing relations + relations for VMITruncFOp and VMITruncIOp, not only deinterleaved narrowing + relations mask producers: only participate after mask layout support defines the corresponding @@ -1283,15 +1364,16 @@ ext as the single selected conversion part. It is still driven by the explicit layout request; remat does not inspect sibling consumers or choose lane_stride by itself. -Do lane-stride ext rematerialization only in these cases: +Do lane-stride cast rematerialization only in these cases: ```text required shape: - ext result is followed by ensure_layout to a requested dense result layout - widening ratio W > 1 - the requested result lane_stride is R, where contiguous means R=1 - source lane_stride = R * W is supported - ext with that source layout can lower as one selected conversion part + cast result is followed by ensure_layout to a requested dense result layout + widening or narrowing ratio W > 1 + the requested source/result layout pair is accepted by the cast relation + helper + the cast with that source/result layout can lower as one selected conversion + part or the existing multi-part relation acceptance/safety gate: the source-side lane_stride request must be discharged by a concrete local @@ -1306,8 +1388,8 @@ acceptance/safety gate: concrete producer cases is reached do not apply: - result consumer already accepts the natural ext layout - source lane_stride = R * W is unsupported + result consumer already accepts the natural cast layout + requested cast layout relation is unsupported source is an ordinary load with other incompatible consumers and no safe-read proof to clone it the rewrite only moves an expensive materialization from result side to source @@ -1331,6 +1413,15 @@ load -> ensure_layout(lane_stride=W) -> ext -> store elementwise cheap chain -> ext -> ensure_layout(contiguous) remat/sink the chain to lane_stride=W only when the chain reaches a concrete cheap producer or direct load-fold case + +trunc -> ensure_layout(lane_stride=W) -> compact store + remat/rebuild trunc with the requested lane_stride result when the source + layout relation is supported + store fold may then consume the lane_stride result directly + +trunc -> ensure_layout(lane_stride=W) -> masked_store + only accepted after mask layout assignment can provide the same lane map for + the predicate; otherwise keep the conservative contiguous masked-store path ``` ## 8. Broadcast And E2B Interaction @@ -1522,6 +1613,10 @@ bf16 lane_stride=2 -> f32 contiguous follows the same relation ui8 lane_stride=2 -> ui16 contiguous follows W=2 ui8 lane_stride=4 -> ui32 contiguous follows W=4 when target supports it contiguous f16 -> deinterleaved=2 f32 still emits EVEN + ODD +f32 contiguous -> f16 lane_stride=2 emits the selected narrowing part when the +assigned relation is supported +f32 deinterleaved=2 -> f16 contiguous keeps the existing packed full-result +narrowing relation ui16 lane_stride=2 -> contiguous can materialize with vpack 32->16 carrier path ui8 lane_stride=4 -> contiguous can materialize with two vpack stages ``` @@ -1529,8 +1624,14 @@ ui8 lane_stride=4 -> contiguous can materialize with two vpack stages Assignment/rematerialization: ```text -extf records a strided dense source relation for a supported single-part -widening conversion +extf records a strided dense source relation when compact result arity is +smaller than natural result arity +extf 64xf16 -> 64xf32 chooses source lane_stride=2, result contiguous +extf 128xf16 -> 128xf32 chooses result deinterleaved=2 +extf 256xf16 -> 256xf32 chooses result deinterleaved=2 +truncf records a strided result relation only when the conservative +self-preference/support rule or a supported consumer request selects it; it +does not choose lane_stride solely because the op narrows layout-transparent op propagates the same strided layout through operands/result ensure_layout is folded when source and target lane maps match rematerialization clones a cheap broadcast for two different dense layouts @@ -1562,8 +1663,12 @@ Negative tests: ```text assigned ext layout pair where LS % W != 0 and no multi-part relation exists +assigned trunc layout pair where result lane_stride is not compatible with the +narrowing ratio ordinary dense op with mismatched lane_stride operands store consuming strided dense layout without a supported store/materialization +masked_store consuming lane_stride value with a stale contiguous user mask is +rejected or kept on the conservative contiguous path ``` ## 10. Suggested Patch Order @@ -1572,8 +1677,10 @@ store consuming strided dense layout without a supported store/materialization 2. Split dense lane-map physicalization from group-slot carrier packing. 3. Update physical arity/unpack helpers for dense lane stride. 4. Extend support queries and assignment layout keys. -5. Implement widening single-part relation and tests. -6. Implement narrowing relation and tests. +5. Implement widening arity-driven self-preference, single-part relation, and + tests. +6. Implement narrowing inverse relation support, consumer-request handling, and + tests. 7. Teach rematerialization/fold about exact dense lane-map equality. 8. Add broadcast/E2B recognition improvements that consume assigned lane maps. diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 19a7fc712a..7a15864928 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -36,6 +36,17 @@ struct VMIContiguousStoreSupport { VMIContiguousStoreSupportKind::ContiguousVsts; }; +enum class VMIContiguousLoadSupportKind { + ContiguousVlds, + LaneStride2UnpackedVlds, + LaneStride4UnpackedVlds, +}; + +struct VMIContiguousLoadSupport { + VMIContiguousLoadSupportKind kind = + VMIContiguousLoadSupportKind::ContiguousVlds; +}; + enum class VMILayoutMaterializationSupportKind { Identity, ContiguousToDeinterleaved, @@ -223,6 +234,10 @@ class VMILayoutSupport { getContiguousStoreSupport(VMIVRegType valueType, std::string *reason = nullptr) const; + FailureOr + getContiguousLoadSupport(VMIVRegType resultType, + std::string *reason = nullptr) const; + LogicalResult canFoldContiguousStoreMaterialization(VMIVRegType sourceType, VMIVRegType resultType, diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index fcba1d5637..627660ae98 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -463,13 +463,6 @@ struct LayoutSolver { return dataNodes[find(id)].naturalLayout; } - bool hasMaskedStoreUse(Value value) { - for (OpOperand &use : value.getUses()) - if (isa(use.getOwner()) && use.getOperandNumber() == 0) - return true; - return false; - } - bool hasCompatibleTruncFUseForGroupReduce(Value value, int64_t groupSize) { auto sourceType = dyn_cast(value.getType()); if (!sourceType || !sourceType.getElementType().isF32()) @@ -622,7 +615,19 @@ struct LayoutSolver { Operation *definingOp = value.getDefiningOp(); if (!definingOp) return false; - if (!isa(definingOp)) { + if (isa(definingOp)) { + if (requestedLayout && requestedLayout.hasDenseLaneStride()) { + auto type = dyn_cast(value.getType()); + if (!type) + return false; + auto candidateType = + VMIVRegType::get(ctx, type.getElementCount(), type.getElementType(), + requestedLayout); + VMILayoutSupport supports; + if (failed(supports.getContiguousLoadSupport(candidateType))) + return false; + } + } else { if (!requestedLayout || requestedLayout.isContiguous()) return false; if (!canProducerAdoptConsumerLayout(definingOp)) @@ -1278,21 +1283,11 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && hasMaskedStoreUse(truncf.getResult()) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) { - requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); - if (failed(setNaturalLayout(truncf.getResult(), fact->resultLayout, - op))) - return WalkResult::interrupt(); - return WalkResult::advance(); - } VMILayoutAttr resultLayout = getContiguousLayout(); if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || fact->kind == VMICastLayoutKind::Narrow4x)) { - requestDataUse(truncf.getSourceMutable(), getContiguousLayout()); - resultLayout = - VMILayoutAttr::getContiguous(ctx, /*laneStride=*/fact->factor); + resultLayout = fact->resultLayout; + requestDataUse(truncf.getSourceMutable(), fact->sourceLayout); } if (failed(setNaturalLayout(truncf.getResult(), resultLayout, op))) return WalkResult::interrupt(); @@ -1320,21 +1315,11 @@ struct LayoutSolver { return WalkResult::interrupt(); return WalkResult::advance(); } - if (succeeded(fact) && hasMaskedStoreUse(trunci.getResult()) && - (fact->kind == VMICastLayoutKind::Narrow2x || - fact->kind == VMICastLayoutKind::Narrow4x)) { - requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); - if (failed(setNaturalLayout(trunci.getResult(), fact->resultLayout, - op))) - return WalkResult::interrupt(); - return WalkResult::advance(); - } VMILayoutAttr resultLayout = getContiguousLayout(); if (succeeded(fact) && (fact->kind == VMICastLayoutKind::Narrow2x || fact->kind == VMICastLayoutKind::Narrow4x)) { - requestDataUse(trunci.getSourceMutable(), getContiguousLayout()); - resultLayout = - VMILayoutAttr::getContiguous(ctx, /*laneStride=*/fact->factor); + resultLayout = fact->resultLayout; + requestDataUse(trunci.getSourceMutable(), fact->sourceLayout); } if (failed(setNaturalLayout(trunci.getResult(), resultLayout, op))) return WalkResult::interrupt(); diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 4d93fa374c..d2113733d4 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -71,6 +71,16 @@ static bool hasX2MemoryDistToken(Type elementType) { return elementBits == 8 || elementBits == 16 || elementBits == 32; } +static bool hasDenseLaneStride2UnpackedLoad(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8 || elementBits == 16 || elementBits == 32; +} + +static bool hasDenseLaneStride4UnpackedLoad(Type elementType) { + unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); + return elementBits == 8; +} + static bool hasDenseLaneStride2PackedStore(Type elementType) { unsigned elementBits = pto::getPTOStorageElemBitWidth(elementType); return elementBits == 8 || elementBits == 16 || elementBits == 32; @@ -479,8 +489,9 @@ VMILayoutSupport::getPreferredGroupReduceLayoutFact(VMIVRegType sourceType, "2*VLaneElems, 4*VLaneElems, or full physical chunk multiples"); } -FailureOr VMILayoutSupport::getPreferredCastLayoutFact( - VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { +static FailureOr +getBaselineCastLayoutFact(VMIVRegType sourceType, VMIVRegType resultType, + std::string *reason) { auto fail = [&](const Twine &message) -> FailureOr { if (reason) *reason = message.str(); @@ -544,6 +555,69 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( "narrowing dense cast layout facts"); } +FailureOr VMILayoutSupport::getPreferredCastLayoutFact( + VMIVRegType sourceType, VMIVRegType resultType, std::string *reason) const { + FailureOr baseline = + getBaselineCastLayoutFact(sourceType, resultType, reason); + if (failed(baseline)) + return baseline; + + bool isWiden = baseline->kind == VMICastLayoutKind::Widen2x || + baseline->kind == VMICastLayoutKind::Widen4x; + bool isNarrow = baseline->kind == VMICastLayoutKind::Narrow2x || + baseline->kind == VMICastLayoutKind::Narrow4x; + if (!isWiden && !isNarrow) + return baseline; + + MLIRContext *ctx = sourceType.getContext(); + VMILayoutAttr compactSourceLayout = isWiden + ? VMILayoutAttr::getContiguous( + ctx, baseline->factor) + : VMILayoutAttr::getContiguous(ctx); + VMILayoutAttr compactResultLayout = isWiden + ? VMILayoutAttr::getContiguous(ctx) + : VMILayoutAttr::getContiguous( + ctx, baseline->factor); + VMIVRegType compactSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), compactSourceLayout); + VMIVRegType compactResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), compactResultLayout); + FailureOr compactSourceArity = + getVMIPhysicalArity(compactSourceType); + FailureOr compactResultArity = + getVMIPhysicalArity(compactResultType); + if (failed(compactSourceArity) || failed(compactResultArity) || + *compactSourceArity != *compactResultArity) + return baseline; + + if (isWiden) { + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); + FailureOr baselineResultArity = + getVMIPhysicalArity(baselineResultType); + if (failed(baselineResultArity) || + *compactResultArity >= *baselineResultArity) + return baseline; + } else { + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); + FailureOr baselineSourceArity = + getVMIPhysicalArity(baselineSourceType); + if (failed(baselineSourceArity) || + *compactSourceArity >= *baselineSourceArity) + return baseline; + } + + VMICastLayoutFact compact = *baseline; + compact.sourceLayout = compactSourceLayout; + compact.resultLayout = compactResultLayout; + return compact; +} + FailureOr VMILayoutSupport::getWidenSourceLayoutForResultLayout( VMIVRegType sourceType, VMIVRegType resultType, @@ -593,6 +667,40 @@ VMILayoutSupport::getWidenSourceLayoutForResultLayout( return fail("derived source layout factor is unsupported"); } +FailureOr +VMILayoutSupport::getContiguousLoadSupport(VMIVRegType resultType, + std::string *reason) const { + auto fail = [&](const Twine &message) -> FailureOr { + if (reason) + *reason = message.str(); + return failure(); + }; + + VMILayoutAttr layout = resultType.getLayoutAttr(); + if (!layout) + return fail("requires assigned result layout"); + if (!layout.isContiguous()) + return fail("requires contiguous result layout"); + if (layout.getLaneStride() == 1) + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::ContiguousVlds}; + if (layout.getLaneStride() == 2) { + if (!hasDenseLaneStride2UnpackedLoad(resultType.getElementType())) + return fail("requires 8/16/32-bit element type for dense lane_stride=2 " + "unpacked load"); + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::LaneStride2UnpackedVlds}; + } + if (layout.getLaneStride() == 4) { + if (!hasDenseLaneStride4UnpackedLoad(resultType.getElementType())) + return fail("requires 8-bit element type for dense lane_stride=4 " + "unpacked load"); + return VMIContiguousLoadSupport{ + VMIContiguousLoadSupportKind::LaneStride4UnpackedVlds}; + } + return fail("requires lane_stride 1, 2, or 4 for contiguous load"); +} + FailureOr VMILayoutSupport::getContiguousStoreSupport(VMIVRegType valueType, std::string *reason) const { diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index ec6d5e4fac..70c817e799 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -1144,6 +1144,10 @@ FailureOr verifyFullOrSafeReadVRegChunks(Operation *op, return *lanesPerPart; } + lanesPerPart = getDataLanesPerPart(type.getElementType()); + if (succeeded(lanesPerPart)) + return *lanesPerPart; + (void)rewriter.notifyMatchFailure( op, Twine("memory lowering ") + fullChunkReason + "; safe full-read proof failed: " + safeReadProof.reason); @@ -1170,16 +1174,9 @@ checkSupportedLoadShape(const VMITargetCapabilityRegistry &capabilities, if (getDenseLaneStrideLoadDistToken(type)) return success(); - std::string fullChunkReason; - if (succeeded(checkFullDataPhysicalChunks(type, &fullChunkReason))) - return success(); - - if (accessPlan.safeReadProof.proven) - return success(); - requireUnavailableReadFallback(accessPlan); - return fail(Twine(fullChunkReason) + - "; safe-read proof failed: " + accessPlan.safeReadProof.reason + - "; fallback decision: " + accessPlan.fallbackDecision.reason); + if (failed(getDataLanesPerPart(type.getElementType()))) + return fail("requires element type with known physical lane width"); + return success(); } LogicalResult checkSupportedDeinterleaveLoadShape( @@ -9313,8 +9310,7 @@ verifySupportedVMIToVPTOOps(ModuleOp module, op->emitError() << kVMIDiagUnsupportedPrefix << opName - << " requires full physical chunks without padding lanes or a " - "statically safe full-read footprint (" + << " direct lowering requires a supported memory source (" << reason << ")"; return WalkResult::interrupt(); }; diff --git a/test/lit/vmi/opt/README.md b/test/lit/vmi/opt/README.md new file mode 100644 index 0000000000..c169dcee08 --- /dev/null +++ b/test/lit/vmi/opt/README.md @@ -0,0 +1,18 @@ + + +# VMI Optimization Shape Guards + +This directory contains end-to-end VMI optimization capability tests. + +These tests intentionally check generated VPTO instruction shape for representative +kernels. They are not generic correctness tests: a failure means the VMI pipeline +has likely regressed an optimization contract and should not be updated away +without an explicit replacement optimization decision. diff --git a/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto b/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto new file mode 100644 index 0000000000..21c0f52c2f --- /dev/null +++ b/test/lit/vmi/opt/compute_mrope_f16_vmi_opt.pto @@ -0,0 +1,134 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --implicit-check-not=pto.vdintlv --implicit-check-not=pto.vintlv --implicit-check-not='part = "ODD"' + +// This is an optimization capability guard for the VMI ComputeMropeF16 path. +// Do not weaken the checks when the output shape regresses. The intended shape is: +// - 64xf16 loads feeding f16->f32 ext lower as UNPK_B16 loads plus one EVEN vcvt; +// - no ODD vcvt, vintlv, or vdintlv is needed for those ext paths; +// - f32->f16 trunc feeding masked_store lowers through EVEN vcvt plus vpack. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeMropeF16( + %x_ub_u: !pto.ptr, + %y_ub_u: !pto.ptr, + %cs_ub_u: !pto.ptr, + %curTokens: i32, + %num_heads: i32, + %num_heads_max: i32, + %head_size: i32, + %rotary_dim: i32, + %headAlign_fp16: i32, + %is_neox: i1) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2_i32 = arith.constant 2 : i32 + + %x_ub = pto.castptr %x_ub_u : !pto.ptr -> !pto.ptr + %y_ub = pto.castptr %y_ub_u : !pto.ptr -> !pto.ptr + %cs_ub = pto.castptr %cs_ub_u : !pto.ptr -> !pto.ptr + + %cur_tokens = arith.index_cast %curTokens : i32 to index + %num_heads_idx = arith.index_cast %num_heads : i32 to index + %num_heads_max_idx = arith.index_cast %num_heads_max : i32 to index + %head_align_idx = arith.index_cast %headAlign_fp16 : i32 to index + %rotary_dim_idx = arith.index_cast %rotary_dim : i32 to index + + %half_rotary_i32 = arith.divui %rotary_dim, %c2_i32 : i32 + %suffix_len_i32 = arith.subi %head_size, %rotary_dim : i32 + %has_suffix = arith.cmpi sgt, %head_size, %rotary_dim : i32 + + %half_rotary = arith.index_cast %half_rotary_i32 : i32 to index + %suffix_len = arith.index_cast %suffix_len_i32 : i32 to index + %half_mask = pto.vmi.create_mask %half_rotary : index -> !pto.vmi.mask<64xpred> + %suffix_mask = pto.vmi.create_mask %suffix_len : index -> !pto.vmi.mask<64xpred> + + %head_stride = arith.muli %num_heads_max_idx, %head_align_idx : index + + pto.vecscope { + scf.if %is_neox { + scf.for %ti = %c0 to %cur_tokens step %c1 { + %cs_off = arith.muli %ti, %rotary_dim_idx : index + %token_base = arith.muli %ti, %head_stride : index + %cs_sin_base = arith.addi %cs_off, %half_rotary : index + + scf.for %h = %c0 to %num_heads_idx step %c1 { + %head_off = arith.muli %h, %head_align_idx : index + %x_base = arith.addi %token_base, %head_off : index + %x2_base = arith.addi %x_base, %half_rotary : index + + %x1 = pto.vmi.load %x_ub[%x_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x2 = pto.vmi.load %x_ub[%x2_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %cos = pto.vmi.load %cs_ub[%cs_off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %sin = pto.vmi.load %cs_ub[%cs_sin_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + + %x1_f32 = pto.vmi.extf %x1 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %x2_f32 = pto.vmi.extf %x2 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %cos_f32 = pto.vmi.extf %cos + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %sin_f32 = pto.vmi.extf %sin + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + + %x1_cos = pto.vmi.mulf %x1_f32, %cos_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %x2_sin = pto.vmi.mulf %x2_f32, %sin_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out1_f32 = pto.vmi.subf %x1_cos, %x2_sin + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out1 = pto.vmi.truncf %out1_f32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %out1, %y_ub[%x_base], %half_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + + %x2_cos = pto.vmi.mulf %x2_f32, %cos_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %x1_sin = pto.vmi.mulf %x1_f32, %sin_f32 + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out2_f32 = pto.vmi.addf %x2_cos, %x1_sin + : !pto.vmi.vreg<64xf32>, !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf32> + %out2 = pto.vmi.truncf %out2_f32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %out2, %y_ub[%x2_base], %half_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + + scf.if %has_suffix { + %suffix_base = arith.addi %x_base, %rotary_dim_idx : index + %suffix = pto.vmi.load %x_ub[%suffix_base] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %suffix, %y_ub[%suffix_base], %suffix_mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + } + } + } + } + } + return + } +} + +// CHECK-LABEL: func.func @ComputeMropeF16 +// CHECK-COUNT-4: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK-COUNT-4: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vmul +// CHECK: pto.vsub +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vpack +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: pto.vadd +// CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vpack +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<128xf16> diff --git a/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto b/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto new file mode 100644 index 0000000000..448c4d524d --- /dev/null +++ b/test/lit/vmi/opt/compute_y1_to_fp8_fp16_vmi_opt.pto @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --implicit-check-not=pto.vdintlv --implicit-check-not=pto.vintlv --implicit-check-not=pto.vpack + +// This is an optimization capability guard for the VMI ComputeY1ToFP8 FP16 path. +// Do not weaken the checks when the output shape regresses. The intended shape is: +// - one E2B_B16 scale load per kernel, outside the block loop; +// - DINTLV_B16 x loads inside the block loop; +// - four f32 streams quantized through P0/P1/P2/P3 and merged by vor; +// - contiguous fp8 store, with no extra register layout materialization. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeY1ToFP8_fp16_e4m3_vmi( + %dataLen: i16, + %blockCount: i16, + %xAddr: !pto.ptr, + %mxScale1ReciprocalAddr: !pto.ptr, + %y1Addr: !pto.ptr, + %ubBlockSize: i16, + %vlForHalfNumber: i16) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %block_count = arith.index_cast %blockCount : i16 to index + %vl_half = arith.index_cast %vlForHalfNumber : i16 to index + %load_stride_y8 = arith.muli %vl_half, %c2 : index + + pto.vecscope { + %scale_f16 = pto.vmi.group_slot_load %mxScale1ReciprocalAddr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf16> + %scale_f16_vec = pto.vmi.group_broadcast %scale_f16 {num_groups = 8} + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<256xf16> + %scale_fp32 = pto.vmi.extf %scale_f16_vec + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + + scf.for %i = %c0 to %block_count step %c1 { + %x_off = arith.muli %i, %load_stride_y8 : index + %y_off = arith.muli %i, %load_stride_y8 : index + + %x_f16 = pto.vmi.load %xAddr[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x_fp32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %res_fp32 = pto.vmi.mulf %x_fp32, %scale_fp32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + + %res_fp8 = pto.vmi.truncf %res_fp32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> + + pto.vmi.store %res_fp8, %y1Addr[%y_off] + : !pto.vmi.vreg<256xf8E4M3FN>, !pto.ptr + } + } + return + } + + func.func @ComputeY1ToFP8_fp16_e5m2_vmi( + %dataLen: i16, + %blockCount: i16, + %xAddr: !pto.ptr, + %mxScale1ReciprocalAddr: !pto.ptr, + %y1Addr: !pto.ptr, + %ubBlockSize: i16, + %vlForHalfNumber: i16) + attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %block_count = arith.index_cast %blockCount : i16 to index + %vl_half = arith.index_cast %vlForHalfNumber : i16 to index + %load_stride_y8 = arith.muli %vl_half, %c2 : index + + pto.vecscope { + %scale_f16 = pto.vmi.group_slot_load %mxScale1ReciprocalAddr[%c0], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf16> + %scale_f16_vec = pto.vmi.group_broadcast %scale_f16 {num_groups = 8} + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<256xf16> + %scale_fp32 = pto.vmi.extf %scale_f16_vec + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + + scf.for %i = %c0 to %block_count step %c1 { + %x_off = arith.muli %i, %load_stride_y8 : index + %y_off = arith.muli %i, %load_stride_y8 : index + + %x_f16 = pto.vmi.load %xAddr[%x_off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x_fp32 = pto.vmi.extf %x_f16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %res_fp32 = pto.vmi.mulf %x_fp32, %scale_fp32 + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + + %res_fp8 = pto.vmi.truncf %res_fp32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E5M2> + + pto.vmi.store %res_fp8, %y1Addr[%y_off] + : !pto.vmi.vreg<256xf8E5M2>, !pto.ptr + } + } + return + } +} + +// CHECK-LABEL: func.func @ComputeY1ToFP8_fp16_e4m3_vmi +// CHECK: pto.vlds {{.*}} {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B16" : !pto.ptr, index -> !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "ODD"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + +// CHECK-LABEL: func.func @ComputeY1ToFP8_fp16_e5m2_vmi +// CHECK: pto.vlds {{.*}} {dist = "E2B_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// CHECK: scf.for +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B16" : !pto.ptr, index -> !pto.vreg<128xf16>, !pto.vreg<128xf16> +// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E5M2>, !pto.ptr, !pto.mask diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto index b16082f89a..25e135a14a 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -23,6 +23,36 @@ module { return } + func.func @vmi_layout_assignment_compact_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<64xf32>, !pto.ptr + return + } + + func.func @vmi_layout_assignment_compact_f16_to_f32_masked_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %c64 = arith.constant 64 : index + %mask = pto.vmi.create_mask %c64 : index -> !pto.vmi.mask<64xpred> + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<64xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<64xf16> -> !pto.vmi.vreg<64xf32> + %y16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<64xf32> -> !pto.vmi.vreg<64xf16> + pto.vmi.masked_store %y16, %dst[%off], %mask + : !pto.vmi.vreg<64xf16>, !pto.ptr, !pto.vmi.mask<64xpred> + return + } + func.func @vmi_layout_assignment_dense_f32_to_f16_store( %src: !pto.ptr, %dst: !pto.ptr, @@ -57,21 +87,55 @@ module { // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast +// ASSIGN-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: pto.vmi.store %[[X32]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf32, #pto.vmi.layout>, !pto.ptr + +// LOWER-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_store( +// LOWER: pto.vlds {{.*}} {dist = "UNPK_B16"} : !pto.ptr -> !pto.vreg<128xf16> +// LOWER: pto.vcvt {{.*}} {part = "EVEN"} : !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<64xf32> +// LOWER-NOT: part = "ODD" +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast + +// ASSIGN-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_masked_store( +// ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask +// ASSIGN-SAME: -> !pto.vmi.mask<64xb32, #pto.vmi.layout> +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf32, #pto.vmi.layout> +// ASSIGN: %[[Y16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[Y16C:.*]] = pto.vmi.ensure_layout %[[Y16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> +// ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_granularity %[[MASK0]] +// ASSIGN-SAME: -> !pto.vmi.mask<64xb16, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[Y16C]] +// ASSIGN-SAME: !pto.vmi.vreg<64xf16, #pto.vmi.layout>, !pto.ptr, !pto.vmi.mask<64xb16, #pto.vmi.layout> + // ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( // ASSIGN: %[[X32:.*]] = pto.vmi.load -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN-NOT: pto.vmi.ensure_layout // ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[X16]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( -// LOWER: pto.vlds +// LOWER: pto.vldsx2 {{.*}} "DINTLV_B32" // LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER-NOT: part = "ODD" -// LOWER-NOT: pto.vor -// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto index 550eed8f28..6ef79fb5f1 100644 --- a/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto +++ b/test/lit/vmi/vmi_layout_assignment_f32_f8_store_reduce.pto @@ -35,32 +35,30 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( // ASSIGN: %[[X32:.*]] = pto.vmi.load -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK0:.*]] = pto.vmi.create_mask // ASSIGN-SAME: -> !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[X32_REDUCE:.*]] = pto.vmi.ensure_layout %[[X32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[MASK:.*]] = pto.vmi.ensure_mask_layout %[[MASK0]] // ASSIGN-SAME: !pto.vmi.mask<256xb32, #pto.vmi.layout> -> !pto.vmi.mask<256xb32, #pto.vmi.layout> -// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32_REDUCE]], %[[MASK]] +// ASSIGN: %[[SUM:.*]] = pto.vmi.group_reduce_addf %[[X32]], %[[MASK]] // ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.group_store %[[SUM]] // ASSIGN: %[[X8:.*]] = pto.vmi.truncf %[[X32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[X8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f32_f8_store_reduce( -// LOWER-COUNT-4: pto.vlds -// LOWER-COUNT-4: pto.vdintlv +// LOWER-COUNT-2: pto.vldsx2 {{.*}} "DINTLV_B32" +// LOWER-COUNT-2: pto.vdintlv // LOWER-COUNT-4: pto.vcgadd // LOWER-COUNT-3: pto.vadd // LOWER: pto.vsts // LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} -// LOWER-NOT: part = "P1" -// LOWER-NOT: part = "P2" -// LOWER-NOT: part = "P3" -// LOWER-NOT: pto.vor -// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto index 4b7496a8f8..7e1d1b293f 100644 --- a/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto +++ b/test/lit/vmi/vmi_layout_assignment_f8_compute_f8.pto @@ -41,10 +41,8 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Y32:.*]] = pto.vmi.mulf %[[X32]], %[[SCALE]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Y32_CONTIG:.*]] = pto.vmi.ensure_layout %[[Y32]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32_CONTIG]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[Y8:.*]] = pto.vmi.truncf %[[Y32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[Y8]] // LOWER-LABEL: func.func @vmi_layout_assignment_f8_compute_f8( @@ -53,11 +51,11 @@ module { // LOWER-COUNT-4: pto.vdup // LOWER-COUNT-4: pto.vmul // LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} -// LOWER-NOT: part = "P1" -// LOWER-NOT: part = "P2" -// LOWER-NOT: part = "P3" -// LOWER-NOT: pto.vor -// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto index 2118b0f5d2..400c10093a 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_broadcast_multi_consumer.pto @@ -55,8 +55,10 @@ module { // ASSIGN: pto.vmi.group_store %[[YSUM]] // ASSIGN: %[[B_CAST:.*]] = pto.vmi.group_broadcast %[[SUM]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST]] -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[B_CAST_SPLIT:.*]] = pto.vmi.ensure_layout %[[B_CAST]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[H:.*]] = pto.vmi.truncf %[[B_CAST_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[H]] // LOWER-LABEL: func.func @vmi_layout_assignment_group_broadcast_multi_consumer( @@ -71,7 +73,7 @@ module { // LOWER: pto.vselr // LOWER: pto.vselr // LOWER: pto.vcvt -// LOWER-NOT: pto.vor -// LOWER: pto.vsts {{.*}} {dist = "PK_B32"} +// LOWER: pto.vor +// LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto index 712eddda7e..2953ab1989 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_load_block8_truncf.pto @@ -38,9 +38,11 @@ module { // CHECK: pto.vsldb // CHECK: pto.vcgadd // CHECK: pto.vintlv -// CHECK-NOT: pto.vdintlv +// CHECK: pto.vdintlv // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// CHECK: pto.vor +// CHECK: pto.vsts // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto index 685987b72c..fddd344cf6 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_maxf_quant.pto @@ -61,17 +61,23 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> // ASSIGN: %[[Q:.*]] = pto.vmi.divf %[[X]], %[[SCALE_VEC]] // ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> -// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q]] -// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[Q_SPLIT:.*]] = pto.vmi.ensure_layout %[[Q]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q_SPLIT]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf8E4M3FN, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_maxf_quant( // LOWER: pto.vcgmax // LOWER: pto.vmax // LOWER: pto.vsel // LOWER: pto.vdiv -// LOWER-NOT: pto.vdintlv +// LOWER: pto.vdintlv // LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} -// LOWER: pto.vsts {{.*}} {dist = "PK4_B32"} +// LOWER: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto index d3a8f54b87..2e53e6d7a3 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store.pto @@ -39,10 +39,12 @@ module { // ASSIGN-SAME: -> !pto.vmi.vreg<8xf32, #pto.vmi.layout> // ASSIGN: %[[B32:.*]] = pto.vmi.group_broadcast %[[SUM32]] // ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[B32_SPLIT:.*]] = pto.vmi.ensure_layout %[[B32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[B16:.*]] = pto.vmi.truncf %[[B32_SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[B16]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr // LOWER-LABEL: func.func @vmi_layout_assignment_group_reduce_s16_truncf_broadcast_store( // LOWER: pto.vcgadd @@ -51,6 +53,7 @@ module { // LOWER: pto.vselr // LOWER: pto.vselr // LOWER: pto.vcvt +// LOWER: pto.vor // LOWER: pto.vsts // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto index 689e7b0836..9d64cffec2 100644 --- a/test/lit/vmi/vmi_layout_assignment_load_truncf.pto +++ b/test/lit/vmi/vmi_layout_assignment_load_truncf.pto @@ -36,34 +36,40 @@ module { } // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[WIDE:.*]] = pto.vmi.load -// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN-NOT: pto.vmi.ensure_layout // ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf( -// LOWER: %[[P0:.*]] = pto.vlds %arg0[%arg1] -// LOWER: %[[NARROW:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER: return %[[NARROW]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vldsx2 %arg0[%arg1], "DINTLV_B32" +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[LOW]], {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} +// LOWER: %[[ODD:.*]] = pto.vcvt %[[HIGH]], {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. // ASSIGN-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( -// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: %[[WIDE:.*]] = pto.vmi.load // ASSIGN-SAME: !pto.ptr -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> // ASSIGN: pto.vmi.store %[[WIDE]] // ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> -// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: return %[[NARROW]] : !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_load_truncf_multi_use( // LOWER: pto.vsts // LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} -// LOWER: return {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> +// LOWER: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} +// LOWER: pto.vor +// LOWER: return {{.*}} : !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto index 4e9b2885fd..9d2b6e35ea 100644 --- a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto @@ -27,4 +27,4 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout +// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout diff --git a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto index bbf5afe97d..8908036648 100644 --- a/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto +++ b/test/lit/vmi/vmi_layout_assignment_truncf_ensure.pto @@ -20,15 +20,20 @@ module { // ASSIGN-LABEL: func.func @vmi_layout_assignment_truncf_ensure( // ASSIGN-SAME: %[[WIDE:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -// ASSIGN-NOT: pto.vmi.ensure_layout -// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[WIDE]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN: %[[SPLIT:.*]] = pto.vmi.ensure_layout %[[WIDE]] +// ASSIGN-SAME: -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[NARROW:.*]] = pto.vmi.truncf %[[SPLIT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf16, #pto.vmi.layout> // ASSIGN: return %[[NARROW]] -// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> +// ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout> // LOWER-LABEL: func.func @vmi_layout_assignment_truncf_ensure( // LOWER-SAME: %[[D0:arg[0-9]+]]: !pto.vreg<64xf32> -// LOWER: %[[NARROW:.*]] = pto.vcvt %[[D0]]{{.*}}part = "EVEN" -// LOWER: return %[[NARROW]], {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16> +// LOWER-SAME: %[[D1:arg[0-9]+]]: !pto.vreg<64xf32> +// LOWER: %[[LOW:.*]], %[[HIGH:.*]] = pto.vdintlv %[[D0]], %[[D1]] +// LOWER: %[[EVEN:.*]] = pto.vcvt %[[LOW]]{{.*}}part = "EVEN" +// LOWER: %[[ODD:.*]] = pto.vcvt %[[HIGH]]{{.*}}part = "ODD" +// LOWER: %[[NARROW:.*]] = pto.vor %[[EVEN]], %[[ODD]] +// LOWER: return %[[NARROW]] : !pto.vreg<128xf16> // LOWER-NOT: pto.vmi. // LOWER-NOT: !pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull.pto similarity index 62% rename from test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto rename to test/lit/vmi/vmi_to_vpto_load_nonfull.pto index f87e3753ca..edb8f88cf3 100644 --- a/test/lit/vmi/vmi_to_vpto_load_nonfull_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_load_nonfull_invalid( + func.func @vmi_to_vpto_load_nonfull( %src: !pto.ptr, %offset: index) -> (!pto.vreg<64xf32>) { %value = pto.vmi.load %src[%offset] @@ -20,8 +20,8 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint -// CHECK-SAME: safe-read proof failed: requires constant index offset -// CHECK-SAME: fallback decision: partial/tail read needs a scratch, guarded, or true masked/non-faulting load fallback -// CHECK-SAME: scratch memory fallback resource allocation is not implemented -// CHECK-SAME: guarded memory fallback control-flow lowering is not implemented +// CHECK-LABEL: func.func @vmi_to_vpto_load_nonfull( +// CHECK: pto.vlds %arg0[%arg1] : !pto.ptr -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto similarity index 67% rename from test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto rename to test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto index 07975ea70d..d1e1f27c94 100644 --- a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_load_nonfull_memref.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_load_safe_tail_memref_invalid(%src: memref<100xf32>) + func.func @vmi_to_vpto_load_nonfull_memref(%src: memref<100xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { %c0 = arith.constant 0 : index %value = pto.vmi.load %src[%c0] @@ -21,5 +21,11 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint -// CHECK-SAME: safe-read proof failed: full physical read footprint [0, 128) exceeds static memref element count 100 +// CHECK-LABEL: func.func @vmi_to_vpto_load_nonfull_memref( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: pto.vlds %arg0[%[[C0]]] : memref<100xf32> -> !pto.vreg<64xf32> +// CHECK: pto.vlds %arg0[%[[C64]]] : memref<100xf32> -> !pto.vreg<64xf32> +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto similarity index 74% rename from test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto rename to test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto index 863c2b4fa5..b444c3d1a8 100644 --- a/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_load_safe_tail_memref_negative_offset.pto @@ -6,10 +6,10 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: not pto-test-opt %s -vmi-to-vpto 2>&1 | FileCheck %s +// RUN: pto-test-opt %s -vmi-to-vpto | FileCheck %s module { - func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset_invalid(%src: memref<132xf32>) + func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset(%src: memref<132xf32>) -> (!pto.vreg<64xf32>, !pto.vreg<64xf32>) { %cm1 = arith.constant -1 : index %value = pto.vmi.load %src[%cm1] @@ -21,5 +21,9 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint -// CHECK-SAME: safe-read proof failed: requires non-negative offset +// CHECK-LABEL: func.func @vmi_to_vpto_load_safe_tail_memref_negative_offset( +// CHECK: pto.vlds +// CHECK: pto.vlds +// CHECK-NOT: pto.vmi. +// CHECK-NOT: !pto.vmi. +// CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto index 8d1485d965..a1749cffe1 100644 --- a/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_memory_space_invalid.pto @@ -16,7 +16,7 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source // CHECK-SAME: source is GM-backed // CHECK-SAME: requires UB-backed memory diff --git a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto index c3483450bc..ba91366ff2 100644 --- a/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_memref_layout_invalid.pto @@ -19,7 +19,7 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source // CHECK-SAME: source memref layout is non-identity // CHECK-SAME: contiguous identity lane-to-address maps @@ -37,7 +37,7 @@ module { } } -// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load requires full physical chunks without padding lanes or a statically safe full-read footprint +// CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.load direct lowering requires a supported memory source // CHECK-SAME: source memref layout is non-identity // CHECK-SAME: contiguous identity lane-to-address maps // CHECK-SAME: memref.subview requires normalized base/offset/stride lane-to-address planning diff --git a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto index e596488dc3..e6dad5963a 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_dequant.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_dequant.pto @@ -252,12 +252,12 @@ module { // CHECK-SAME: %[[QDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> -// CHECK-NOT: part = "ODD" -// CHECK-NOT: pto.vor -// CHECK: pto.vsts {{.*}} {dist = "PK_B32"} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask +// CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vor {{.*}} : !pto.vreg<128xf16>, !pto.vreg<128xf16>, !pto.mask -> !pto.vreg<128xf16> +// CHECK: pto.vsts {{.*}} : !pto.vreg<128xf16>, !pto.ptr, !pto.mask // CHECK: scf.if // CHECK: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> // CHECK: pto.vcvt {{.*}} {part = "ODD", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<128xf16> @@ -291,11 +291,13 @@ module { // CHECK-SAME: %[[FQDST:[^,]+]]: !pto.ptr // CHECK: scf.for // CHECK: scf.for -// CHECK-COUNT-4: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK-NOT: pto.vdintlv +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vldsx2 {{.*}} "DINTLV_B32" : !pto.ptr, index -> !pto.vreg<64xf32>, !pto.vreg<64xf32> +// CHECK: pto.vdintlv +// CHECK: pto.vdintlv // CHECK: pto.vmul {{.*}} : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK: scf.if // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> // CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> diff --git a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto index 7ecee61a27..93c822a9f3 100644 --- a/test/lit/vmi/vmi_to_vpto_quant_fp8.pto +++ b/test/lit/vmi/vmi_to_vpto_quant_fp8.pto @@ -30,14 +30,14 @@ module { } // CHECK-LABEL: func.func @vmi_to_vpto_quant_matrix_f32_to_fp8( -// CHECK-COUNT-4: pto.vlds {{.*}} : !pto.ptr -> !pto.vreg<64xf32> -// CHECK-NOT: pto.vdintlv +// CHECK-COUNT-2: pto.vldsx2 {{.*}} "DINTLV_B32" +// CHECK: pto.vdintlv // CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK-NOT: part = "P1" -// CHECK-NOT: part = "P2" -// CHECK-NOT: part = "P3" -// CHECK-NOT: pto.vor -// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask +// CHECK: pto.vcvt {{.*}} {part = "P1", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vcvt {{.*}} {part = "P3", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> +// CHECK: pto.vor +// CHECK: pto.vsts {{.*}} : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast From 1b13d3cb98f8a633fb707253cc1ab0b65256e3b3 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Tue, 30 Jun 2026 23:58:06 +0800 Subject: [PATCH 48/54] Clarify PTO-Gym validation skill scope --- .codex/skills/pto-gym-vpto-validation/SKILL.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.codex/skills/pto-gym-vpto-validation/SKILL.md b/.codex/skills/pto-gym-vpto-validation/SKILL.md index 0e1451a614..721df32b6f 100644 --- a/.codex/skills/pto-gym-vpto-validation/SKILL.md +++ b/.codex/skills/pto-gym-vpto-validation/SKILL.md @@ -1,6 +1,6 @@ --- name: pto-gym-vpto-validation -description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks to run PTO-Gym SIM or board validation from the current source tree. Always force PTOAS onto the VPTO LLVM path instead of relying on the repo default backend. +description: Run bundled PTO-Gym exercise/validation cases. Use when the user explicitly asks for PTO-Gym, 3rdparty/PTO-Gym, or the PTO-Gym validation scripts. Always force PTOAS onto the VPTO path instead of relying on the repo default backend. --- # PTO-Gym VPTO Validation @@ -8,20 +8,20 @@ description: Run PTO-Gym validation from this PTOAS repo. Use when the user asks Use this skill when the task is specifically about: - running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation.sh` - running `3rdparty/PTO-Gym/examples/pto/scripts/run_host_vpto_validation_parallel.sh` -- validating PTO-Gym cases from this PTOAS source tree +- validating bundled PTO-Gym exercise cases ## Required Rule When PTO-Gym is run from this repo, do not rely on the default PTOAS backend. Always pass PTOAS flags that force the VPTO LLVM path. -The current `ptoas` CLI spellings in this repo are `--pto-backend=vpto` and -`--vpto-emit-hivm-llvm`; do not shorten `--pto-backend` to `--backend`. +The current `ptoas` CLI spelling in this repo is `--pto-backend=vpto`; do not +shorten `--pto-backend` to `--backend`. Use: ```bash -PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +PTOAS_FLAGS='--pto-backend=vpto --pto-arch a5' ``` If the caller already provides `PTOAS_FLAGS`, make sure these options are still @@ -44,7 +44,7 @@ Typical simulator environment: source /home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2/set_env.sh export ASCEND_HOME_PATH=/home/mouliangyu/.local/ascend/beta.2/cann-9.0.0-beta.2 export PTOAS_BIN=$PWD/build/tools/ptoas/ptoas -export PTOAS_FLAGS='--pto-backend=vpto --vpto-emit-hivm-llvm --pto-arch a5' +export PTOAS_FLAGS='--pto-backend=vpto --pto-arch a5' ``` ## Canonical Commands From ab5c8839aee56b6586641d74cb1c8b79cda91dc8 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 1 Jul 2026 00:50:08 +0800 Subject: [PATCH 49/54] Fix arity-driven VMI cast layout selection --- lib/PTO/Transforms/VMILayoutSupport.cpp | 21 ++++++++- ..._layout_assignment_dense_f16_f32_store.pto | 45 +++++++++++++++++++ ...signment_multi_return_conflict_invalid.pto | 2 +- test/lit/vmi/vmi_ptoas_cli_pipeline.pto | 10 ++--- ...vpto_truncf_fp8_128_contiguous_invalid.pto | 13 ++---- 5 files changed, 71 insertions(+), 20 deletions(-) diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index d2113733d4..fff22f6528 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -593,21 +593,33 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( return baseline; if (isWiden) { + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); VMIVRegType baselineResultType = VMIVRegType::get(ctx, resultType.getElementCount(), resultType.getElementType(), baseline->resultLayout); + FailureOr baselineSourceArity = + getVMIPhysicalArity(baselineSourceType); FailureOr baselineResultArity = getVMIPhysicalArity(baselineResultType); - if (failed(baselineResultArity) || + if (failed(baselineSourceArity) || failed(baselineResultArity) || + *compactSourceArity > *baselineSourceArity || *compactResultArity >= *baselineResultArity) return baseline; } else { VMIVRegType baselineSourceType = VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), baseline->sourceLayout); + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); FailureOr baselineSourceArity = getVMIPhysicalArity(baselineSourceType); - if (failed(baselineSourceArity) || + FailureOr baselineResultArity = + getVMIPhysicalArity(baselineResultType); + if (failed(baselineSourceArity) || failed(baselineResultArity) || + *compactResultArity > *baselineResultArity || *compactSourceArity >= *baselineSourceArity) return baseline; } @@ -647,6 +659,11 @@ VMILayoutSupport::getWidenSourceLayoutForResultLayout( return fail("requires supported 8/16-bit to 32-bit widen cast"); if (requestedResultLayout.isContiguous()) { + if (!fact->resultLayout.isContiguous() || + fact->resultLayout.getLaneStride() != + requestedResultLayout.getLaneStride()) + return fail("requested contiguous result layout is not the natural " + "compact widen result layout"); return VMILayoutAttr::getContiguous(sourceType.getContext(), /*laneStride=*/fact->factor); } diff --git a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto index 25e135a14a..7238c10a42 100644 --- a/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto +++ b/test/lit/vmi/vmi_layout_assignment_dense_f16_f32_store.pto @@ -23,6 +23,19 @@ module { return } + func.func @vmi_layout_assignment_multichunk_f16_to_f32_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x16 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf16> + %x32 = pto.vmi.extf %x16 + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + pto.vmi.store %x32, %dst[%off] + : !pto.vmi.vreg<256xf32>, !pto.ptr + return + } + func.func @vmi_layout_assignment_compact_f16_to_f32_store( %src: !pto.ptr, %dst: !pto.ptr, @@ -65,6 +78,19 @@ module { : !pto.vmi.vreg<128xf16>, !pto.ptr return } + + func.func @vmi_layout_assignment_multichunk_f32_to_f16_store( + %src: !pto.ptr, + %dst: !pto.ptr, + %off: index) { + %x32 = pto.vmi.load %src[%off] + : !pto.ptr -> !pto.vmi.vreg<256xf32> + %x16 = pto.vmi.truncf %x32 + : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf16> + pto.vmi.store %x16, %dst[%off] + : !pto.vmi.vreg<256xf16>, !pto.ptr + return + } } // ASSIGN-LABEL: func.func @vmi_layout_assignment_dense_f16_to_f32_store( @@ -87,6 +113,16 @@ module { // LOWER-NOT: !pto.vmi. // LOWER-NOT: unrealized_conversion_cast +// ASSIGN-LABEL: func.func @vmi_layout_assignment_multichunk_f16_to_f32_store( +// ASSIGN: %[[X16:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// ASSIGN: %[[X32:.*]] = pto.vmi.extf %[[X16]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: %[[DENSE:.*]] = pto.vmi.ensure_layout %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[DENSE]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf32, #pto.vmi.layout>, !pto.ptr + // ASSIGN-LABEL: func.func @vmi_layout_assignment_compact_f16_to_f32_store( // ASSIGN: %[[X16:.*]] = pto.vmi.load // ASSIGN-SAME: -> !pto.vmi.vreg<64xf16, #pto.vmi.layout> @@ -130,6 +166,15 @@ module { // ASSIGN: pto.vmi.store %[[X16]] // ASSIGN-SAME: !pto.vmi.vreg<128xf16, #pto.vmi.layout>, !pto.ptr +// ASSIGN-LABEL: func.func @vmi_layout_assignment_multichunk_f32_to_f16_store( +// ASSIGN: %[[X32:.*]] = pto.vmi.load +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf32, #pto.vmi.layout> +// ASSIGN-NOT: pto.vmi.ensure_layout +// ASSIGN: %[[X16:.*]] = pto.vmi.truncf %[[X32]] +// ASSIGN-SAME: -> !pto.vmi.vreg<256xf16, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[X16]] +// ASSIGN-SAME: !pto.vmi.vreg<256xf16, #pto.vmi.layout>, !pto.ptr + // LOWER-LABEL: func.func @vmi_layout_assignment_dense_f32_to_f16_store( // LOWER: pto.vldsx2 {{.*}} "DINTLV_B32" // LOWER: pto.vcvt {{.*}} {part = "EVEN", rnd = "R", sat = "SAT"} diff --git a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto index 9d2b6e35ea..4e9b2885fd 100644 --- a/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto +++ b/test/lit/vmi/vmi_layout_assignment_multi_return_conflict_invalid.pto @@ -27,4 +27,4 @@ module { } } -// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout +// CHECK: VMI-LAYOUT-CONTRACT: conflicting natural layouts #pto.vmi.layout and #pto.vmi.layout diff --git a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto index 584c889233..0a136e8c0f 100644 --- a/test/lit/vmi/vmi_ptoas_cli_pipeline.pto +++ b/test/lit/vmi/vmi_ptoas_cli_pipeline.pto @@ -48,14 +48,10 @@ module attributes {pto.target_arch = "a5"} { // CHECK-NOT: unrealized_conversion_cast // CHECK-LABEL: func.func @vmi_ptoas_cli_fold_pipeline -// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} -// CHECK: pto.vlds {{.*}} {dist = "UNPK_B16"} +// CHECK: pto.vlds // CHECK: pto.vcvt {{.*}} {part = "EVEN"} -// CHECK: pto.vcvt {{.*}} {part = "EVEN"} -// CHECK-NOT: part = "ODD" -// CHECK-NOT: pto.vintlv -// CHECK: pto.vsts -// CHECK: pto.vsts +// CHECK: pto.vcvt {{.*}} {part = "ODD"} +// CHECK: pto.vstsx2 // CHECK-NOT: pto.vmi. // CHECK-NOT: !pto.vmi. // CHECK-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto index 1a0f498140..f4c208aeda 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s module { func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( @@ -21,12 +21,5 @@ module { } } -// CHECK-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( -// CHECK-SAME: %[[P0:.*]]: !pto.vreg<64xf32>, %[[P1:.*]]: !pto.vreg<64xf32> -// CHECK: %[[R0:.*]] = pto.vcvt %[[P0]], {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: %[[R1:.*]] = pto.vcvt %[[P1]], {{.*}} {part = "P0", rnd = "R", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256xf8E4M3FN> -// CHECK: pto.vsts %[[R0]], {{.*}} {dist = "PK4_B32"} -// CHECK: pto.vsts %[[R1]], {{.*}} {dist = "PK4_B32"} -// CHECK-NOT: pto.vmi. -// CHECK-NOT: !pto.vmi. -// CHECK-NOT: unrealized_conversion_cast +// CHECK: VMI{{-}}UNSUP{{P}}ORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' +// CHECK: pto.vmi.ensure_layout cannot materialize this conversion From de78a1a6c36a7cc68a7d96bfb6d09093677bf14e Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 1 Jul 2026 01:44:13 +0800 Subject: [PATCH 50/54] test(vmi): use group-slot result shapes in runtime cases --- .../broadcast-dense-group-users/kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../vmi/f32-to-f8-store-reduce/kernel.pto | 4 +-- .../group-broadcast-multi-consumer/kernel.pto | 10 +++--- .../group-load-s16-stride-store/kernel.pto | 4 +-- .../kernel.pto | 8 ++--- .../group-load-s32-stride-store/kernel.pto | 4 +-- .../vmi/group-reduce-basic-store/kernel.pto | 12 +++---- .../group-reduce-f16-addf-store/kernel.pto | 4 +-- .../group-reduce-f16-f8-mul-store/kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../group-reduce-i32-addi-store/kernel.pto | 4 +-- .../group-reduce-i32-maxi-store/kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 8 ++--- .../kernel.pto | 8 ++--- .../kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 6 ++-- .../kernel.pto | 4 +-- .../kernel.pto | 8 ++--- .../group-reduce-s32-cf-join-store/kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 8 ++--- .../kernel.pto | 10 +++--- .../group-reduce-s64-tail-store/kernel.pto | 4 +-- .../group-reduce-s64-truncf-store/kernel.pto | 6 ++-- .../group-reduce-slot-add-store/kernel.pto | 20 +++++------ .../vmi/group-slots-cf-join-store/kernel.pto | 36 +++++++++---------- .../kernel.pto | 10 +++--- .../vmi/group-slots-scf-for-store/kernel.pto | 14 ++++---- .../masked-load-dense-group-users/kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../kernel.pto | 4 +-- .../vmi/private-call-inline-store/kernel.pto | 4 +-- .../widen-f16-to-f32-store-reduce/kernel.pto | 4 +-- 38 files changed, 130 insertions(+), 130 deletions(-) diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto index 3881dfc10f..648a98b0a9 100644 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.pto @@ -49,9 +49,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %prod, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto index 2d0dcd2c64..48ce7738a5 100644 --- a/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto +++ b/test/vpto/cases/vmi/dense-group-reduce-multi-consumer/kernel.pto @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr pto.vmi.store %x, %ub_copy[%c0] : !pto.vmi.vreg<256xf32>, !pto.ptr } diff --git a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto index 8e9ebed693..1c4918951e 100644 --- a/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/dynamic-create-group-mask-s32-reduce-store/kernel.pto @@ -45,9 +45,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.ptr %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto index 6f68510ede..1a0f7f0d42 100644 --- a/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto +++ b/test/vpto/cases/vmi/f32-to-f8-store-reduce/kernel.pto @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<256xpred> %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %x8 = pto.vmi.truncf %x32 : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf8E4M3FN> pto.vmi.store %x8, %ub_out8_f8[%c0] diff --git a/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto index 3c14b7fc38..f81b4dfd24 100644 --- a/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto +++ b/test/vpto/cases/vmi/group-broadcast-multi-consumer/kernel.pto @@ -37,19 +37,19 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %b_for_mul = pto.vmi.group_broadcast %sum32 {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %y = pto.vmi.mulf %x, %b_for_mul : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %b_for_cast = pto.vmi.group_broadcast %sum32 {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %h = pto.vmi.truncf %b_for_cast : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> pto.vmi.store %h, %ub_dense[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr diff --git a/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto index f28676f8d5..6de55bf7fb 100644 --- a/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto +++ b/test/vpto/cases/vmi/group-load-s16-stride-store/kernel.pto @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto index cf2aea21d7..e73c083e55 100644 --- a/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto +++ b/test/vpto/cases/vmi/group-load-s32-stride-broadcast-reduce/kernel.pto @@ -35,17 +35,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %broadcast = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %scaled = pto.vmi.mulf %x, %broadcast : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %scaled_sum = pto.vmi.group_reduce_addf %scaled, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %scaled_sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto index 7afde7d6f5..609ebb6891 100644 --- a/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto +++ b/test/vpto/cases/vmi/group-load-s32-stride-store/kernel.pto @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto index 4db72772c1..123ef977f1 100644 --- a/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-basic-store/kernel.pto @@ -54,25 +54,25 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<64xf32> %sum8 = pto.vmi.group_reduce_addf %x8, %mask8 {num_groups = 8, reassoc} : !pto.vmi.vreg<64xf32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum8, %ub_dst8[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<64xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %mask16 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum16, %ub_dst16[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum32, %ub_dst32[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto index b8d274c280..f586b23278 100644 --- a/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-f16-addf-store/kernel.pto @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf16> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf16>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf16> + -> !pto.vmi.vreg<8xf16> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf16>, !pto.ptr + : !pto.vmi.vreg<8xf16>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto index 9cedd97e60..34c6a1bf1e 100644 --- a/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-f16-f8-mul-store/kernel.pto @@ -47,13 +47,13 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %src_f16_f32, %mask {num_groups = 2, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<2xf32> %src_f8 = pto.vmi.group_load %ub_f8[%c0], %c320 {num_groups = 2} : !pto.ptr -> !pto.vmi.vreg<512xf8E4M3FN> %src_f8_f32 = pto.vmi.extf %src_f8 : !pto.vmi.vreg<512xf8E4M3FN> -> !pto.vmi.vreg<512xf32> %sum_vec = pto.vmi.group_broadcast %sum {num_groups = 2} - : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + : !pto.vmi.vreg<2xf32> -> !pto.vmi.vreg<512xf32> %out = pto.vmi.mulf %sum_vec, %src_f8_f32 : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> pto.vmi.group_store %out, %ub_dst[%c0], %c320 {num_groups = 2} diff --git a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto index da95759e3c..e4a7d10eeb 100644 --- a/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-i16-extsi-i32-addi-store/kernel.pto @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xi32> %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} : !pto.vmi.vreg<128xi32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xi32> + -> !pto.vmi.vreg<8xi32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xi32>, !pto.ptr + : !pto.vmi.vreg<8xi32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto index 783658e453..d311a4a932 100644 --- a/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-i32-addi-store/kernel.pto @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<64xi32> %sum = pto.vmi.group_reduce_addi %x, %mask {num_groups = 8} : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xi32> + -> !pto.vmi.vreg<8xi32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<64xi32>, !pto.ptr + : !pto.vmi.vreg<8xi32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto index f11fa15503..1c594cf5c0 100644 --- a/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-i32-maxi-store/kernel.pto @@ -35,9 +35,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<64xi32> %sum = pto.vmi.group_reduce_maxi %x, %mask {num_groups = 8} : !pto.vmi.vreg<64xi32>, !pto.vmi.mask<64xpred> - -> !pto.vmi.vreg<64xi32> + -> !pto.vmi.vreg<8xi32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<64xi32>, !pto.ptr + : !pto.vmi.vreg<8xi32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto index 97154d0dd6..04a1afda13 100644 --- a/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-i8-extsi-i32-addi-store/kernel.pto @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xi32> %sum = pto.vmi.group_reduce_addi %x32, %mask {num_groups = 8} : !pto.vmi.vreg<256xi32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xi32> + -> !pto.vmi.vreg<8xi32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xi32>, !pto.ptr + : !pto.vmi.vreg<8xi32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto index e41c4d656d..3f0243d8e1 100644 --- a/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s16-broadcast-reduce-store/kernel.pto @@ -33,17 +33,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto index 56f042af1e..eb6ebedee1 100644 --- a/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-broadcast-reduce-store/kernel.pto @@ -39,17 +39,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto index c07f2782fd..04af55f5bc 100644 --- a/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s16-group-mask-tail-store/kernel.pto @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto index b53a1a51ff..f22ce53896 100644 --- a/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s16-stride-group-mask-tail-store/kernel.pto @@ -39,9 +39,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto index 29193f5d6b..7063b4e5ef 100644 --- a/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s16-truncf-broadcast-store/kernel.pto @@ -35,11 +35,11 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %sum16 = pto.vmi.truncf %sum32 - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf16> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> %rows = pto.vmi.group_broadcast %sum16 {num_groups = 8} - : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf16> + : !pto.vmi.vreg<8xf16> -> !pto.vmi.vreg<128xf16> pto.vmi.store %rows, %ub_dst[%c0] : !pto.vmi.vreg<128xf16>, !pto.ptr } diff --git a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto index d21fb5efd2..f4c7b4f18a 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-add-bias-store/kernel.pto @@ -38,9 +38,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %biased, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto index f51fe89924..aa20ef0c55 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-broadcast-reduce-store/kernel.pto @@ -33,17 +33,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> -> !pto.vmi.vreg<256xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto index de08d084e6..271fa80ef7 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-cf-join-store/kernel.pto @@ -47,9 +47,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto index 758691c5cf..0b1ea79141 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-multitile-store/kernel.pto @@ -33,9 +33,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 16, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<16xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 16} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<16xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto index 4e311c0703..3e78d88df0 100644 --- a/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s32-tail-full-tile-store/kernel.pto @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<256xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto index bcb027a753..1712ef8025 100644 --- a/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s64-broadcast-reduce-store/kernel.pto @@ -37,17 +37,17 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<512xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<512xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_dst[%c0], %c8 {num_groups = 8} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto index 04338c1c1b..5765b56274 100644 --- a/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s64-slot-add-store/kernel.pto @@ -42,15 +42,15 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<512xpred> %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<512xf32> %rhs = pto.vmi.group_slot_load %ub_rhs[%c0], %c8 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<512xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %out = pto.vmi.addf %sum, %rhs - : !pto.vmi.vreg<512xf32>, !pto.vmi.vreg<512xf32> - -> !pto.vmi.vreg<512xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %out, %ub_dst[%c0], %c8 {num_groups = 8} - : !pto.vmi.vreg<512xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto index 5167c9198a..1073e351c8 100644 --- a/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s64-tail-store/kernel.pto @@ -36,9 +36,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<384xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 6, reassoc} : !pto.vmi.vreg<384xf32>, !pto.vmi.mask<384xpred> - -> !pto.vmi.vreg<384xf32> + -> !pto.vmi.vreg<6xf32> pto.vmi.group_store %sum, %ub_dst[%c0], %c8 {num_groups = 6} - : !pto.vmi.vreg<384xf32>, !pto.ptr + : !pto.vmi.vreg<6xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto index 6436738080..b909f1f66c 100644 --- a/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-s64-truncf-store/kernel.pto @@ -36,11 +36,11 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<512xf32> %sum32 = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<512xf32>, !pto.vmi.mask<512xpred> - -> !pto.vmi.vreg<512xf32> + -> !pto.vmi.vreg<8xf32> %sum16 = pto.vmi.truncf %sum32 - : !pto.vmi.vreg<512xf32> -> !pto.vmi.vreg<512xf16> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<8xf16> pto.vmi.group_store %sum16, %ub_dst[%c0], %c16 {num_groups = 8} - : !pto.vmi.vreg<512xf16>, !pto.ptr + : !pto.vmi.vreg<8xf16>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto index 291251e0bf..6cbe6b01fc 100644 --- a/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto +++ b/test/vpto/cases/vmi/group-reduce-slot-add-store/kernel.pto @@ -48,28 +48,28 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %x16 = pto.vmi.load %ub_src16[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> %rhs16 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %sum16 = pto.vmi.group_reduce_addf %x16, %mask16 {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %out16 = pto.vmi.addf %sum16, %rhs16 - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %out16, %ub_dst16[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %mask32 = pto.vmi.create_mask %c256 : index -> !pto.vmi.mask<256xpred> %x32 = pto.vmi.load %ub_src32[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32> %rhs32 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<256xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %sum32 = pto.vmi.group_reduce_addf %x32, %mask32 {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> %out32 = pto.vmi.addf %sum32, %rhs32 - : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> - -> !pto.vmi.vreg<256xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %out32, %ub_dst32[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto index 7fcdd382c8..5d66d5ff13 100644 --- a/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto +++ b/test/vpto/cases/vmi/group-slots-cf-join-store/kernel.pto @@ -42,45 +42,45 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> - %reduce_join = scf.if %cond_true -> !pto.vmi.vreg<128xf32> { + %reduce_join = scf.if %cond_true -> !pto.vmi.vreg<8xf32> { %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> - scf.yield %sum : !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %sum : !pto.vmi.vreg<8xf32> } else { %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> - scf.yield %slot : !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %slot : !pto.vmi.vreg<8xf32> } %bias0 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %reduce_out = pto.vmi.addf %reduce_join, %bias0 - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %reduce_out, %ub_dst_reduce[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr - %slot_join = scf.if %cond_false -> !pto.vmi.vreg<128xf32> { + %slot_join = scf.if %cond_false -> !pto.vmi.vreg<8xf32> { %x = pto.vmi.load %ub_src[%c0] : !pto.ptr -> !pto.vmi.vreg<128xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> - scf.yield %sum : !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %sum : !pto.vmi.vreg<8xf32> } else { %slot = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> - scf.yield %slot : !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> + scf.yield %slot : !pto.vmi.vreg<8xf32> } %bias1 = pto.vmi.group_slot_load %ub_rhs[%c0], %c1 {num_groups = 8} - : !pto.ptr -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %slot_out = pto.vmi.addf %slot_join, %bias1 - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %slot_out, %ub_dst_slot[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto index 0660b1e0a3..636db6de38 100644 --- a/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto +++ b/test/vpto/cases/vmi/group-slots-fanout-store-broadcast/kernel.pto @@ -42,19 +42,19 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr %b = pto.vmi.group_broadcast %sum {num_groups = 8} - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<128xf32> %y = pto.vmi.mulf %x, %b : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> %ysum = pto.vmi.group_reduce_addf %y, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %ysum, %ub_out[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto index 8ae0c03444..f0e6dc5e25 100644 --- a/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto +++ b/test/vpto/cases/vmi/group-slots-scf-for-store/kernel.pto @@ -37,9 +37,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind -> !pto.vmi.vreg<128xf32> + : !pto.ptr -> !pto.vmi.vreg<8xf32> %acc = scf.for %i = %c0 to %c2 step %c1 - iter_args(%arg = %acc0) -> (!pto.vmi.vreg<128xf32>) { + iter_args(%arg = %acc0) -> (!pto.vmi.vreg<8xf32>) { %x = pto.vmi.group_load %ub_src[%c0], %c16 {num_groups = 8} : !pto.ptr -> !pto.vmi.vreg<128xf32> %mask = pto.vmi.create_group_mask %c16 @@ -47,14 +47,14 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> %next = pto.vmi.addf %arg, %sum - : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> - -> !pto.vmi.vreg<128xf32> - scf.yield %next : !pto.vmi.vreg<128xf32> + : !pto.vmi.vreg<8xf32>, !pto.vmi.vreg<8xf32> + -> !pto.vmi.vreg<8xf32> + scf.yield %next : !pto.vmi.vreg<8xf32> } pto.vmi.group_store %acc, %ub_dst[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto index 503068186e..e491e30698 100644 --- a/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto +++ b/test/vpto/cases/vmi/masked-load-dense-group-users/kernel.pto @@ -42,9 +42,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.ptr %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto index 37a10109ee..c07d3c503e 100644 --- a/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto +++ b/test/vpto/cases/vmi/masked-load-group-tail-s32-reduce-store/kernel.pto @@ -43,9 +43,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.ptr %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto index eb8f7f5e6a..4049b38720 100644 --- a/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto +++ b/test/vpto/cases/vmi/private-call-argument-boundary-store/kernel.pto @@ -14,9 +14,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %out[%off], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr return } diff --git a/test/vpto/cases/vmi/private-call-inline-store/kernel.pto b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto index 5f7beec943..5e713650bc 100644 --- a/test/vpto/cases/vmi/private-call-inline-store/kernel.pto +++ b/test/vpto/cases/vmi/private-call-inline-store/kernel.pto @@ -48,9 +48,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<256xpred> %sum = pto.vmi.group_reduce_addf %x, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<256xf32>, !pto.vmi.mask<256xpred> - -> !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<256xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr } pto.set_flag["PIPE_V", "PIPE_MTE3", "EVENT_ID0"] diff --git a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto index 9f3dfeabb4..9b926ac640 100644 --- a/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto +++ b/test/vpto/cases/vmi/widen-f16-to-f32-store-reduce/kernel.pto @@ -46,9 +46,9 @@ module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind !pto.vmi.mask<128xpred> %sum = pto.vmi.group_reduce_addf %x32, %mask {num_groups = 8, reassoc} : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> - -> !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<8xf32> pto.vmi.group_store %sum, %ub_sum[%c0], %c1 {num_groups = 8} - : !pto.vmi.vreg<128xf32>, !pto.ptr + : !pto.vmi.vreg<8xf32>, !pto.ptr pto.vmi.store %x32, %ub_dense[%c0] : !pto.vmi.vreg<128xf32>, !pto.ptr } From b8ea175f6d346a918ae5c91a6d52cbb4901ccb29 Mon Sep 17 00:00:00 2001 From: mouliangyu <21963576+mouliangyu@users.noreply.github.com> Date: Wed, 1 Jul 2026 12:33:31 +0800 Subject: [PATCH 51/54] Optimize VMI group broadcast load layout --- include/PTO/Transforms/Passes.h | 1 + include/PTO/Transforms/Passes.td | 18 + include/PTO/Transforms/VMILayoutSupport.h | 1 + lib/PTO/Transforms/CMakeLists.txt | 1 + lib/PTO/Transforms/VMILayoutAssignment.cpp | 153 ++--- lib/PTO/Transforms/VMILayoutSupport.cpp | 70 +- .../Transforms/VMIPreAssignmentCombine.cpp | 81 +++ lib/PTO/Transforms/VMIToVPTO.cpp | 621 ++++++++++++++---- ...ment_group_slot_broadcast_load_e2b_b16.pto | 2 +- ...ment_combine_group_slot_broadcast_load.pto | 78 +++ ..._broadcast_load_e2b_b16_stride_invalid.pto | 4 +- tools/ptoas/ptoas.cpp | 6 + 12 files changed, 760 insertions(+), 276 deletions(-) create mode 100644 lib/PTO/Transforms/VMIPreAssignmentCombine.cpp create mode 100644 test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index cc73821a30..3b8a74f959 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -115,6 +115,7 @@ LogicalResult validateVMILayoutAssignedIR(ModuleOp module, bool verifyHelperSupport = true); std::unique_ptr createPTOValidateVMIIRPass(); std::unique_ptr createPTOValidateVMILayoutIRPass(); +std::unique_ptr createVMIPreAssignmentCombinePass(); std::unique_ptr createVMILayoutAssignmentPass(); std::unique_ptr createVMILayoutFoldPass(); std::unique_ptr createVMILayoutRematerializePass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index f000a50060..a2ce108db6 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -826,6 +826,24 @@ def PTOValidateVMILayoutIR "mlir::scf::SCFDialect"]; } +def VMIPreAssignmentCombine + : Pass<"vmi-pre-assignment-combine", "ModuleOp"> { + let summary = "Combine VMI operations before layout assignment"; + let description = [{ + Performs VMI-level structural combines before VMI layout assignment. This + keeps layout assignment focused on choosing and materializing layouts while + still exposing direct semantic operations to the later layout and lowering + passes. + + The pass currently rewrites the semantic pattern + `group_broadcast(group_slot_load(...))` into the equivalent + `group_broadcast_load` operation. + }]; + let constructor = "mlir::pto::createVMIPreAssignmentCombinePass()"; + let dependentDialects = ["mlir::func::FuncDialect", + "mlir::pto::PTODialect"]; +} + def VMILayoutAssignment : Pass<"vmi-layout-assignment", "ModuleOp"> { let summary = "Assign concrete VMI layouts and mask granularities"; let description = [{ diff --git a/include/PTO/Transforms/VMILayoutSupport.h b/include/PTO/Transforms/VMILayoutSupport.h index 7a15864928..b03d12cae4 100644 --- a/include/PTO/Transforms/VMILayoutSupport.h +++ b/include/PTO/Transforms/VMILayoutSupport.h @@ -157,6 +157,7 @@ struct VMIGroupBroadcastSupport { enum class VMIGroupBroadcastLoadSupportKind { E2BVlds, + SlotLoadThenBroadcast, }; struct VMIGroupBroadcastLoadSupport { diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 116d4f10cb..4d95daf90d 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ add_mlir_dialect_library(PTOTransforms PTOValidateVPTOIR.cpp PTOUnrollSIMTForPass.cpp PTOValidateVMIIR.cpp + VMIPreAssignmentCombine.cpp VMILegalizeArithSelect.cpp VMILayoutAssignment.cpp VMILayoutFold.cpp diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 627660ae98..737e204dd2 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -326,7 +326,7 @@ struct LayoutSolver { return existing; if (!isE2BGroupBroadcastLoadCandidate(op)) - return getContiguousLayout(); + return {}; int64_t numGroups = op.getNumGroupsAttr().getInt(); int64_t groupSize = type.getElementCount() / numGroups; int64_t directGroupSize = 256 / getElementBitWidth(type.getElementType()); @@ -546,7 +546,8 @@ struct LayoutSolver { VMISubIOp, VMIMulFOp, VMIMulIOp, VMIFmaOp, VMIDivFOp, VMIMinFOp, VMIMaxFOp, VMINegFOp, VMIAbsFOp, VMIAbsIOp, VMISqrtOp, VMIExpOp, VMILnOp, VMIReluOp, VMIFPToSIOp, VMIAndIOp, VMIOrIOp, VMIXOrIOp, - VMIShLIOp, VMIShRUIOp, VMINotOp, VMISelectOp, VMIBitcastOp>(op); + VMIShLIOp, VMIShRUIOp, VMINotOp, VMISelectOp, VMIBitcastOp, + VMIGroupBroadcastLoadOp>(op); } bool canGroupBroadcastProduceLayout(VMIGroupBroadcastOp broadcast, @@ -567,6 +568,37 @@ struct LayoutSolver { capabilities, assignedSourceType, assignedResultType, numGroups)); } + bool canGroupBroadcastLoadProduceLayout(VMIGroupBroadcastLoadOp load, + VMILayoutAttr resultLayout) { + if (!resultLayout) + return false; + auto resultType = cast(load.getResult().getType()); + int64_t numGroups = load.getNumGroupsAttr().getInt(); + unsigned elementBits = getElementBitWidth(resultType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return false; + std::optional stride = + getConstantIndexValue(load.getSourceGroupStride()); + int64_t alignedStrideElems = 256 / elementBits; + int64_t slots = 0; + if (stride && *stride == 1) + slots = 8; + else if (stride && *stride > 0 && *stride % alignedStrideElems == 0) + slots = 1; + else + return false; + + auto assignedSourceType = + VMIVRegType::get(ctx, numGroups, resultType.getElementType(), + VMILayoutAttr::getGroupSlots(ctx, numGroups, slots)); + auto assignedResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + VMILayoutSupport supports; + return succeeded(supports.getGroupBroadcastSupport( + capabilities, assignedSourceType, assignedResultType, numGroups)); + } + bool canEquivalenceClassAdoptConsumerLayout(Value value, VMILayoutAttr requestedLayout) { unsigned id = addDataValue(value); @@ -581,6 +613,11 @@ struct LayoutSolver { !canGroupBroadcastProduceLayout(broadcast, requestedLayout)) return false; } + if (auto load = node.value.getDefiningOp()) { + if (node.value == load.getResult() && + !canGroupBroadcastLoadProduceLayout(load, requestedLayout)) + return false; + } } return true; } @@ -588,7 +625,10 @@ struct LayoutSolver { bool isUnsupportedGroupBroadcastResultForLayout(Value value, VMILayoutAttr layout) { auto broadcast = value.getDefiningOp(); - return broadcast && !canGroupBroadcastProduceLayout(broadcast, layout); + if (broadcast) + return !canGroupBroadcastProduceLayout(broadcast, layout); + auto load = value.getDefiningOp(); + return load && !canGroupBroadcastLoadProduceLayout(load, layout); } LogicalResult constrainElementwiseBinary(OpOperand &lhs, OpOperand &rhs, @@ -703,109 +743,6 @@ struct LayoutSolver { return success(); } - bool shouldCommuteTruncFAfterGroupBroadcast(VMIGroupBroadcastOp broadcast) { - auto truncf = broadcast.getSource().getDefiningOp(); - if (!truncf) - return false; - - auto truncSourceType = dyn_cast(truncf.getSource().getType()); - auto truncResultType = dyn_cast(truncf.getResult().getType()); - auto broadcastResultType = - dyn_cast(broadcast.getResult().getType()); - if (!truncSourceType || !truncResultType || !broadcastResultType) - return false; - if (truncSourceType.getElementCount() != - truncResultType.getElementCount() || - truncResultType.getElementCount() != - broadcastResultType.getElementCount()) - return false; - - VMILayoutAttr sourceLayout = truncSourceType.getLayoutAttr(); - bool sourceIsGroupSlotValue = - (sourceLayout && sourceLayout.isGroupSlots()) || - truncf.getSource().getDefiningOp() || - truncf.getSource().getDefiningOp() || - truncf.getSource().getDefiningOp(); - if (!sourceIsGroupSlotValue) - return false; - - unsigned sourceBits = getElementBitWidth(truncSourceType.getElementType()); - unsigned resultBits = getElementBitWidth(truncResultType.getElementType()); - return truncSourceType.getElementType().isF32() && sourceBits > resultBits; - } - - LogicalResult commuteTruncFAfterGroupBroadcast() { - SmallVector broadcasts; - module.walk([&](VMIGroupBroadcastOp broadcast) { - if (shouldCommuteTruncFAfterGroupBroadcast(broadcast)) - broadcasts.push_back(broadcast); - }); - - OpBuilder builder(ctx); - for (VMIGroupBroadcastOp broadcast : broadcasts) { - auto truncf = broadcast.getSource().getDefiningOp(); - if (!truncf) - continue; - - auto truncSourceType = cast(truncf.getSource().getType()); - auto broadcastResultType = - cast(broadcast.getResult().getType()); - auto wideBroadcastType = - VMIVRegType::get(ctx, broadcastResultType.getElementCount(), - truncSourceType.getElementType(), - broadcastResultType.getLayoutAttr()); - - builder.setInsertionPoint(broadcast); - auto wideBroadcast = builder.create( - broadcast.getLoc(), wideBroadcastType, truncf.getSource(), - broadcast.getNumGroupsAttr()); - auto narrow = builder.create( - broadcast.getLoc(), broadcastResultType, wideBroadcast.getResult()); - broadcast.getResult().replaceAllUsesWith(narrow.getResult()); - broadcast.erase(); - if (truncf->use_empty()) - truncf.erase(); - } - return success(); - } - - LogicalResult fuseGroupSlotBroadcastLoads() { - SmallVector broadcasts; - module.walk([&](VMIGroupBroadcastOp broadcast) { - auto load = broadcast.getSource().getDefiningOp(); - if (!load || !load.getResult().hasOneUse()) - return; - if (load.getNumGroupsAttr().getInt() != - broadcast.getNumGroupsAttr().getInt()) - return; - - if (!isE2BGroupBroadcastLoadCandidate( - cast(broadcast.getResult().getType()), - load.getSource().getType(), load.getSourceGroupStride(), - broadcast.getNumGroupsAttr().getInt())) - return; - broadcasts.push_back(broadcast); - }); - - OpBuilder builder(ctx); - for (VMIGroupBroadcastOp broadcast : broadcasts) { - auto load = broadcast.getSource().getDefiningOp(); - if (!load) - continue; - - builder.setInsertionPoint(broadcast); - auto fused = builder.create( - broadcast.getLoc(), broadcast.getResult().getType(), - load.getSource(), load.getOffset(), load.getSourceGroupStride(), - broadcast.getNumGroupsAttr()); - broadcast.getResult().replaceAllUsesWith(fused.getResult()); - broadcast.erase(); - if (load->use_empty()) - load.erase(); - } - return success(); - } - LogicalResult addConstraints() { WalkResult result = module.walk([&](Operation *op) -> WalkResult { if (auto maskAnd = dyn_cast(op)) { @@ -2030,10 +1967,6 @@ struct LayoutSolver { } LogicalResult run() { - if (failed(fuseGroupSlotBroadcastLoads())) - return failure(); - if (failed(commuteTruncFAfterGroupBroadcast())) - return failure(); if (failed(collect())) return failure(); if (failed(addConstraints())) diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index fff22f6528..0a724f09ba 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -1392,25 +1392,13 @@ VMILayoutSupport::getGroupBroadcastLoadSupport( .isSupported()) return fail("requires supported direct memory source"); if (!isa(op.getSource().getType())) - return fail("requires !pto.ptr source for E2B lowering"); + return fail("requires !pto.ptr source"); unsigned elementBits = pto::getPTOStorageElemBitWidth(resultType.getElementType()); - if (elementBits != 16 && elementBits != 32) - return fail("E2B lowering currently supports only 16-bit and 32-bit " - "element types"); - int64_t directGroupSize = 256 / elementBits; - VMILayoutAttr layout = resultType.getLayoutAttr(); if (!layout) - return fail("E2B lowering requires assigned result layout"); - bool contiguousPacketLayout = layout.isContiguous(); - bool splitPacketLayout = layout.isDeinterleaved() && layout.getFactor() == 2 && - layout.getBlockElems() == 1; - if (!contiguousPacketLayout && !splitPacketLayout) - return fail("E2B lowering requires contiguous result layout for " - "direct group size or deinterleaved=2, block_elems=1 " - "result layout for split group size"); + return fail("requires assigned result layout"); std::string fullChunkReason; if (failed(checkFullDataPhysicalChunks(resultType, &fullChunkReason))) @@ -1419,27 +1407,51 @@ VMILayoutSupport::getGroupBroadcastLoadSupport( FailureOr lanesPerPart = getDataLanesPerPart(resultType.getElementType()); - if (failed(lanesPerPart) || *lanesPerPart != (2048 / elementBits)) - return fail("E2B lowering requires one full 256-byte vreg per physical " - "part"); + if (failed(lanesPerPart)) + return fail("requires known result lanes per physical part"); int64_t groupSize = resultType.getElementCount() / numGroups; - if (contiguousPacketLayout && groupSize != directGroupSize) - return fail("E2B contiguous lowering requires logical group size matching " - "the element-width direct packet size"); - if (splitPacketLayout && groupSize != 2 * directGroupSize) - return fail("E2B deinterleaved=2 lowering requires logical group size " - "matching the element-width split packet size"); - if (numGroups % 8 != 0) - return fail("E2B lowering requires num_groups to be a multiple of 8"); - std::optional stride = getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride != 1) - return fail("E2B lowering requires constant unit source_group_stride"); + bool contiguousPacketLayout = layout.isContiguous(); + bool splitPacketLayout = layout.isDeinterleaved() && layout.getFactor() == 2 && + layout.getBlockElems() == 1; + if ((elementBits == 16 || elementBits == 32) && + *lanesPerPart == static_cast(2048 / elementBits) && + (contiguousPacketLayout || splitPacketLayout) && numGroups % 8 == 0 && + stride && *stride == 1) { + int64_t directGroupSize = 256 / elementBits; + if ((contiguousPacketLayout && groupSize == directGroupSize) || + (splitPacketLayout && groupSize == 2 * directGroupSize)) + return VMIGroupBroadcastLoadSupport{ + VMIGroupBroadcastLoadSupportKind::E2BVlds}; + } + + if (elementBits == 0 || 256 % elementBits != 0) + return fail("fallback lowering requires an 8/16/32-bit element type"); + int64_t alignedStrideElems = 256 / elementBits; + int64_t slots = 0; + if (stride && *stride == 1) + slots = 8; + else if (stride && *stride > 0 && *stride % alignedStrideElems == 0) + slots = 1; + else + return fail(Twine("fallback lowering requires constant unit " + "source_group_stride for packed slots or constant " + "positive source_group_stride divisible by ") + + Twine(alignedStrideElems) + " elements for lane-0 slots"); + + auto sourceType = VMIVRegType::get( + resultType.getContext(), numGroups, resultType.getElementType(), + VMILayoutAttr::getGroupSlots(resultType.getContext(), numGroups, slots)); + std::string broadcastReason; + if (failed(getGroupBroadcastSupport(capabilities, sourceType, resultType, + numGroups, &broadcastReason))) + return fail(Twine("fallback broadcast is unsupported; ") + + broadcastReason); return VMIGroupBroadcastLoadSupport{ - VMIGroupBroadcastLoadSupportKind::E2BVlds}; + VMIGroupBroadcastLoadSupportKind::SlotLoadThenBroadcast}; } FailureOr VMILayoutSupport::getGroupBroadcastSupport( diff --git a/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp b/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp new file mode 100644 index 0000000000..afc5ff04f9 --- /dev/null +++ b/lib/PTO/Transforms/VMIPreAssignmentCombine.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under +// the terms and conditions of CANN Open Software License Agreement Version 2.0 +// (the "License"). Please refer to the License for details. You may not use +// this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON +// AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS +// FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository +// for the full text of the License. + +//===- VMIPreAssignmentCombine.cpp - Pre-assignment VMI combines ---------===// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_VMIPREASSIGNMENTCOMBINE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult fuseGroupSlotBroadcastLoads(ModuleOp module) { + SmallVector broadcasts; + module.walk([&](VMIGroupBroadcastOp broadcast) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load || !load.getResult().hasOneUse()) + return; + if (load.getNumGroupsAttr().getInt() != + broadcast.getNumGroupsAttr().getInt()) + return; + + if (!isa(broadcast.getResult().getType())) + return; + broadcasts.push_back(broadcast); + }); + + OpBuilder builder(module.getContext()); + for (VMIGroupBroadcastOp broadcast : broadcasts) { + auto load = broadcast.getSource().getDefiningOp(); + if (!load) + continue; + + builder.setInsertionPoint(broadcast); + auto fused = builder.create( + broadcast.getLoc(), broadcast.getResult().getType(), load.getSource(), + load.getOffset(), load.getSourceGroupStride(), + broadcast.getNumGroupsAttr()); + broadcast.getResult().replaceAllUsesWith(fused.getResult()); + broadcast.erase(); + if (load->use_empty()) + load.erase(); + } + return success(); +} + +struct VMIPreAssignmentCombinePass + : pto::impl::VMIPreAssignmentCombineBase { + void runOnOperation() override { + if (failed(fuseGroupSlotBroadcastLoads(getOperation()))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createVMIPreAssignmentCombinePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 70c817e799..30b76b2f6f 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -4824,6 +4824,402 @@ struct OneToNVMIGroupLoadOpPattern : OneToNOpConversionPattern { } }; +static LogicalResult lowerGroupSlotLoadParts( + Operation *op, Value source, Value offset, Value sourceGroupStride, + VMIVRegType resultVMIType, TypeRange resultTypes, int64_t numGroups, + OneToNPatternRewriter &rewriter, SmallVectorImpl &results) { + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); + if (!layout || !layout.isGroupSlots() || layout.getSlots() <= 0) + return rewriter.notifyMatchFailure( + op, "group_slot_load requires explicit group_slots layout"); + if (!isa(source.getType())) + return rewriter.notifyMatchFailure(op, + "group_slot_load requires !pto.ptr source"); + + int64_t slots = layout.getSlots(); + int64_t expectedArity = ceilDivNonNegative(numGroups, slots); + if (static_cast(resultTypes.size()) != expectedArity) + return rewriter.notifyMatchFailure(op, "group_slot_load arity mismatch"); + + auto makeI16 = [&](int64_t value) -> Value { + return rewriter.create(op->getLoc(), value, 16); + }; + Value zeroI16 = makeI16(0); + auto makePtr = [&](Value elementOffset) -> Value { + return rewriter + .create(op->getLoc(), source.getType(), source, + elementOffset) + .getResult(); + }; + + results.reserve(results.size() + resultTypes.size()); + + if (slots == 8) { + std::optional stride = getConstantIndexValue(sourceGroupStride); + if (!stride || *stride != 1) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load requires constant unit stride"); + for (auto [chunk, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + int64_t groupBegin = static_cast(chunk) * slots; + int64_t activeGroups = std::min(slots, numGroups - groupBegin); + if (activeGroups <= 0) + return rewriter.notifyMatchFailure( + op, "slots=8 group_slot_load has no active groups for chunk"); + std::string pattern = (Twine("PAT_VL") + Twine(activeGroups)).str(); + FailureOr slotMask = + createPrefixMask(op->getLoc(), *maskType, pattern, rewriter); + if (failed(slotMask)) + return rewriter.notifyMatchFailure( + op, "failed to create slots=8 group_slot_load mask"); + Value groupOffset = + createChunkOffset(op->getLoc(), offset, groupBegin, rewriter); + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op->getLoc(), vregType, slotBase, + zeroI16, zeroI16, *slotMask) + .getResult()); + } + return success(); + } + + if (slots != 1) + return rewriter.notifyMatchFailure( + op, "group_slot_load supports only slots=8 or slots=1"); + unsigned elementBits = + pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); + if (elementBits == 0 || 256 % elementBits != 0) + return rewriter.notifyMatchFailure( + op, "slots=1 group_slot_load requires supported element width"); + int64_t alignedStrideElems = 256 / elementBits; + std::optional constantStride = + getConstantIndexValue(sourceGroupStride); + if (!constantStride || *constantStride <= 0 || + *constantStride % alignedStrideElems != 0) + return rewriter.notifyMatchFailure( + op, Twine("slots=1 group_slot_load requires constant positive " + "source_group_stride divisible by ") + + Twine(alignedStrideElems) + + " elements for 32B lane-0 vsldb alignment"); + + for (auto [group, resultType] : llvm::enumerate(resultTypes)) { + auto vregType = dyn_cast(resultType); + if (!vregType) + return rewriter.notifyMatchFailure(op, + "group_slot_load result must be vreg"); + FailureOr maskType = + getMaskTypeForVReg(vregType, rewriter.getContext()); + if (failed(maskType)) + return rewriter.notifyMatchFailure( + op, "unsupported element type for group_slot_load mask"); + FailureOr oneBlockMask = + createPrefixMask(op->getLoc(), *maskType, "PAT_VL1", rewriter); + if (failed(oneBlockMask)) + return rewriter.notifyMatchFailure(op, + "failed to create group_slot_load mask"); + Value groupOffset = offset; + if (group != 0) { + Value groupIndex = + rewriter.create(op->getLoc(), group); + Value rowOffset = + rewriter + .create(op->getLoc(), sourceGroupStride, + groupIndex) + .getResult(); + groupOffset = + rewriter.create(op->getLoc(), groupOffset, rowOffset) + .getResult(); + } + Value slotBase = makePtr(groupOffset); + results.push_back(rewriter + .create(op->getLoc(), vregType, slotBase, + zeroI16, zeroI16, *oneBlockMask) + .getResult()); + } + return success(); +} + +static LogicalResult lowerGroupBroadcastParts( + Operation *op, ValueRange sourceParts, VMIVRegType sourceVMIType, + VMIVRegType resultVMIType, TypeRange resultTypes, int64_t numGroups, + OneToNPatternRewriter &rewriter, SmallVectorImpl &results) { + FailureOr groupSize = + getGroupSizeFromNumGroups(resultVMIType, numGroups); + if (failed(groupSize)) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires num_groups to evenly divide lane count"); + int64_t lanesPerPart = 0; + int64_t groupCount = 0; + if (failed(checkFullGroupSlotSourceShape(op, sourceVMIType, *groupSize, + numGroups, &lanesPerPart, + &groupCount, rewriter))) + return failure(); + int64_t resultLayoutFactor = 0; + int64_t resultGroupCount = 0; + if (failed(checkFullGroupBroadcastResultShape( + op, resultVMIType, *groupSize, lanesPerPart, &resultLayoutFactor, + &resultGroupCount, rewriter))) + return failure(); + if (resultGroupCount != groupCount) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires matching source/result group slots"); + + if (sourceParts.empty() || resultTypes.empty()) + return rewriter.notifyMatchFailure(op, "group_broadcast arity mismatch"); + + auto firstSourceType = dyn_cast(sourceParts.front().getType()); + if (!firstSourceType) + return rewriter.notifyMatchFailure(op, + "group_broadcast source must be vreg"); + unsigned indexBits = + pto::getPTOStorageElemBitWidth(firstSourceType.getElementType()); + if (indexBits != 8 && indexBits != 16 && indexBits != 32) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires 8/16/32-bit index elements"); + auto indexElementType = IntegerType::get(rewriter.getContext(), indexBits); + auto indexType = + VRegType::get(rewriter.getContext(), firstSourceType.getElementCount(), + indexElementType); + FailureOr allMask = + createAllTrueMaskForVReg(op->getLoc(), firstSourceType, rewriter); + if (failed(allMask)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast all mask"); + VMILayoutAttr resultLayout = resultVMIType.getLayoutAttr(); + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t selectionGroupSize = *groupSize; + if (resultLayoutFactor != 1 && resultLayout && + resultLayout.isDeinterleaved() && resultLayout.getBlockElems() > 1 && + *groupSize < lanesPerPart) + selectionGroupSize = resultLayout.getBlockElems(); + auto resolveLargeGroupSource = [&](int64_t group, int64_t chunksPerGroup, + int64_t &sourceChunk, + int64_t &baseGroupSlot) { + int64_t slots = sourceLayout.getSlots(); + if (slots > 0) { + sourceChunk = group / slots; + baseGroupSlot = group % slots; + return; + } + sourceChunk = group * chunksPerGroup; + baseGroupSlot = 0; + }; + + results.clear(); + results.resize(resultTypes.size()); + for (auto [flatIndex, resultType] : llvm::enumerate(resultTypes)) { + auto resultVRegType = dyn_cast(resultType); + if (!resultVRegType || resultVRegType != firstSourceType) + return rewriter.notifyMatchFailure( + op, "group_broadcast requires uniform physical vreg types"); + int64_t sourceChunk = flatIndex; + int64_t baseGroupSlot = 0; + Value mappedGroupSlotIndex; + if (resultLayoutFactor == 1) { + if (*groupSize >= lanesPerPart) { + int64_t chunksPerGroup = *groupSize / lanesPerPart; + int64_t group = flatIndex / chunksPerGroup; + resolveLargeGroupSource(group, chunksPerGroup, sourceChunk, + baseGroupSlot); + } else { + VMILayoutAttr sourceLayout = sourceVMIType.getLayoutAttr(); + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + int64_t groupsPerResultChunk = lanesPerPart / *groupSize; + int64_t firstGroup = flatIndex * groupsPerResultChunk; + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + } + } else { + bool blockFragmentSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1 && *groupSize < lanesPerPart; + bool deinterleavedSmallGroup = + resultLayout && resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1 && *groupSize < lanesPerPart; + if (blockFragmentSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t groupsPerResultChunk = + lanesPerPart / resultLayout.getBlockElems(); + int64_t firstGroup = chunk * groupsPerResultChunk; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, + "group_broadcast block-fragment source requires explicit " + "group_slots slots or derivable legacy slot count"); + slots = groupCount / sourceParts.size(); + } + sourceChunk = firstGroup / slots; + baseGroupSlot = firstGroup % slots; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else if (deinterleavedSmallGroup) { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + int64_t slots = sourceLayout.getSlots(); + if (slots <= 0) { + if (sourceParts.empty() || + groupCount % static_cast(sourceParts.size()) != 0) + return rewriter.notifyMatchFailure( + op, "group_broadcast deinterleaved small-group source " + "requires explicit group_slots slots or derivable " + "legacy slot count"); + slots = groupCount / sourceParts.size(); + } + FailureOr index = createMappedGroupSlotIndexVector( + op->getLoc(), resultVMIType, part, chunk, indexType, + *groupSize, slots, sourceChunk, rewriter); + if (failed(index)) + return rewriter.notifyMatchFailure( + op, + "failed to create group_broadcast mapped group-slot index " + "vector"); + mappedGroupSlotIndex = *index; + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } else { + int64_t runningFlatIndex = 0; + bool found = false; + for (int64_t part = 0; part < resultLayoutFactor && !found; ++part) { + FailureOr chunks = getDataChunksInPart(resultVMIType, part); + if (failed(chunks)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to enumerate result chunks"); + for (int64_t chunk = 0; chunk < *chunks; + ++chunk, ++runningFlatIndex) { + if (runningFlatIndex != static_cast(flatIndex)) + continue; + FailureOr firstLogical = + mapPhysicalLaneToLogical(resultVMIType, part, chunk, 0); + FailureOr lastLogical = mapPhysicalLaneToLogical( + resultVMIType, part, chunk, lanesPerPart - 1); + if (failed(firstLogical) || failed(lastLogical)) + return rewriter.notifyMatchFailure( + op, "group_broadcast failed to map result chunk lanes"); + int64_t firstGroup = *firstLogical / *groupSize; + int64_t lastGroup = *lastLogical / *groupSize; + if (firstGroup != lastGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk crosses logical groups"); + int64_t chunksPerGroup = *groupSize / lanesPerPart; + resolveLargeGroupSource(firstGroup, chunksPerGroup, sourceChunk, + baseGroupSlot); + found = true; + break; + } + } + if (!found) + return rewriter.notifyMatchFailure( + op, "group_broadcast result chunk index is out of range"); + } + } + if (*groupSize >= lanesPerPart) { + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + if (sourceLayout.getSlots() > 1) { + FailureOr groupSlotIndex = createGroupSlotIndexVector( + op->getLoc(), indexType, selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } else { + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *allMask, + rewriter.getStringAttr("LOWEST")) + .getResult(); + } + } else { + bool blockFragmentSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() > 1; + bool deinterleavedSmallGroup = resultLayout && + resultLayout.isDeinterleaved() && + resultLayout.getBlockElems() == 1; + if (resultLayoutFactor != 1 && !blockFragmentSmallGroup && + !deinterleavedSmallGroup) + return rewriter.notifyMatchFailure( + op, "group_broadcast small-group deinterleaved result is not " + "supported"); + if (sourceChunk < 0 || + sourceChunk >= static_cast(sourceParts.size())) + return rewriter.notifyMatchFailure( + op, "group_broadcast source chunk is out of range"); + FailureOr groupSlotIndex = + mappedGroupSlotIndex + ? FailureOr(mappedGroupSlotIndex) + : createGroupSlotIndexVector(op->getLoc(), indexType, + selectionGroupSize, baseGroupSlot, + rewriter); + if (failed(groupSlotIndex)) + return rewriter.notifyMatchFailure( + op, "failed to create group_broadcast group-slot index vector"); + results[flatIndex] = + rewriter + .create(op->getLoc(), resultType, + sourceParts[sourceChunk], *groupSlotIndex) + .getResult(); + } + } + return success(); +} + struct OneToNVMIGroupSlotLoadOpPattern : OneToNOpConversionPattern { using OneToNOpConversionPattern< @@ -4850,123 +5246,15 @@ struct OneToNVMIGroupSlotLoadOpPattern rewriter); if (failed(source) || failed(offset) || failed(sourceGroupStride)) return failure(); - if (!isa((*source).getType())) - return rewriter.notifyMatchFailure( - op, "group_slot_load requires !pto.ptr source"); TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); int64_t numGroups = op.getNumGroupsAttr().getInt(); - int64_t slots = layout.getSlots(); - int64_t expectedArity = ceilDivNonNegative(numGroups, slots); - if (static_cast(resultTypes.size()) != expectedArity) - return rewriter.notifyMatchFailure(op, "group_slot_load arity mismatch"); - - auto makeI16 = [&](int64_t value) -> Value { - return rewriter.create(op.getLoc(), value, 16); - }; - Value zeroI16 = makeI16(0); - auto makePtr = [&](Value elementOffset) -> Value { - return rewriter - .create(op.getLoc(), (*source).getType(), *source, - elementOffset) - .getResult(); - }; SmallVector results; - results.reserve(resultTypes.size()); - - if (slots == 8) { - std::optional stride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!stride || *stride != 1) - return rewriter.notifyMatchFailure( - op, "slots=8 group_slot_load requires constant unit stride"); - for (auto [chunk, resultType] : llvm::enumerate(resultTypes)) { - auto vregType = dyn_cast(resultType); - if (!vregType) - return rewriter.notifyMatchFailure( - op, "group_slot_load result must be vreg"); - FailureOr maskType = - getMaskTypeForVReg(vregType, rewriter.getContext()); - if (failed(maskType)) - return rewriter.notifyMatchFailure( - op, "unsupported element type for group_slot_load mask"); - int64_t groupBegin = static_cast(chunk) * slots; - int64_t activeGroups = std::min(slots, numGroups - groupBegin); - if (activeGroups <= 0) - return rewriter.notifyMatchFailure( - op, "slots=8 group_slot_load has no active groups for chunk"); - std::string pattern = (Twine("PAT_VL") + Twine(activeGroups)).str(); - FailureOr slotMask = - createPrefixMask(op.getLoc(), *maskType, pattern, rewriter); - if (failed(slotMask)) - return rewriter.notifyMatchFailure( - op, "failed to create slots=8 group_slot_load mask"); - Value groupOffset = - createChunkOffset(op.getLoc(), *offset, groupBegin, rewriter); - Value slotBase = makePtr(groupOffset); - results.push_back(rewriter - .create(op.getLoc(), vregType, slotBase, - zeroI16, zeroI16, *slotMask) - .getResult()); - } - rewriter.replaceOp(op, results, adaptor.getResultMapping()); - return success(); - } - - if (slots != 1) - return rewriter.notifyMatchFailure( - op, "group_slot_load supports only slots=8 or slots=1"); - unsigned elementBits = - pto::getPTOStorageElemBitWidth(resultVMIType.getElementType()); - if (elementBits == 0 || 256 % elementBits != 0) - return rewriter.notifyMatchFailure( - op, "slots=1 group_slot_load requires supported element width"); - int64_t alignedStrideElems = 256 / elementBits; - std::optional constantStride = - getConstantIndexValue(op.getSourceGroupStride()); - if (!constantStride || *constantStride <= 0 || - *constantStride % alignedStrideElems != 0) - return rewriter.notifyMatchFailure( - op, Twine("slots=1 group_slot_load requires constant positive " - "source_group_stride divisible by ") + - Twine(alignedStrideElems) + - " elements for 32B lane-0 vsldb alignment"); - - for (auto [group, resultType] : llvm::enumerate(resultTypes)) { - auto vregType = dyn_cast(resultType); - if (!vregType) - return rewriter.notifyMatchFailure( - op, "group_slot_load result must be vreg"); - FailureOr maskType = - getMaskTypeForVReg(vregType, rewriter.getContext()); - if (failed(maskType)) - return rewriter.notifyMatchFailure( - op, "unsupported element type for group_slot_load mask"); - FailureOr oneBlockMask = - createPrefixMask(op.getLoc(), *maskType, "PAT_VL1", rewriter); - if (failed(oneBlockMask)) - return rewriter.notifyMatchFailure( - op, "failed to create group_slot_load mask"); - Value groupOffset = *offset; - if (group != 0) { - Value groupIndex = - rewriter.create(op.getLoc(), group); - Value rowOffset = rewriter - .create( - op.getLoc(), *sourceGroupStride, groupIndex) - .getResult(); - groupOffset = - rewriter.create(op.getLoc(), groupOffset, rowOffset) - .getResult(); - } - Value slotBase = makePtr(groupOffset); - results.push_back(rewriter - .create(op.getLoc(), vregType, slotBase, - zeroI16, zeroI16, *oneBlockMask) - .getResult()); - } - + if (failed(lowerGroupSlotLoadParts(op, *source, *offset, *sourceGroupStride, + resultVMIType, resultTypes, numGroups, + rewriter, results))) + return failure(); rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } @@ -5858,13 +6146,87 @@ struct OneToNVMIMaskedStoreOpPattern struct OneToNVMIGroupBroadcastLoadOpPattern : OneToNOpConversionPattern { - using OneToNOpConversionPattern< - VMIGroupBroadcastLoadOp>::OneToNOpConversionPattern; + OneToNVMIGroupBroadcastLoadOpPattern( + TypeConverter &typeConverter, MLIRContext *context, + const VMITargetCapabilityRegistry &capabilities) + : OneToNOpConversionPattern(typeConverter, + context), + capabilities(capabilities) {} LogicalResult matchAndRewrite(VMIGroupBroadcastLoadOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto resultVMIType = cast(op.getResult().getType()); + int64_t numGroups = op.getNumGroupsAttr().getInt(); + FailureOr source = + getSingleValue(op, adaptor.getSource(), + "group_broadcast_load source must convert to one value", + rewriter); + FailureOr offset = + getSingleValue(op, adaptor.getOffset(), + "group_broadcast_load offset must convert to one value", + rewriter); + FailureOr sourceGroupStride = getSingleValue( + op, adaptor.getSourceGroupStride(), + "group_broadcast_load source_group_stride must convert to one value", + rewriter); + if (failed(source) || failed(offset) || failed(sourceGroupStride)) + return failure(); + + VMILayoutSupport supports; + std::string supportReason; + FailureOr support = + supports.getGroupBroadcastLoadSupport(capabilities, op, &supportReason); + if (failed(support)) + return rewriter.notifyMatchFailure( + op, Twine("group_broadcast_load has no registered support: ") + + supportReason); + + TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); + if (support->kind == + VMIGroupBroadcastLoadSupportKind::SlotLoadThenBroadcast) { + std::optional stride = + getConstantIndexValue(op.getSourceGroupStride()); + int64_t slots = (stride && *stride == 1) ? 8 : 1; + auto sourceVMIType = VMIVRegType::get( + rewriter.getContext(), numGroups, resultVMIType.getElementType(), + VMILayoutAttr::getGroupSlots(rewriter.getContext(), numGroups, + slots)); + + FailureOr sourceArity = getVMIPhysicalArity(sourceVMIType); + FailureOr sourceElementType = + getVMIVRegPhysicalElementType(sourceVMIType); + if (failed(sourceArity) || failed(sourceElementType)) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load fallback cannot derive physical types"); + + SmallVector sourceTypes; + sourceTypes.reserve(*sourceArity); + FailureOr sourceLanesPerPart = + getDataLanesPerPart(*sourceElementType); + if (failed(sourceLanesPerPart)) + return rewriter.notifyMatchFailure( + op, "group_broadcast_load fallback cannot derive source lanes"); + for (int64_t i = 0; i < *sourceArity; ++i) + sourceTypes.push_back(VRegType::get(rewriter.getContext(), + *sourceLanesPerPart, + *sourceElementType)); + + SmallVector sourceParts; + if (failed(lowerGroupSlotLoadParts( + op, *source, *offset, *sourceGroupStride, sourceVMIType, + sourceTypes, numGroups, rewriter, sourceParts))) + return failure(); + + SmallVector results; + if (failed(lowerGroupBroadcastParts(op, sourceParts, sourceVMIType, + resultVMIType, resultTypes, + numGroups, rewriter, results))) + return failure(); + rewriter.replaceOp(op, results, adaptor.getResultMapping()); + return success(); + } + VMILayoutAttr layout = resultVMIType.getLayoutAttr(); bool contiguousPacketLayout = layout && layout.isContiguous(); bool splitPacketLayout = layout && layout.isDeinterleaved() && @@ -5886,7 +6248,6 @@ struct OneToNVMIGroupBroadcastLoadOpPattern int64_t directGroupSize = 256 / elementBits; StringRef e2bDist = elementBits == 16 ? "E2B_B16" : "E2B_B32"; - int64_t numGroups = op.getNumGroupsAttr().getInt(); if (numGroups <= 0 || resultVMIType.getElementCount() % numGroups != 0) return rewriter.notifyMatchFailure( op, "group_broadcast_load requires valid num_groups"); @@ -5911,21 +6272,10 @@ struct OneToNVMIGroupBroadcastLoadOpPattern op, "group_broadcast_load E2B lowering requires constant unit " "source_group_stride"); - FailureOr source = - getSingleValue(op, adaptor.getSource(), - "group_broadcast_load source must convert to one value", - rewriter); - FailureOr offset = - getSingleValue(op, adaptor.getOffset(), - "group_broadcast_load offset must convert to one value", - rewriter); - if (failed(source) || failed(offset)) - return failure(); if (!isa((*source).getType())) return rewriter.notifyMatchFailure( op, "group_broadcast_load E2B lowering requires !pto.ptr source"); - TypeRange resultTypes = adaptor.getResultMapping().getConvertedTypes(0); FailureOr chunksPerPart = getDataChunksInPart(resultVMIType, 0); if (failed(chunksPerPart) || *chunksPerPart <= 0) return rewriter.notifyMatchFailure( @@ -5980,6 +6330,9 @@ struct OneToNVMIGroupBroadcastLoadOpPattern rewriter.replaceOp(op, results, adaptor.getResultMapping()); return success(); } + +private: + const VMITargetCapabilityRegistry &capabilities; }; struct OneToNVMIStrideLoadOpPattern @@ -8624,9 +8977,8 @@ void populateVMIOneToNConversionPatterns( OneToNVMIMaskBinaryOpPattern, OneToNVMIMaskUnaryOpPattern, OneToNVMILoadOpPattern, OneToNVMIDeinterleaveLoadOpPattern, OneToNVMIGroupLoadOpPattern, - OneToNVMIGroupSlotLoadOpPattern, OneToNVMIGroupBroadcastLoadOpPattern, - OneToNVMIStrideLoadOpPattern, OneToNVMIMaskedLoadOpPattern, - OneToNVMIGatherOpPattern, + OneToNVMIGroupSlotLoadOpPattern, OneToNVMIStrideLoadOpPattern, + OneToNVMIMaskedLoadOpPattern, OneToNVMIGatherOpPattern, OneToNVMIExpandLoadOpPattern, OneToNVMIStoreOpPattern, OneToNVMIInterleaveStoreOpPattern, OneToNVMIGroupStoreOpPattern, OneToNVMIMaskedStoreOpPattern, OneToNVMIStrideStoreOpPattern, @@ -8666,6 +9018,8 @@ void populateVMIOneToNConversionPatterns( OneToNVMISIToFPOpPattern, OneToNVMIBitcastOpPattern, OneToNVMIChannelSplitOpPattern, OneToNVMIChannelMergeOpPattern, OneToNVMIShuffleOpPattern>(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), capabilities); patterns.add< OneToNVMIGroupReduceOpPattern, OneToNVMIGroupReduceOpPattern, @@ -9464,11 +9818,10 @@ verifySupportedVMIToVPTOOps(ModuleOp module, return WalkResult::advance(); load.emitError() << kVMIDiagUnsupportedPrefix - << "pto.vmi.group_broadcast_load currently lowers through E2B " - "only for b16/b32 contiguous direct group size or " - "deinterleaved=2/block_elems=1 split group size full result " - "chunks, num_groups multiple of 8, unit source_group_stride, " - "and supported UB pointer source (" + << "pto.vmi.group_broadcast_load requires either the E2B packet " + "form for b16/b32 direct or split group size, or the generic " + "group-slot-load then group-broadcast fallback with supported UB " + "pointer source and source_group_stride (" << reason << ")"; return WalkResult::interrupt(); } diff --git a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto index 3de1463671..b8135eeb05 100644 --- a/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto +++ b/test/lit/vmi/vmi_layout_assignment_group_slot_broadcast_load_e2b_b16.pto @@ -6,7 +6,7 @@ // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. // See LICENSE in the root of the software repository for the full text of the License. -// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s +// RUN: pto-test-opt %s -vmi-pre-assignment-combine -vmi-layout-assignment -vmi-to-vpto | FileCheck %s module { func.func @vmi_layout_assignment_group_slot_broadcast_load_e2b_b16( diff --git a/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto b/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto new file mode 100644 index 0000000000..5092f998be --- /dev/null +++ b/test/lit/vmi/vmi_pre_assignment_combine_group_slot_broadcast_load.pto @@ -0,0 +1,78 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-pre-assignment-combine | FileCheck %s +// RUN: pto-test-opt %s -vmi-pre-assignment-combine -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @e2b_candidate(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf16> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 16} + : !pto.ptr -> !pto.vmi.vreg<16xf16> + %out = pto.vmi.group_broadcast %slots {num_groups = 16} + : !pto.vmi.vreg<16xf16> -> !pto.vmi.vreg<256xf16> + return %out : !pto.vmi.vreg<256xf16> + } + + func.func @not_e2b_candidate(%src: !pto.ptr, %off: index) + -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %src[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %out = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } + + func.func @not_e2b_consumer_deint2(%scale: !pto.ptr, + %x: !pto.vmi.vreg<256xf16>, %off: index) + -> !pto.vmi.vreg<256xf32> { + %c1 = arith.constant 1 : index + %slots = pto.vmi.group_slot_load %scale[%off], %c1 {num_groups = 8} + : !pto.ptr -> !pto.vmi.vreg<8xf32> + %sf = pto.vmi.group_broadcast %slots {num_groups = 8} + : !pto.vmi.vreg<8xf32> -> !pto.vmi.vreg<256xf32> + %x32 = pto.vmi.extf %x + : !pto.vmi.vreg<256xf16> -> !pto.vmi.vreg<256xf32> + %out = pto.vmi.mulf %x32, %sf + : !pto.vmi.vreg<256xf32>, !pto.vmi.vreg<256xf32> + -> !pto.vmi.vreg<256xf32> + return %out : !pto.vmi.vreg<256xf32> + } +} + +// CHECK-LABEL: func.func @e2b_candidate +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// CHECK-LABEL: func.func @not_e2b_candidate +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// CHECK-LABEL: func.func @not_e2b_consumer_deint2 +// CHECK-NOT: pto.vmi.group_slot_load +// CHECK: pto.vmi.group_broadcast_load +// CHECK-NOT: pto.vmi.group_broadcast + +// LOWER-LABEL: func.func @e2b_candidate +// LOWER: pto.vlds {{.*}} {dist = "E2B_B16"} + +// LOWER-LABEL: func.func @not_e2b_candidate +// LOWER-NOT: E2B_B32 +// LOWER: pto.vsldb +// LOWER: pto.vselr +// LOWER-NOT: pto.vmi. + +// LOWER-LABEL: func.func @not_e2b_consumer_deint2 +// LOWER-NOT: E2B_B32 +// LOWER: pto.vsldb +// LOWER: pto.vselr +// LOWER-NOT: pto.vmi. diff --git a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto index 3e942b0737..02640d136e 100644 --- a/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_group_broadcast_load_e2b_b16_stride_invalid.pto @@ -21,5 +21,5 @@ module { } // CHECK: VMI-UNSUPPORTED: -// CHECK: pto.vmi.group_broadcast_load currently lowers through E2B -// CHECK: E2B lowering requires constant unit source_group_stride +// CHECK: pto.vmi.group_broadcast_load requires either the E2B packet form +// CHECK: fallback lowering requires constant unit source_group_stride diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 0c69864f66..dd193fea67 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -1808,6 +1808,12 @@ static LogicalResult runVMISemanticPipeline(OwningOpRef &module) { PassManager pm(module->getContext()); pm.enableVerifier(); pm.addPass(pto::createPTOValidateVMIIRPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMIPreAssignmentCombinePass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(pto::createVMILegalizeArithSelectPass()); pm.addPass(pto::createVMILayoutAssignmentPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); From 55d07dfafb69ac9f15644bcecac9c58608ab733a Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 2 Jul 2026 09:50:00 +0800 Subject: [PATCH 52/54] Support compact VMI f32 to fp8 truncf layouts --- lib/PTO/Transforms/VMILayoutAssignment.cpp | 3 +- lib/PTO/Transforms/VMILayoutSupport.cpp | 78 +++++++++++-- lib/PTO/Transforms/VMIToVPTO.cpp | 38 ++++--- .../vmi/opt/compute_single_row_vf_vmi_opt.pto | 104 ++++++++++++++++++ ...vpto_truncf_fp8_128_contiguous_invalid.pto | 25 ----- ...vmi_to_vpto_truncf_fp8_128_lane_stride.pto | 42 +++++++ ..._vpto_truncf_unsupported_shape_invalid.pto | 4 +- 7 files changed, 246 insertions(+), 48 deletions(-) create mode 100644 test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto delete mode 100644 test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto create mode 100644 test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto diff --git a/lib/PTO/Transforms/VMILayoutAssignment.cpp b/lib/PTO/Transforms/VMILayoutAssignment.cpp index 737e204dd2..736bc28924 100644 --- a/lib/PTO/Transforms/VMILayoutAssignment.cpp +++ b/lib/PTO/Transforms/VMILayoutAssignment.cpp @@ -1948,7 +1948,8 @@ struct LayoutSolver { for (Value operand : it->second) results.push_back(operand.getType()); } else { - for (Type type : func.getFunctionType().getResults()) { + FunctionType functionType = func.getFunctionType(); + for (Type type : functionType.getResults()) { if (auto vregType = dyn_cast(type)) { results.push_back(VMIVRegType::get(ctx, vregType.getElementCount(), vregType.getElementType(), diff --git a/lib/PTO/Transforms/VMILayoutSupport.cpp b/lib/PTO/Transforms/VMILayoutSupport.cpp index 0a724f09ba..0a3649fd32 100644 --- a/lib/PTO/Transforms/VMILayoutSupport.cpp +++ b/lib/PTO/Transforms/VMILayoutSupport.cpp @@ -607,7 +607,8 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( *compactSourceArity > *baselineSourceArity || *compactResultArity >= *baselineResultArity) return baseline; - } else { + } else if (!sourceType.getElementType().isF32() || + !isa(resultType.getElementType())) { VMIVRegType baselineSourceType = VMIVRegType::get(ctx, sourceType.getElementCount(), sourceType.getElementType(), baseline->sourceLayout); @@ -622,6 +623,61 @@ FailureOr VMILayoutSupport::getPreferredCastLayoutFact( *compactResultArity > *baselineResultArity || *compactSourceArity >= *baselineSourceArity) return baseline; + } else { + VMICastLayoutFact best = *baseline; + VMIVRegType baselineSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), baseline->sourceLayout); + VMIVRegType baselineResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), baseline->resultLayout); + FailureOr bestSourceArity = getVMIPhysicalArity(baselineSourceType); + FailureOr bestResultArity = getVMIPhysicalArity(baselineResultType); + if (failed(bestSourceArity) || failed(bestResultArity)) + return baseline; + int64_t bestCost = *bestSourceArity + *bestResultArity; + + for (int64_t sourceFactor = 1; sourceFactor <= baseline->factor; + sourceFactor *= 2) { + if (baseline->factor % sourceFactor != 0) + continue; + int64_t resultLaneStride = baseline->factor / sourceFactor; + if (resultLaneStride != 1 && + !hasDenseLaneStridePackUnpackElement(resultType.getElementType(), + resultLaneStride)) + continue; + + VMILayoutAttr sourceLayout = + sourceFactor == 1 + ? VMILayoutAttr::getContiguous(ctx) + : VMILayoutAttr::getDeinterleaved(ctx, sourceFactor, + /*blockElems=*/1); + VMILayoutAttr resultLayout = + VMILayoutAttr::getContiguous(ctx, resultLaneStride); + VMIVRegType candidateSourceType = + VMIVRegType::get(ctx, sourceType.getElementCount(), + sourceType.getElementType(), sourceLayout); + VMIVRegType candidateResultType = + VMIVRegType::get(ctx, resultType.getElementCount(), + resultType.getElementType(), resultLayout); + FailureOr candidateSourceArity = + getVMIPhysicalArity(candidateSourceType); + FailureOr candidateResultArity = + getVMIPhysicalArity(candidateResultType); + if (failed(candidateSourceArity) || failed(candidateResultArity) || + *candidateSourceArity != sourceFactor * *candidateResultArity) + continue; + + int64_t candidateCost = *candidateSourceArity + *candidateResultArity; + if (candidateCost >= bestCost) + continue; + + best = *baseline; + best.sourceLayout = sourceLayout; + best.resultLayout = resultLayout; + bestCost = candidateCost; + } + return best; } VMICastLayoutFact compact = *baseline; @@ -1581,16 +1637,24 @@ VMILayoutSupport::getTruncFSupport(VMITruncFOp op, std::string *reason) const { return fail("unsupported deinterleaved truncf factor, arity, or result " "element width"); - if (sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1 && - resultLayout.isContiguous() && - resultLayout.getLaneStride() == fact->factor && - *sourceArity == *resultArity) { + int64_t sourceFactor = + sourceLayout.isDeinterleaved() ? sourceLayout.getFactor() : 1; + if (((sourceLayout.isContiguous() && sourceLayout.getLaneStride() == 1) || + (sourceLayout.isDeinterleaved() && sourceLayout.getBlockElems() == 1 && + sourceLayout.getLaneStride() == 1)) && + resultLayout.isContiguous() && resultLayout.getLaneStride() > 0 && + sourceFactor * resultLayout.getLaneStride() == fact->factor && + *sourceArity == sourceFactor * *resultArity) { if (fact->kind == VMICastLayoutKind::Narrow2x) return VMITruncFSupport{ - VMITruncFSupportKind::ContiguousF32ToLaneStrideF16}; + resultLayout.getLaneStride() == 1 + ? VMITruncFSupportKind::Deinterleaved2F32ToContiguousF16 + : VMITruncFSupportKind::ContiguousF32ToLaneStrideF16}; if (fact->kind == VMICastLayoutKind::Narrow4x) return VMITruncFSupport{ - VMITruncFSupportKind::ContiguousF32ToLaneStrideF8}; + resultLayout.getLaneStride() == 1 + ? VMITruncFSupportKind::Deinterleaved4F32ToContiguousF8 + : VMITruncFSupportKind::ContiguousF32ToLaneStrideF8}; } if (!sourceLayout.isDeinterleaved() || !resultLayout.isContiguous() || diff --git a/lib/PTO/Transforms/VMIToVPTO.cpp b/lib/PTO/Transforms/VMIToVPTO.cpp index 30b76b2f6f..59768b33d7 100644 --- a/lib/PTO/Transforms/VMIToVPTO.cpp +++ b/lib/PTO/Transforms/VMIToVPTO.cpp @@ -201,10 +201,11 @@ LogicalResult verifyVMIToVPTOInputTypes(Operation *op) { if (failed(verifyLayoutAssignedVMITypeTree(op, type))) return failure(); if (auto func = dyn_cast(op)) { - for (Type type : func.getFunctionType().getInputs()) + FunctionType functionType = func.getFunctionType(); + for (Type type : functionType.getInputs()) if (failed(verifyLayoutAssignedVMITypeTree(op, type))) return failure(); - for (Type type : func.getFunctionType().getResults()) + for (Type type : functionType.getResults()) if (failed(verifyLayoutAssignedVMITypeTree(op, type))) return failure(); } @@ -7938,22 +7939,32 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { return success(); } - ArrayRef parts; + ArrayRef allParts; int64_t factor = 0; - if (resultBits == 16 && sourceParts.size() == 2 * resultTypes.size()) { + if (resultBits == 16) { static constexpr StringRef kEvenOddParts[] = {"EVEN", "ODD"}; - parts = kEvenOddParts; + allParts = kEvenOddParts; factor = 2; - } else if (resultBits == 8 && - sourceParts.size() == 4 * resultTypes.size()) { + } else if (resultBits == 8) { static constexpr StringRef kPacked4Parts[] = {"P0", "P1", "P2", "P3"}; - parts = kPacked4Parts; + allParts = kPacked4Parts; factor = 4; } else { return rewriter.notifyMatchFailure( op, "unsupported physical truncf source/result width relation"); } + int64_t resultLaneStride = + resultLayout && resultLayout.isContiguous() ? resultLayout.getLaneStride() + : 1; + if (resultLaneStride <= 0 || factor % resultLaneStride != 0) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf result lane stride"); + int64_t sourceFactor = factor / resultLaneStride; + if (sourceParts.size() != sourceFactor * resultTypes.size()) + return rewriter.notifyMatchFailure( + op, "unsupported physical truncf source/result arity relation"); + FailureOr sourceMask = createAllTrueMaskForVReg(op.getLoc(), sourceType0, rewriter); if (failed(sourceMask)) @@ -7972,15 +7983,16 @@ struct OneToNVMITruncFOpPattern : OneToNOpConversionPattern { op, "failed to build truncf result mask"); SmallVector partials; - partials.reserve(parts.size()); - for (int64_t partIndex = 0; partIndex < factor; ++partIndex) { + partials.reserve(sourceFactor); + for (int64_t partIndex = 0; partIndex < sourceFactor; ++partIndex) { Value sourcePart = sourceParts[partIndex * resultTypes.size() + chunkIndex]; partials.push_back( rewriter .create(op.getLoc(), resultType, sourcePart, *sourceMask, rnd, sat, - rewriter.getStringAttr(parts[partIndex])) + rewriter.getStringAttr( + allParts[partIndex * resultLaneStride])) .getResult()); } @@ -10315,8 +10327,8 @@ verifySupportedVMIToVPTOOps(ModuleOp module, truncf.emitError() << kVMIDiagUnsupportedPrefix << "pto.vmi.truncf supports only f32 deinterleaved=2 source parts " - "to one contiguous f16 result chunk or f32 deinterleaved=4 " - "source parts to one contiguous fp8-like result chunk, or f32 " + "to dense f16 results, f32 source layouts whose factor times the " + "result lane_stride matches the fp8-like narrowing factor, or f32 " "group_slots(num_groups=G, slots=1) to f16 " "group_slots(num_groups=G, slots=1) (" << reason << ")"; diff --git a/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto b/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto new file mode 100644 index 0000000000..c95d51e506 --- /dev/null +++ b/test/lit/vmi/opt/compute_single_row_vf_vmi_opt.pto @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s | FileCheck %s --check-prefix=VPTO --implicit-check-not=pto.vmi. --implicit-check-not='!pto.vmi' + +// Optimization guard for the ComputeSingleRowVF block-quant path. +// The 128xf32 -> 128xf8 truncf should use a deinterleaved=2 f32 source and a +// lane_stride=2 fp8 result, lowering to P0/P2 only. This keeps the output +// physical register count unchanged while avoiding the extra f32 source parts +// required by deinterleaved=4. + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @ComputeSingleRowVF_fp16_vmi_opt( + %inUb16_u: !pto.ptr, + %scaleUb: !pto.ptr, + %outUb8_u: !pto.ptr, + %fp8MaxValue: f32, + %minScale: f32) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %one = arith.constant 1.000000e+00 : f32 + %pos_inf = arith.constant 0x7F800000 : f32 + + %inUb = pto.castptr %inUb16_u : !pto.ptr -> !pto.ptr + %outUbFp = pto.castptr %outUb8_u : !pto.ptr -> !pto.ptr + %recip_min_scale = arith.divf %one, %minScale : f32 + + pto.vecscope { + %fp8max1 = pto.vmi.broadcast %fp8MaxValue : f32 -> !pto.vmi.vreg<1xf32> + %limit_vec = pto.vmi.broadcast %recip_min_scale : f32 -> !pto.vmi.vreg<128xf32> + %pos_inf_vec = pto.vmi.broadcast %pos_inf : f32 -> !pto.vmi.vreg<128xf32> + + %mask128 = pto.vmi.create_mask %c128 : index -> !pto.vmi.mask<128xpred> + %x16 = pto.vmi.load %inUb[%c0] + : !pto.ptr -> !pto.vmi.vreg<128xf16> + %x = pto.vmi.extf %x16 + : !pto.vmi.vreg<128xf16> -> !pto.vmi.vreg<128xf32> + %abs = pto.vmi.absf %x + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf32> + %amax = pto.vmi.group_reduce_maxf %abs, %mask128 {num_groups = 1} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<1xf32> + %scale_raw = pto.vmi.divf %amax, %fp8max1 + : !pto.vmi.vreg<1xf32>, !pto.vmi.vreg<1xf32> + -> !pto.vmi.vreg<1xf32> + %scale_raw_vec = pto.vmi.group_broadcast %scale_raw {num_groups = 1} + : !pto.vmi.vreg<1xf32> -> !pto.vmi.vreg<128xf32> + %finite_mask = pto.vmi.cmpf "olt", %scale_raw_vec, %pos_inf_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.mask<128xpred> + %scale_clamped_vec = pto.vmi.minf %scale_raw_vec, %limit_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scale_vec = pto.vmi.select %finite_mask, %scale_clamped_vec, %scale_raw_vec + : !pto.vmi.mask<128xpred>, !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %scale = pto.vmi.group_reduce_maxf %scale_vec, %mask128 {num_groups = 1} + : !pto.vmi.vreg<128xf32>, !pto.vmi.mask<128xpred> + -> !pto.vmi.vreg<1xf32> + + pto.vmi.group_store %scale, %scaleUb[%c0], %c1 {num_groups = 1} + : !pto.vmi.vreg<1xf32>, !pto.ptr + %q = pto.vmi.divf %x, %scale_vec + : !pto.vmi.vreg<128xf32>, !pto.vmi.vreg<128xf32> + -> !pto.vmi.vreg<128xf32> + %q8 = pto.vmi.truncf %q + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + pto.vmi.masked_store %q8, %outUbFp[%c0], %mask128 + : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr, + !pto.vmi.mask<128xpred> + } + return + } +} + +// ASSIGN-LABEL: func.func @ComputeSingleRowVF_fp16_vmi_opt( +// ASSIGN: %[[Q:.*]] = pto.vmi.divf {{.*}} : !pto.vmi.vreg<128xf32, #pto.vmi.layout>, !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[Q8:.*]] = pto.vmi.truncf %[[Q]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: %[[STORE:.*]] = pto.vmi.ensure_layout %[[Q8]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.masked_store %[[STORE]] + +// VPTO-LABEL: func.func @ComputeSingleRowVF_fp16_vmi_opt( +// VPTO: pto.vcvt {{.*}} {part = "EVEN"} +// VPTO: pto.vcvt {{.*}} {part = "ODD"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// VPTO-NOT: part = "P1" +// VPTO-NOT: part = "P3" +// VPTO: pto.vor +// VPTO: pto.vpack +// VPTO: pto.vsts diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto deleted file mode 100644 index f4c208aeda..0000000000 --- a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_contiguous_invalid.pto +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -// RUN: not pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto 2>&1 | FileCheck %s - -module { - func.func @vmi_to_vpto_truncf_fp8_128_contiguous_invalid( - %input: !pto.vmi.vreg<128xf32>, - %dst: !pto.ptr, - %off: index) { - %packed = pto.vmi.truncf %input - : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> - pto.vmi.store %packed, %dst[%off] - : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr - return - } -} - -// CHECK: VMI{{-}}UNSUP{{P}}ORTED{{:}} pto.vmi.truncf operand #0 has type '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' but requires '!pto.vmi.vreg<128xf32, #pto.vmi.layout>' -// CHECK: pto.vmi.ensure_layout cannot materialize this conversion diff --git a/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto new file mode 100644 index 0000000000..9ad0755cc4 --- /dev/null +++ b/test/lit/vmi/vmi_to_vpto_truncf_fp8_128_lane_stride.pto @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: pto-test-opt %s -vmi-layout-assignment | FileCheck %s --check-prefix=ASSIGN +// RUN: pto-test-opt %s -vmi-layout-assignment -vmi-to-vpto | FileCheck %s --check-prefix=LOWER + +module { + func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( + %input: !pto.vmi.vreg<128xf32>, + %dst: !pto.ptr, + %off: index) { + %packed = pto.vmi.truncf %input + : !pto.vmi.vreg<128xf32> -> !pto.vmi.vreg<128xf8E4M3FN> + pto.vmi.store %packed, %dst[%off] + : !pto.vmi.vreg<128xf8E4M3FN>, !pto.ptr + return + } +} + +// ASSIGN-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( +// ASSIGN-SAME: %[[INPUT:.*]]: !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[D2:.*]] = pto.vmi.ensure_layout %[[INPUT]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf32, #pto.vmi.layout> +// ASSIGN: %[[PACKED:.*]] = pto.vmi.truncf %[[D2]] +// ASSIGN-SAME: !pto.vmi.vreg<128xf32, #pto.vmi.layout> -> !pto.vmi.vreg<128xf8E4M3FN, #pto.vmi.layout> +// ASSIGN: pto.vmi.store %[[PACKED]] + +// LOWER-LABEL: func.func @vmi_to_vpto_truncf_fp8_128_lane_stride( +// LOWER: pto.vcvt {{.*}} {part = "P0", rnd = "R", sat = "SAT"} +// LOWER: pto.vcvt {{.*}} {part = "P2", rnd = "R", sat = "SAT"} +// LOWER-NOT: part = "P1" +// LOWER-NOT: part = "P3" +// LOWER: pto.vor +// LOWER: pto.vsts +// LOWER-NOT: pto.vmi. +// LOWER-NOT: !pto.vmi. +// LOWER-NOT: unrealized_conversion_cast diff --git a/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto index 9d8cb972aa..324b01ea5b 100644 --- a/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto +++ b/test/lit/vmi/vmi_to_vpto_truncf_unsupported_shape_invalid.pto @@ -19,5 +19,5 @@ module { } // CHECK: VMI{{-}}UNSUPPORTED{{:}} pto.vmi.truncf supports only f32 deinterleaved=2 source parts -// CHECK-SAME: one contiguous f16 result chunk -// CHECK-SAME: f32 deinterleaved=4 source parts to one contiguous fp8-like result chunk +// CHECK-SAME: dense f16 results +// CHECK-SAME: f32 source layouts whose factor times the result lane_stride matches the fp8-like narrowing factor From 87fc3fade9eba7ac5a8e4657d78a77dd58929806 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Thu, 2 Jul 2026 14:29:37 +0800 Subject: [PATCH 53/54] Support low-precision VPTO vor lowering --- lib/PTO/IR/VPTO.cpp | 10 ++- lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp | 90 ++++++++++++++++--- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 82 +++++++++++++++-- test/lit/vpto/vmi_truncf_hif8.pto | 22 +++-- test/lit/vpto/vor_low_precision_vpto_llvm.pto | 29 ++++++ 5 files changed, 204 insertions(+), 29 deletions(-) create mode 100644 test/lit/vpto/vor_low_precision_vpto_llvm.pto diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index b2e887e775..fe6a741f50 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -5288,7 +5288,8 @@ LogicalResult VreluOp::verify() { LogicalResult VnotOp::verify() { return verifyUnaryVecOp(*this); } template -static LogicalResult verifyBinaryVecOp(BinaryOp op) { +static LogicalResult verifyBinaryVecOp(BinaryOp op, + bool allowLowPrecision = false) { if (failed(verifyVRegTypeLike(op, op.getLhs().getType(), "lhs type"))) return failure(); if (failed(verifyVRegTypeLike(op, op.getRhs().getType(), "rhs type"))) @@ -5297,7 +5298,8 @@ static LogicalResult verifyBinaryVecOp(BinaryOp op) { return failure(); if (failed(verifyVRegTypeLike(op, op.getResult().getType(), "result type"))) return failure(); - if (failed(verifyNonLowPrecisionVRegElementTypeLike( + if (!allowLowPrecision && + failed(verifyNonLowPrecisionVRegElementTypeLike( op.getOperation(), op.getLhs().getType(), "lhs type"))) return failure(); if (op.getLhs().getType() != op.getRhs().getType() || @@ -5311,7 +5313,9 @@ LogicalResult VsubOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VmulOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VdivOp::verify() { return verifyBinaryVecOp(*this); } LogicalResult VandOp::verify() { return verifyBinaryVecOp(*this); } -LogicalResult VorOp::verify() { return verifyBinaryVecOp(*this); } +LogicalResult VorOp::verify() { + return verifyBinaryVecOp(*this, /*allowLowPrecision=*/true); +} LogicalResult VxorOp::verify() { return verifyBinaryVecOp(*this); } template diff --git a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp index ee1da2c13f..ec8c2595da 100644 --- a/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOCANN900LLVMEmitter.cpp @@ -574,6 +574,44 @@ getLowpPayloadABI(Type elementType, MLIRContext *context) { return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; } +static std::string getDirectLowpVLogicElementFragment(Type type) { + if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return "fp8e4m3"; + if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + return "fp8e5m2"; + return {}; +} + +static FailureOr +buildDirectLowpVLogicCallee(MLIRContext *context, Type vectorType, + StringRef stem, StringRef mode) { + Type elementType = getElementTypeFromVectorLike(vectorType); + auto lanes = getElementCountFromVectorLike(vectorType); + std::string elem = getDirectLowpVLogicElementFragment(elementType); + if (elem.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + mode.str() + ".v" + + std::to_string(*lanes) + elem) + .getValue(); +} + +static FailureOr +buildLowpPayloadVLogicCallee(MLIRContext *context, Type vectorType, + StringRef stem, StringRef mode) { + Type elementType = getElementTypeFromVectorLike(vectorType); + auto lanes = getElementCountFromVectorLike(vectorType); + std::optional abi = getLowpPayloadABI(elementType, context); + if (!abi || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + mode.str() + ".v" + + std::to_string(*lanes) + + abi->intrinsicElementFragment.str()) + .getValue(); +} + static Type getLowpPayloadCarrierType(Type vectorLikeType, MLIRContext *context) { Type elementType = getElementTypeFromVectorLike(vectorLikeType); @@ -4365,15 +4403,6 @@ class LowerBinaryMaskedOpPattern final : public OpConversionPattern { matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef stem = getBinaryMaskedStem(); - FailureOr calleeName = - usesSignedBinaryCANN900Callee() - ? buildCANN900SignedModeTypedCallee( - op.getContext(), op.getResult().getType(), stem, "x") - : buildCANN900ModeTypedCallee(op.getContext(), - op.getResult().getType(), stem, "x"); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) @@ -4391,12 +4420,49 @@ class LowerBinaryMaskedOpPattern final : public OpConversionPattern { op, "unexpected converted binary VPTO operand types"); } + Type callResultType = resultType; + Value callLhs = lhs; + Value callRhs = rhs; + FailureOr calleeName = + usesSignedBinaryCANN900Callee() + ? buildCANN900SignedModeTypedCallee( + op.getContext(), op.getResult().getType(), stem, "x") + : buildCANN900ModeTypedCallee(op.getContext(), + op.getResult().getType(), stem, "x"); + + if constexpr (std::is_same_v) { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (elementType && pto::isPTOLowPrecisionType(elementType)) { + calleeName = buildDirectLowpVLogicCallee( + op.getContext(), op.getResult().getType(), stem, "x"); + if (failed(calleeName)) { + Type carrierType = getLowpPayloadCarrierType( + op.getResult().getType(), rewriter.getContext()); + if (!carrierType) + return rewriter.notifyMatchFailure( + op, "unsupported low-precision binary payload ABI"); + callResultType = carrierType; + callLhs = castToPayloadABI(op.getLoc(), lhs, + op.getResult().getType(), rewriter); + callRhs = castToPayloadABI(op.getLoc(), rhs, + op.getResult().getType(), rewriter); + calleeName = buildLowpPayloadVLogicCallee( + op.getContext(), op.getResult().getType(), stem, "x"); + } + } + } + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + auto call = rewriter.create(op.getLoc(), *calleeName, - TypeRange{resultType}, - ValueRange{lhs, rhs, mask}); + TypeRange{callResultType}, + ValueRange{callLhs, callRhs, mask}); state.plannedDecls.push_back( PlannedDecl{calleeName->str(), call.getCalleeType()}); - rewriter.replaceOp(op, call.getResults()); + Value result = call.getResult(0); + if (callResultType != resultType) + result = rewriter.create(op.getLoc(), resultType, result); + rewriter.replaceOp(op, result); return success(); } diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 2de59cc620..d93fbdfc7a 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -529,6 +529,44 @@ getLowpPayloadABI(Type elementType, MLIRContext *context) { return LowpPayloadABI{IntegerType::get(context, 8), "u8"}; } +static std::string getDirectLowpVLogicElementFragment(Type type) { + if (type.isFloat8E4M3() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return "fp8e4m3"; + if (type.isFloat8E5M2() || type.isFloat8E5M2FNUZ()) + return "fp8e5m2"; + return {}; +} + +static FailureOr +buildDirectLowpVLogicCallee(MLIRContext *context, Type vectorType, + StringRef stem, StringRef mode) { + Type elementType = getElementTypeFromVectorLike(vectorType); + auto lanes = getElementCountFromVectorLike(vectorType); + std::string elem = getDirectLowpVLogicElementFragment(elementType); + if (elem.empty() || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + "." + + mode.str() + ".v" + + std::to_string(*lanes) + elem) + .getValue(); +} + +static FailureOr +buildLowpPayloadVLogicCallee(MLIRContext *context, Type vectorType, + StringRef stem, StringRef mode) { + Type elementType = getElementTypeFromVectorLike(vectorType); + auto lanes = getElementCountFromVectorLike(vectorType); + std::optional abi = getLowpPayloadABI(elementType, context); + if (!abi || !lanes) + return failure(); + return StringAttr::get(context, "llvm.hivm." + stem.str() + ".v" + + std::to_string(*lanes) + + abi->intrinsicElementFragment.str() + + "." + mode.str()) + .getValue(); +} + static Type getLowpPayloadCarrierType(Type vectorLikeType, MLIRContext *context) { Type elementType = getElementTypeFromVectorLike(vectorLikeType); @@ -4307,11 +4345,6 @@ class LowerBinaryMaskedOpPattern final : public OpConversionPattern { matchAndRewrite(BinaryOp op, typename BinaryOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef stem = getBinaryMaskedStem(); - FailureOr calleeName = - buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); - if (failed(calleeName)) - return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); - Type resultType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!resultType) @@ -4329,12 +4362,45 @@ class LowerBinaryMaskedOpPattern final : public OpConversionPattern { op, "unexpected converted binary VPTO operand types"); } + Type callResultType = resultType; + Value callLhs = lhs; + Value callRhs = rhs; + FailureOr calleeName = + buildLaneTypedCallee(op.getContext(), op.getResult().getType(), stem, ".x"); + + if constexpr (std::is_same_v) { + Type elementType = getElementTypeFromVectorLike(op.getResult().getType()); + if (elementType && pto::isPTOLowPrecisionType(elementType)) { + calleeName = buildDirectLowpVLogicCallee( + op.getContext(), op.getResult().getType(), stem, "x"); + if (failed(calleeName)) { + Type carrierType = getLowpPayloadCarrierType( + op.getResult().getType(), rewriter.getContext()); + if (!carrierType) + return rewriter.notifyMatchFailure( + op, "unsupported low-precision binary payload ABI"); + callResultType = carrierType; + callLhs = castToPayloadABI(op.getLoc(), lhs, + op.getResult().getType(), rewriter); + callRhs = castToPayloadABI(op.getLoc(), rhs, + op.getResult().getType(), rewriter); + calleeName = buildLowpPayloadVLogicCallee( + op.getContext(), op.getResult().getType(), stem, "x"); + } + } + } + if (failed(calleeName)) + return rewriter.notifyMatchFailure(op, "unsupported binary VPTO signature"); + auto call = rewriter.create(op.getLoc(), *calleeName, - TypeRange{resultType}, - ValueRange{lhs, rhs, mask}); + TypeRange{callResultType}, + ValueRange{callLhs, callRhs, mask}); state.plannedDecls.push_back( PlannedDecl{calleeName->str(), call.getCalleeType()}); - rewriter.replaceOp(op, call.getResults()); + Value result = call.getResult(0); + if (callResultType != resultType) + result = rewriter.create(op.getLoc(), resultType, result); + rewriter.replaceOp(op, result); return success(); } diff --git a/test/lit/vpto/vmi_truncf_hif8.pto b/test/lit/vpto/vmi_truncf_hif8.pto index a638759e01..174c2e09a1 100644 --- a/test/lit/vpto/vmi_truncf_hif8.pto +++ b/test/lit/vpto/vmi_truncf_hif8.pto @@ -9,13 +9,23 @@ // RUN: ptoas --pto-arch=a5 --pto-backend=vpto --enable-vmi --emit-vpto %s -o - 2>/dev/null | FileCheck %s // CHECK-LABEL: func.func @vmi_truncf_hif8_default_kernel -// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK-NOT: part = "P1" -// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask +// CHECK: %[[A_P0:.*]] = pto.vcvt {{.*}} {part = "P0", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[A_P1:.*]] = pto.vcvt {{.*}} {part = "P1", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[A_P2:.*]] = pto.vcvt {{.*}} {part = "P2", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[A_P3:.*]] = pto.vcvt {{.*}} {part = "P3", rnd = "A", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[A_M01:.*]] = pto.vor %[[A_P0]], %[[A_P1]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[A_M012:.*]] = pto.vor %[[A_M01]], %[[A_P2]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vor %[[A_M012]], %[[A_P3]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vsts {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask // CHECK-LABEL: func.func @vmi_truncf_hif8_hybrid_kernel -// CHECK: pto.vcvt {{.*}} {part = "P0", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> -// CHECK-NOT: part = "P1" -// CHECK: pto.vsts {{.*}} {dist = "PK4_B32"} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask +// CHECK: %[[H_P0:.*]] = pto.vcvt {{.*}} {part = "P0", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[H_P1:.*]] = pto.vcvt {{.*}} {part = "P1", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[H_P2:.*]] = pto.vcvt {{.*}} {part = "P2", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[H_P3:.*]] = pto.vcvt {{.*}} {part = "P3", rnd = "H", sat = "SAT"} : !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[H_M01:.*]] = pto.vor %[[H_P0]], %[[H_P1]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: %[[H_M012:.*]] = pto.vor %[[H_M01]], %[[H_P2]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vor %[[H_M012]], %[[H_P3]], {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.vreg<256x!pto.hif8>, !pto.mask -> !pto.vreg<256x!pto.hif8> +// CHECK: pto.vsts {{.*}} : !pto.vreg<256x!pto.hif8>, !pto.ptr, !pto.mask module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { func.func @vmi_truncf_hif8_default_kernel(%src_gm: !pto.ptr, diff --git a/test/lit/vpto/vor_low_precision_vpto_llvm.pto b/test/lit/vpto/vor_low_precision_vpto_llvm.pto new file mode 100644 index 0000000000..582fe4e3e8 --- /dev/null +++ b/test/lit/vpto/vor_low_precision_vpto_llvm.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --pto-arch=a5 --pto-backend=vpto --emit-vpto-llvm-ir %s -o - 2>&1 | FileCheck %s + +module attributes {pto.target_arch = "a5", pto.kernel_kind = #pto.kernel_kind} { + func.func @vor_lowp_direct(%lhs_ptr: !pto.ptr, + %rhs_ptr: !pto.ptr, + %dst_ptr: !pto.ptr) attributes {pto.kernel} { + %c0 = arith.constant 0 : index + pto.vecscope { + %mask = pto.pset_b8 "PAT_ALL" : !pto.mask + %lhs = pto.vlds %lhs_ptr[%c0] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> + %rhs = pto.vlds %rhs_ptr[%c0] : !pto.ptr -> !pto.vreg<256xf8E4M3FN> + %out = pto.vor %lhs, %rhs, %mask : !pto.vreg<256xf8E4M3FN>, !pto.vreg<256xf8E4M3FN>, !pto.mask -> !pto.vreg<256xf8E4M3FN> + pto.vsts %out, %dst_ptr[%c0], %mask : !pto.vreg<256xf8E4M3FN>, !pto.ptr, !pto.mask + } + return + } +} + +// CHECK: declare <256 x float8e4m3> @llvm.hivm.vor.x.v256fp8e4m3(<256 x float8e4m3>, <256 x float8e4m3>, <256 x i1>) +// CHECK-LABEL: define void @vor_lowp_direct_mix_aiv +// CHECK: call <256 x float8e4m3> @llvm.hivm.vor.x.v256fp8e4m3 From 440216c62e4dd28ed8f85e14de7a40bf381cb01c Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Thu, 2 Jul 2026 00:21:44 +0800 Subject: [PATCH 54/54] feat(ptodsl): support vmi backend in ptodsl --- ptodsl/ptodsl/_runtime/native_build.py | 7 ++ ptodsl/tests/test_jit_compile.py | 26 +++++ .../broadcast-dense-group-users/compare.py | 40 -------- .../vmi/broadcast-dense-group-users/golden.py | 47 --------- .../vmi/broadcast-dense-group-users/kernel.py | 88 +++++++++++++++++ .../broadcast-dense-group-users/launch.cpp | 33 ------- .../vmi/broadcast-dense-group-users/main.cpp | 97 ------------------- .../broadcast-dense-group-users/ptoas.flags | 1 - 8 files changed, 121 insertions(+), 218 deletions(-) delete mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/compare.py delete mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/golden.py create mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/kernel.py delete mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp delete mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp delete mode 100644 test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py index 0821c326be..4b9e09f59d 100644 --- a/ptodsl/ptodsl/_runtime/native_build.py +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -39,6 +39,11 @@ def _run(cmd: list[str], *, cwd: Path | None = None) -> None: ) +def _mlir_requires_enable_vmi(mlir_path: Path) -> bool: + text = mlir_path.read_text(encoding="utf-8") + return "pto.vmi." in text or "!pto.vmi." in text + + def _run_ptoas( mlir_path: Path, kernel_object: Path, @@ -59,6 +64,8 @@ def _run_ptoas( cmd.append(f"--pto-level={pto_level}") if insert_sync is True: cmd.append("--enable-insert-sync") + if _mlir_requires_enable_vmi(mlir_path): + cmd.append("--enable-vmi") cmd.extend([ "--enable-tile-op-expand", str(mlir_path), diff --git a/ptodsl/tests/test_jit_compile.py b/ptodsl/tests/test_jit_compile.py index b3ab3694c6..cbc522690b 100644 --- a/ptodsl/tests/test_jit_compile.py +++ b/ptodsl/tests/test_jit_compile.py @@ -3849,6 +3849,32 @@ def fake_run_ptoas_cmd(cmd, *, cwd=None): "--enable-insert-sync" in source_ptoas_cmd, "source-backed native build should still pass explicit/effective insert-sync to ptoas", ) + ptoas_cmds.clear() + vmi_mlir_text = ( + 'module attributes {pto.target_arch = "a5"} {\n' + ' module attributes {pto.backend = "vpto", pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} {\n' + " func.func @vmi_probe(%arg0: !pto.ptr) {\n" + ' %c0 = arith.constant 0 : index\n' + ' %x = pto.vmi.load %arg0[%c0] : !pto.ptr -> !pto.vmi.vreg<256xf32>\n' + " return\n" + " }\n" + " }\n" + "}\n" + ) + mlir_path.write_text(vmi_mlir_text, encoding="utf-8") + with mock.patch.object(native_build_runtime, "resolve_ptoas_binary", return_value=Path("/tmp/fake-ptoas")), mock.patch.object( + native_build_runtime, "_run", side_effect=fake_run_ptoas_cmd + ): + native_build_runtime._run_ptoas( + mlir_path, + kernel_object, + target_arch="a5", + ) + expect(len(ptoas_cmds) == 1, "native build should issue exactly one ptoas command for VMI MLIR") + expect( + "--enable-vmi" in ptoas_cmds[0], + "native build should auto-enable the VMI semantic pipeline when the MLIR contains VMI ops", + ) expect("valid=?" not in default_text, "default alloc_tile() should keep full static valid-shape when valid_shape= is omitted") auto_mode_violation = expect_raises( RuntimeError, diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py b/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py deleted file mode 100644 index 9f34394fa1..0000000000 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/compare.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import sys - -import numpy as np - - -def check(name: str, golden_name: str) -> None: - golden = np.fromfile(golden_name, dtype=np.float32) - output = np.fromfile(name, dtype=np.float32) - if golden.shape == output.shape and np.allclose(golden, output, atol=1e-4, rtol=1e-4): - return - if golden.shape != output.shape: - print(f"[ERROR] compare failed {name}: shape golden={golden.shape} output={output.shape}") - sys.exit(2) - diff = np.nonzero(~np.isclose(golden, output, atol=1e-4, rtol=1e-4))[0] - idx = int(diff[0]) if diff.size else -1 - print( - f"[ERROR] compare failed {name} idx={idx} " - f"golden={golden[idx] if idx >= 0 else 'n/a'} " - f"output={output[idx] if idx >= 0 else 'n/a'}" - ) - sys.exit(2) - - -def main() -> None: - check("v2.bin", "golden_v2.bin") - check("v3.bin", "golden_v3.bin") - print("[INFO] compare passed") - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py b/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py deleted file mode 100644 index 7df1eedef3..0000000000 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/golden.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2026 Huawei Technologies Co., Ltd. -# This program is free software, you can redistribute it and/or modify it under the terms and conditions of -# CANN Open Software License Agreement Version 2.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. - -import argparse -from pathlib import Path - -import numpy as np - -ROWS = 8 -COLS = 32 -SCALE = np.float32(0.5) -SENTINEL = np.float32(-777.0) - - -def generate(output_dir: Path) -> None: - base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) - src = np.empty((ROWS, COLS), dtype=np.float32) - for row in range(ROWS): - src[row, :] = base + np.float32(row) * np.float32(0.03125) - copy = np.full((ROWS, COLS), SENTINEL, dtype=np.float32) - sums = np.full(ROWS, SENTINEL, dtype=np.float32) - golden_copy = src + SCALE - golden_sum = np.sum(src * SCALE, axis=1, dtype=np.float32).astype(np.float32) - - output_dir.mkdir(parents=True, exist_ok=True) - src.reshape(-1).tofile(output_dir / "v1.bin") - copy.reshape(-1).tofile(output_dir / "v2.bin") - sums.tofile(output_dir / "v3.bin") - golden_copy.reshape(-1).astype(np.float32).tofile(output_dir / "golden_v2.bin") - golden_sum.astype(np.float32).tofile(output_dir / "golden_v3.bin") - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--output-dir", type=Path, default=Path(".")) - args = parser.parse_args() - generate(args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.py b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.py new file mode 100644 index 0000000000..7d51394b1c --- /dev/null +++ b/test/vpto/cases/vmi/broadcast-dense-group-users/kernel.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from pathlib import Path +import sys + +import numpy as np + + +def _bootstrap_dsl_st_common() -> None: + here = Path(__file__).resolve() + for candidate in here.parents: + common_dir = candidate / "test" / "dsl-st" + if (common_dir / "common.py").exists(): + sys.path.insert(0, str(common_dir)) + return + raise RuntimeError("Unable to locate test/dsl-st/common.py from broadcast-dense-group-users kernel.py") + + +_bootstrap_dsl_st_common() + +from common import auto_main +from ptodsl import pto + + +ROWS = 8 +COLS = 32 +SCALE = np.float32(0.5) + + +@pto.jit( + name="vmi_broadcast_dense_group_users_kernel", + target="a5", + backend="vpto", + mode="explicit", + source="kernel.pto", +) +def vmi_broadcast_dense_group_users_kernel( + src_gm: pto.ptr(pto.f32, "gm"), + copy_gm: pto.ptr(pto.f32, "gm"), + sum_gm: pto.ptr(pto.f32, "gm"), +): + pass + + +def make_inputs(): + base = np.linspace(-0.875, 0.625, COLS, dtype=np.float32) + src = np.empty((ROWS, COLS), dtype=np.float32) + for row in range(ROWS): + src[row, :] = base + np.float32(row) * np.float32(0.03125) + copy = np.zeros((ROWS, COLS), dtype=np.float32) + sums = np.zeros(ROWS, dtype=np.float32) + return [src.reshape(-1), copy.reshape(-1), sums] + + +def make_case(): + host_inputs = make_inputs() + src = host_inputs[0].reshape(ROWS, COLS) + golden_copy = (src + SCALE).astype(np.float32).reshape(-1) + golden_sum = np.sum(src * SCALE, axis=1, dtype=np.float32).astype(np.float32) + return host_inputs, (golden_copy, golden_sum) + + +def check_case(device_inputs, golden): + golden_copy, golden_sum = golden + actual_copy = device_inputs[1].cpu().numpy() + actual_sum = device_inputs[2].cpu().numpy() + np.testing.assert_allclose(actual_copy, golden_copy, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(actual_sum, golden_sum, rtol=1e-4, atol=1e-4) + + +CASES = [ + { + "name": "vmi_broadcast_dense_group_users", + "kernel": vmi_broadcast_dense_group_users_kernel, + "make_case": make_case, + "check": check_case, + }, +] + + +auto_main(globals()) diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp deleted file mode 100644 index 21e26d6cf5..0000000000 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/launch.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#ifndef __VEC_SCOPE__ -#define __VEC_SCOPE__ -#endif -#include -#if !defined(__CCE_AICORE__) && !defined(TMRGSORT_HPP) -struct MrgSortExecutedNumList { - uint16_t mrgSortList0; - uint16_t mrgSortList1; - uint16_t mrgSortList2; - uint16_t mrgSortList3; -}; -#endif -#ifndef __CPU_SIM -#include "acl/acl.h" -#endif - -extern "C" __global__ [aicore] void -vmi_broadcast_dense_group_users_kernel(__gm__ float *src, __gm__ float *copy, - __gm__ float *sum); - -void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, - float *sum, void *stream) { - vmi_broadcast_dense_group_users_kernel<<<1, nullptr, stream>>>( - (__gm__ float *)src, (__gm__ float *)copy, (__gm__ float *)sum); -} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp b/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp deleted file mode 100644 index b43a794cdb..0000000000 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/main.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2026 Huawei Technologies Co., Ltd. -// This program is free software, you can redistribute it and/or modify it under the terms and conditions of -// CANN Open Software License Agreement Version 2.0 (the "License"). -// Please refer to the License for details. You may not use this file except in compliance with the License. -// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// See LICENSE in the root of the software repository for the full text of the License. - -#include "acl/acl.h" -#include "test_common.h" -#include -#include - -using namespace PtoTestCommon; - -#define ACL_CHECK(expr) \ - do { \ - const aclError _ret = (expr); \ - if (_ret != ACL_SUCCESS) { \ - std::fprintf(stderr, "[ERROR] %s failed: %d (%s:%d)\n", #expr, \ - (int)_ret, __FILE__, __LINE__); \ - rc = 1; \ - goto cleanup; \ - } \ - } while (0) - -void LaunchVmi_broadcast_dense_group_users_kernel(float *src, float *copy, - float *sum, void *stream); - -int main() { - constexpr size_t kRows = 8; - constexpr size_t kCols = 32; - constexpr size_t kSrcElems = kRows * kCols; - constexpr size_t kSumElems = kRows; - size_t srcBytes = kSrcElems * sizeof(float); - size_t copyBytes = kSrcElems * sizeof(float); - size_t sumBytes = kSumElems * sizeof(float); - float *srcHost = nullptr; - float *copyHost = nullptr; - float *sumHost = nullptr; - float *srcDevice = nullptr; - float *copyDevice = nullptr; - float *sumDevice = nullptr; - int rc = 0; - bool aclInited = false; - bool deviceSet = false; - int deviceId = 0; - aclrtStream stream = nullptr; - - ACL_CHECK(aclInit(nullptr)); - aclInited = true; - if (const char *envDevice = std::getenv("ACL_DEVICE_ID")) - deviceId = std::atoi(envDevice); - ACL_CHECK(aclrtSetDevice(deviceId)); - deviceSet = true; - ACL_CHECK(aclrtCreateStream(&stream)); - ACL_CHECK(aclrtMallocHost((void **)(&srcHost), srcBytes)); - ACL_CHECK(aclrtMallocHost((void **)(©Host), copyBytes)); - ACL_CHECK(aclrtMallocHost((void **)(&sumHost), sumBytes)); - ACL_CHECK(aclrtMalloc((void **)&srcDevice, srcBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)©Device, copyBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc((void **)&sumDevice, sumBytes, ACL_MEM_MALLOC_HUGE_FIRST)); - - ReadFile("./v1.bin", srcBytes, srcHost, srcBytes); - ReadFile("./v2.bin", copyBytes, copyHost, copyBytes); - ReadFile("./v3.bin", sumBytes, sumHost, sumBytes); - ACL_CHECK(aclrtMemcpy(srcDevice, srcBytes, srcHost, srcBytes, - ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(copyDevice, copyBytes, copyHost, copyBytes, - ACL_MEMCPY_HOST_TO_DEVICE)); - ACL_CHECK(aclrtMemcpy(sumDevice, sumBytes, sumHost, sumBytes, - ACL_MEMCPY_HOST_TO_DEVICE)); - LaunchVmi_broadcast_dense_group_users_kernel(srcDevice, copyDevice, sumDevice, - stream); - ACL_CHECK(aclrtSynchronizeStream(stream)); - ACL_CHECK(aclrtMemcpy(copyHost, copyBytes, copyDevice, copyBytes, - ACL_MEMCPY_DEVICE_TO_HOST)); - ACL_CHECK(aclrtMemcpy(sumHost, sumBytes, sumDevice, sumBytes, - ACL_MEMCPY_DEVICE_TO_HOST)); - WriteFile("./v2.bin", copyHost, copyBytes); - WriteFile("./v3.bin", sumHost, sumBytes); - -cleanup: - aclrtFree(srcDevice); - aclrtFree(copyDevice); - aclrtFree(sumDevice); - aclrtFreeHost(srcHost); - aclrtFreeHost(copyHost); - aclrtFreeHost(sumHost); - if (stream) - aclrtDestroyStream(stream); - if (deviceSet) - aclrtResetDevice(deviceId); - if (aclInited) - aclFinalize(); - return rc; -} diff --git a/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags b/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags deleted file mode 100644 index a79aede1ca..0000000000 --- a/test/vpto/cases/vmi/broadcast-dense-group-users/ptoas.flags +++ /dev/null @@ -1 +0,0 @@ ---pto-arch a5 --pto-backend=vpto --enable-vmi