From 901b2e372f81dd0ebc0c2f188dec336a85350f86 Mon Sep 17 00:00:00 2001 From: jvjhfhg Date: Mon, 18 May 2026 17:35:53 +0800 Subject: [PATCH] Add: SDMA workspace overlay + async completion demo on a5 onboard Layers the host-side SDMA workspace allocation on top of the comm backend from the previous commit. Until CANN exposes the missing SDMA primitives on a5, this overlay is the only piece of comm work that fails on real a5 silicon -- aclnnShmemSdmaStarsQuery raises an AICPU exception (InnerCode=0x715002a) that aborts the entire ACL thread context. Dropping this commit therefore unblocks the non-SDMA comm demos (async_notify_demo etc.) without touching the deferred-completion runtime, which is already SDMA-aware on the kernel side (dormant until a kernel registers an SDMA condition). - Wire SdmaWorkspaceManager into comm_alloc_windows under SIMPLER_ENABLE_PTO_SDMA_WORKSPACE: pre-allocates the per-rank workspace via aclnnShmemSdmaStarsQuery and overlays the result into CommContext.workSpace/.workSpaceSize. On CANN 8.5 the dlsym fails by design and we demote to "no workspace" rather than failing comm_init. - a5 onboard CMakeLists forces SIMPLER_ENABLE_PTO_SDMA_WORKSPACE ON, requires PTO_ISA_ROOT (with FATAL_ERROR message pointing to the workspace coupling), adds pto-isa headers to the include path, and links libnnopbase. - runtime_compiler._init_a5 enforces the same PTO_ISA_ROOT env contract as _init_a2a3. - Migrate sdma_async_completion_demo to examples/a5/ (kernels + orch byte-identical with the a2a3 version; test.py platform- renamed). --- .../kernels/aiv/kernel_consumer.cpp | 64 ++++++ .../kernels/aiv/kernel_sdma_tget_async.cpp | 71 ++++++ .../sdma_async_completion_orch.cpp | 52 +++++ .../test_sdma_async_completion_demo.py | 210 ++++++++++++++++++ simpler_setup/runtime_compiler.py | 5 + src/a5/platform/onboard/host/CMakeLists.txt | 24 ++ src/a5/platform/onboard/host/comm_hccl.cpp | 10 +- 7 files changed, 433 insertions(+), 3 deletions(-) create mode 100644 examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_consumer.cpp create mode 100644 examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_sdma_tget_async.cpp create mode 100644 examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/orchestration/sdma_async_completion_orch.cpp create mode 100644 examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/test_sdma_async_completion_demo.py diff --git a/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_consumer.cpp b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_consumer.cpp new file mode 100644 index 000000000..c0d698bcf --- /dev/null +++ b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_consumer.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include + +#include "tensor.h" + +using namespace pto; + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *src_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *result_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + + __gm__ float *src = reinterpret_cast<__gm__ float *>(src_tensor->buffer.addr) + src_tensor->start_offset; + __gm__ float *result = reinterpret_cast<__gm__ float *>(result_tensor->buffer.addr) + result_tensor->start_offset; + + constexpr int kTotalRows = 128; + constexpr int kRows = 64; + constexpr int kCols = 128; + constexpr int kIters = kTotalRows / kRows; + using DynShapeDim5 = Shape<1, 1, 1, kRows, kCols>; + using DynStrideDim5 = pto::Stride<1, 1, 1, kCols, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData src_tile(kRows, kCols); + TileData result_tile(kRows, kCols); + TASSIGN(src_tile, 0x0); + TASSIGN(result_tile, 0x10000); + + constexpr int kChunkElems = kRows * kCols; + for (int iter = 0; iter < kIters; ++iter) { + GlobalData src_global(src + iter * kChunkElems); + GlobalData result_global(result + iter * kChunkElems); + TLOAD(src_tile, src_global); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TADDS(result_tile, src_tile, 1.0f); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(result_global, result_tile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + } +} diff --git a/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_sdma_tget_async.cpp b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_sdma_tget_async.cpp new file mode 100644 index 000000000..475e35b6d --- /dev/null +++ b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/aiv/kernel_sdma_tget_async.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include + +#include "backend/sdma/sdma_completion_kernel.h" +#include "platform_comm/comm_context.h" +#include "pto_async_kernel_api.h" +#include "tensor.h" + +using namespace pto; + +template +static inline __aicore__ __gm__ T *comm_remote_ptr(__gm__ CommContext *ctx, __gm__ T *local_ptr, int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = reinterpret_cast(local_ptr) - local_base; + return reinterpret_cast<__gm__ T *>(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *in_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *out_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ CommContext *comm_ctx = reinterpret_cast<__gm__ CommContext *>(args[2]); + + __gm__ float *local_in = reinterpret_cast<__gm__ float *>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float *local_out = reinterpret_cast<__gm__ float *>(out_tensor->buffer.addr) + out_tensor->start_offset; + + int rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + if (nranks != 2 || comm_ctx->workSpace == 0) { + pipe_barrier(PIPE_ALL); + return; + } + int peer_rank = 1 - rank; + + constexpr int kElems = 128 * 128; + using FlatShape = Shape<1, 1, 1, 1, kElems>; + using FlatStride = pto::Stride; + using GlobalData = GlobalTensor; + using ScratchTile = Tile; + + __gm__ float *remote_in = comm_remote_ptr(comm_ctx, local_in, peer_rank); + GlobalData remote_global(remote_in); + GlobalData local_global(local_out); + + ScratchTile scratch_tile; + TASSIGN(scratch_tile, 0x0); + + AsyncCtx async_ctx = get_async_ctx(args); + send_request_entry( + async_ctx, + SdmaTget(local_global, remote_global, scratch_tile, reinterpret_cast<__gm__ uint8_t *>(comm_ctx->workSpace)) + ); +} diff --git a/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/orchestration/sdma_async_completion_orch.cpp b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/orchestration/sdma_async_completion_orch.cpp new file mode 100644 index 000000000..a33c96730 --- /dev/null +++ b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/orchestration/sdma_async_completion_orch.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include "platform_comm/comm_context.h" +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +sdma_async_completion_orchestration_config(const L2TaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{.expected_arg_count = 4}; +} + +__attribute__((visibility("default"))) PTO2OrchestrationConfig aicpu_orchestration_config(const L2TaskArgs &orch_args) { + return sdma_async_completion_orchestration_config(orch_args); +} + +__attribute__((visibility("default"))) void sdma_async_completion_orchestration(const L2TaskArgs &orch_args) { + if (orch_args.tensor_count() + orch_args.scalar_count() != 4) { + LOG_ERROR("sdma_async_completion_demo: expected 4 args"); + return; + } + + const Tensor &input = orch_args.tensor(0).ref(); + const Tensor &out = orch_args.tensor(1).ref(); + const Tensor &result = orch_args.tensor(2).ref(); + auto *comm_ctx = reinterpret_cast(static_cast(orch_args.scalar(0))); + + L0TaskArgs producer_args; + producer_args.add_input(input); + producer_args.add_output(out); + producer_args.add_scalar(reinterpret_cast(comm_ctx)); + rt_submit_aiv_task(0, producer_args); + + L0TaskArgs consumer_args; + consumer_args.add_input(out); + consumer_args.add_output(result); + rt_submit_aiv_task(1, consumer_args); +} + +} // extern "C" diff --git a/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/test_sdma_async_completion_demo.py b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/test_sdma_async_completion_demo.py new file mode 100644 index 000000000..42ea2eca4 --- /dev/null +++ b/examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/test_sdma_async_completion_demo.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""SDMA deferred completion smoke test for onboard a5. + +Each rank stages its input inside the HCCL window. The deferred producer +TGET_ASYNCs the peer rank's input into local ``out`` and registers the PTO +AsyncEvent through ``defer_pto_async_event``. The consumer depends on the +producer output and writes ``result = out + 1``. Correct ``out`` and +``result`` therefore validate both the SDMA completion polling and the +deferred-release dependency path. +""" + +from __future__ import annotations + +import argparse +import os + +import pytest +import torch +from simpler.task_interface import ( + ArgDirection, + CallConfig, + ChipCallable, + CommBufferSpec, + CoreCallable, + DataType, + TaskArgs, + Tensor, + TensorArgType, +) +from simpler.worker import Worker + +from simpler_setup.elf_parser import extract_text_section +from simpler_setup.kernel_compiler import KernelCompiler +from simpler_setup.pto_isa import ensure_pto_isa_root +from simpler_setup.torch_interop import make_tensor_arg + +HERE = os.path.dirname(os.path.abspath(__file__)) +N = 128 * 128 +DTYPE_NBYTES = 4 + + +def parse_device_range(spec: str) -> list[int]: + if "," in spec: + return [int(x) for x in spec.split(",") if x] + if "-" in spec: + lo, hi = (int(x) for x in spec.split("-")) + return list(range(lo, hi + 1)) + return [int(spec)] + + +def build_chip_callable(platform: str, pto_isa_commit: str | None, clone_protocol: str) -> ChipCallable: + kc = KernelCompiler(platform=platform) + runtime = "tensormap_and_ringbuffer" + pto_isa_root = ensure_pto_isa_root(commit=pto_isa_commit, clone_protocol=clone_protocol) + include_dirs = kc.get_orchestration_include_dirs(runtime) + extra_includes = list(include_dirs) + [str(kc.project_root / "src" / "common")] + + children = [] + for func_id, rel in [ + (0, "kernels/aiv/kernel_sdma_tget_async.cpp"), + (1, "kernels/aiv/kernel_consumer.cpp"), + ]: + kernel = kc.compile_incore( + source_path=os.path.join(HERE, rel), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=extra_includes, + ) + if not platform.endswith("sim"): + kernel = extract_text_section(kernel) + children.append( + ( + func_id, + CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.OUT, ArgDirection.IN], + arg_index=[0, 1, 2, 3], + binary=kernel, + ), + ) + ) + + orch = kc.compile_orchestration( + runtime_name=runtime, + source_path=os.path.join(HERE, "kernels/orchestration/sdma_async_completion_orch.cpp"), + extra_include_dirs=[str(kc.project_root / "src" / "common")], + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.OUT, ArgDirection.IN], + func_name="sdma_async_completion_orchestration", + binary=orch, + children=children, + ) + + +def run( + platform: str = "a5", + device_ids: list[int] | None = None, + pto_isa_commit: str | None = None, +) -> int: + if device_ids is None: + device_ids = [0, 1] + nranks = len(device_ids) + if nranks != 2: + raise ValueError(f"sdma_async_completion_demo needs exactly 2 devices, got {device_ids}") + if platform.endswith("sim"): + raise ValueError("sdma_async_completion_demo requires onboard a5 hardware") + + input_nbytes = N * DTYPE_NBYTES + window_size = max(input_nbytes, 4 * 1024) + + # `inputs` must live in shared memory: `orch.copy_to` stages each rank's + # data into its HCCL window from the forked chip child, which reads `src` + # out of its own address space. + inputs = [ + torch.tensor([float(rank * 1000 + (i % 251)) / 10.0 for i in range(N)], dtype=torch.float32).share_memory_() + for rank in range(nranks) + ] + out = [torch.zeros(N, dtype=torch.float32).share_memory_() for _ in range(nranks)] + result = [torch.zeros(N, dtype=torch.float32).share_memory_() for _ in range(nranks)] + + chip_callable = build_chip_callable(platform, pto_isa_commit, "https") + worker = Worker( + level=3, + platform=platform, + runtime="tensormap_and_ringbuffer", + device_ids=device_ids, + num_sub_workers=0, + ) + chip_cid = worker.register(chip_callable) + try: + worker.init() + + def orch_fn(orch, _args, cfg): + with orch.allocate_domain( + name="default", + workers=list(range(nranks)), + window_size=window_size, + buffers=[ + CommBufferSpec(name="input_window", dtype="float32", count=N, nbytes=input_nbytes), + ], + ) as handle: + # Stage every rank's input window before submitting any kernel: + # each producer TGET_ASYNCs the *peer* rank's window, so all + # windows must hold real data before execution begins. + for rank in range(nranks): + orch.copy_to( + rank, + dst=handle[rank].buffer_ptrs["input_window"], + src=inputs[rank].data_ptr(), + size=input_nbytes, + ) + for rank in range(nranks): + domain = handle[rank] + args = TaskArgs() + args.add_tensor( + Tensor.make( + data=domain.buffer_ptrs["input_window"], + shapes=(N,), + dtype=DataType.FLOAT32, + child_memory=True, + ), + TensorArgType.INPUT, + ) + args.add_tensor(make_tensor_arg(out[rank]), TensorArgType.OUTPUT_EXISTING) + args.add_tensor(make_tensor_arg(result[rank]), TensorArgType.OUTPUT_EXISTING) + args.add_scalar(domain.device_ctx) + orch.submit_next_level(chip_cid, args, cfg, worker=rank) + + worker.run(orch_fn, args=None, config=CallConfig()) + + ok = True + for rank in range(nranks): + peer = 1 - rank + expected_out = inputs[peer] + expected_result = expected_out + 1.0 + max_out = float(torch.max(torch.abs(out[rank] - expected_out))) + max_result = float(torch.max(torch.abs(result[rank] - expected_result))) + print(f"[sdma_async_completion_demo] rank {rank}: max_out={max_out:.3e} max_result={max_result:.3e}") + ok = ok and max_out <= 1e-3 and max_result <= 1e-3 + return 0 if ok else 1 + finally: + worker.close() + + +@pytest.mark.platforms(["a5"]) +@pytest.mark.runtime("tensormap_and_ringbuffer") +@pytest.mark.device_count(2) +def test_sdma_async_completion_demo(st_device_ids, st_platform) -> None: + assert run(st_platform, [int(d) for d in st_device_ids]) == 0 + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", default="a5") + parser.add_argument("-d", "--device", default="0-1") + parser.add_argument("--pto-isa-commit", default=None) + args = parser.parse_args() + return run(args.platform, parse_device_range(args.device), args.pto_isa_commit) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/simpler_setup/runtime_compiler.py b/simpler_setup/runtime_compiler.py index d773768c8..07a805bf1 100644 --- a/simpler_setup/runtime_compiler.py +++ b/simpler_setup/runtime_compiler.py @@ -167,6 +167,11 @@ def _init_a2a3sim(self): def _init_a5(self): """Initialize toolchains for real a5 hardware.""" env_manager.ensure("ASCEND_HOME_PATH") + # a5 onboard host_runtime hard-depends on pto-isa headers + CANN-9.0 + # aclnn syms (cf. src/a5/platform/onboard/host/CMakeLists.txt + # SIMPLER_ENABLE_PTO_SDMA_WORKSPACE marker). PTO_ISA_ROOT must be + # populated by the caller — same contract as a2a3 onboard. + env_manager.ensure("PTO_ISA_ROOT") # AICore: Bisheng CCE compiler with A5 platform ccec = CCECToolchain(platform="a5") diff --git a/src/a5/platform/onboard/host/CMakeLists.txt b/src/a5/platform/onboard/host/CMakeLists.txt index 7450a1eb9..8ac6687bd 100644 --- a/src/a5/platform/onboard/host/CMakeLists.txt +++ b/src/a5/platform/onboard/host/CMakeLists.txt @@ -33,6 +33,21 @@ else() message(FATAL_ERROR "MUST set CUSTOM_INCLUDE_DIRS to build Host runtime") endif() +# SIMPLER_ENABLE_PTO_SDMA_WORKSPACE: marker for an outstanding architectural +# issue. The SDMA workspace init in comm_hccl.cpp pulls in pto-isa +# headers and CANN-9.0-only aclnn symbols (aclnnShmemSdmaStarsQuery*), +# even though it is logically orthogonal to HCCL comm bootstrap and only +# needed by the sdma_async_completion_demo. Until that coupling is +# refactored away, the macro is forced ON: PTO_ISA_ROOT and CANN 9.0+ are +# hard build/runtime preconditions for a5 onboard. +set(SIMPLER_ENABLE_PTO_SDMA_WORKSPACE ON) +if(NOT DEFINED ENV{PTO_ISA_ROOT}) + message(FATAL_ERROR + "a5 onboard host_runtime requires PTO_ISA_ROOT " + "(SIMPLER_ENABLE_PTO_SDMA_WORKSPACE is forced ON; needs pto-isa headers + CANN 9.0+)") +endif() +list(APPEND CMAKE_CUSTOM_INCLUDE_DIRS "$ENV{PTO_ISA_ROOT}/include") + # Build complete source list: core host sources + sources from CUSTOM_SOURCE_DIRS set(HOST_RUNTIME_SOURCES "") list(APPEND HOST_RUNTIME_SOURCES @@ -97,6 +112,10 @@ target_compile_options(host_runtime # src/common/platform/shared/host/platform_compile_info.cpp. target_compile_definitions(host_runtime PRIVATE SIMPLER_PLATFORM_NAME="a5") +if(SIMPLER_ENABLE_PTO_SDMA_WORKSPACE) + target_compile_definitions(host_runtime PRIVATE SIMPLER_ENABLE_PTO_SDMA_WORKSPACE=1) +endif() + # Include directories - always include local headers target_include_directories(host_runtime PRIVATE @@ -151,6 +170,11 @@ target_link_directories(host_runtime ${ASCEND_HOME_PATH}/runtime/lib64 ) +if(SIMPLER_ENABLE_PTO_SDMA_WORKSPACE) + target_link_directories(host_runtime PRIVATE ${ASCEND_HOME_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/lib64) + target_link_libraries(host_runtime PRIVATE nnopbase) +endif() + set_target_properties(host_runtime PROPERTIES OUTPUT_NAME "host_runtime") # Apply compiler sanitizers to this host-compiled target. No-op unless diff --git a/src/a5/platform/onboard/host/comm_hccl.cpp b/src/a5/platform/onboard/host/comm_hccl.cpp index ae23b2612..2bc37a495 100644 --- a/src/a5/platform/onboard/host/comm_hccl.cpp +++ b/src/a5/platform/onboard/host/comm_hccl.cpp @@ -45,6 +45,9 @@ #include "acl/acl.h" #include "hccl/hccl_comm.h" #include "hccl/hccl_types.h" +#ifdef SIMPLER_ENABLE_PTO_SDMA_WORKSPACE +#include "pto/comm/async/sdma/sdma_workspace_manager.hpp" +#endif // Thin wrappers around the HCCL public APIs we use. Kept as a translation // layer in case we need to swap (e.g., InitConfig variant) later. @@ -87,6 +90,9 @@ struct CommHandle_ { bool owns_device_ctx = false; std::vector derived_contexts; std::unordered_map> domain_allocations; +#ifdef SIMPLER_ENABLE_PTO_SDMA_WORKSPACE + std::unique_ptr sdma_workspace; +#endif }; // ============================================================================ @@ -843,9 +849,7 @@ extern "C" int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t *devic if (alloc_windows_via_ipc(h, effective_win_size) != 0) return -1; // Optional PTO-ISA async SDMA workspace pre-allocation (overlays the comm - // backend's output; comm-side flow does not care about workSpace). No-op - // when SIMPLER_ENABLE_PTO_SDMA_WORKSPACE is undefined (this PR ships A - // without the macro; the overlay PR turns it on). + // backend's output; comm-side flow does not care about workSpace). ensure_sdma_workspace(h); void *newDevMem = nullptr;