Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#include <cstdint>

#ifndef __gm__
#define __gm__
#endif
#ifndef __aicore__
#define __aicore__ [aicore]
#endif

#include <pto/pto-inst.hpp>

#include "tensor.h"

using namespace pto;

extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) {
__gm__ Tensor *src_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]);
__gm__ Tensor *result_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]);

__gm__ float *src = reinterpret_cast<__gm__ float *>(src_tensor->buffer.addr) + src_tensor->start_offset;
__gm__ float *result = reinterpret_cast<__gm__ float *>(result_tensor->buffer.addr) + result_tensor->start_offset;

constexpr int kTotalRows = 128;
constexpr int kRows = 64;
constexpr int kCols = 128;
constexpr int kIters = kTotalRows / kRows;
using DynShapeDim5 = Shape<1, 1, 1, kRows, kCols>;
using DynStrideDim5 = pto::Stride<1, 1, 1, kCols, 1>;
using GlobalData = GlobalTensor<float, DynShapeDim5, DynStrideDim5>;
using TileData = Tile<TileType::Vec, float, kRows, kCols, BLayout::RowMajor, -1, -1>;

TileData src_tile(kRows, kCols);
TileData result_tile(kRows, kCols);
TASSIGN(src_tile, 0x0);
TASSIGN(result_tile, 0x10000);

constexpr int kChunkElems = kRows * kCols;
for (int iter = 0; iter < kIters; ++iter) {
GlobalData src_global(src + iter * kChunkElems);
GlobalData result_global(result + iter * kChunkElems);
TLOAD(src_tile, src_global);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);

TADDS(result_tile, src_tile, 1.0f);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);

TSTORE(result_global, result_tile);
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#include <cstdint>

#ifndef __gm__
#define __gm__
#endif
#ifndef __aicore__
#define __aicore__ [aicore]
#endif

#include <pto/pto-inst.hpp>

#include "backend/sdma/sdma_completion_kernel.h"
#include "platform_comm/comm_context.h"
#include "pto_async_kernel_api.h"
#include "tensor.h"

using namespace pto;

template <typename T>
static inline __aicore__ __gm__ T *comm_remote_ptr(__gm__ CommContext *ctx, __gm__ T *local_ptr, int peer_rank) {
uint64_t local_base = ctx->windowsIn[ctx->rankId];
uint64_t offset = reinterpret_cast<uint64_t>(local_ptr) - local_base;
return reinterpret_cast<__gm__ T *>(ctx->windowsIn[peer_rank] + offset);
}

extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) {
__gm__ Tensor *in_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]);
__gm__ Tensor *out_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]);
__gm__ CommContext *comm_ctx = reinterpret_cast<__gm__ CommContext *>(args[2]);

__gm__ float *local_in = reinterpret_cast<__gm__ float *>(in_tensor->buffer.addr) + in_tensor->start_offset;
__gm__ float *local_out = reinterpret_cast<__gm__ float *>(out_tensor->buffer.addr) + out_tensor->start_offset;

int rank = static_cast<int>(comm_ctx->rankId);
int nranks = static_cast<int>(comm_ctx->rankNum);
if (nranks != 2 || comm_ctx->workSpace == 0) {
pipe_barrier(PIPE_ALL);
return;
}
int peer_rank = 1 - rank;

constexpr int kElems = 128 * 128;
using FlatShape = Shape<1, 1, 1, 1, kElems>;
using FlatStride = pto::Stride<kElems, kElems, kElems, kElems, 1>;
using GlobalData = GlobalTensor<float, FlatShape, FlatStride>;
using ScratchTile = Tile<TileType::Vec, uint8_t, 1, SDMA_SCRATCH_ALIGNMENT>;

__gm__ float *remote_in = comm_remote_ptr(comm_ctx, local_in, peer_rank);
GlobalData remote_global(remote_in);
GlobalData local_global(local_out);

ScratchTile scratch_tile;
TASSIGN(scratch_tile, 0x0);

AsyncCtx async_ctx = get_async_ctx(args);
send_request_entry(
async_ctx,
SdmaTget(local_global, remote_global, scratch_tile, reinterpret_cast<__gm__ uint8_t *>(comm_ctx->workSpace))
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) PyPTO Contributors.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
* -----------------------------------------------------------------------------------------------------------
*/

#include <stdint.h>

#include "platform_comm/comm_context.h"
#include "pto_orchestration_api.h"

extern "C" {

__attribute__((visibility("default"))) PTO2OrchestrationConfig
sdma_async_completion_orchestration_config(const L2TaskArgs &orch_args) {
(void)orch_args;
return PTO2OrchestrationConfig{.expected_arg_count = 4};
}

__attribute__((visibility("default"))) PTO2OrchestrationConfig aicpu_orchestration_config(const L2TaskArgs &orch_args) {
return sdma_async_completion_orchestration_config(orch_args);
}

__attribute__((visibility("default"))) void sdma_async_completion_orchestration(const L2TaskArgs &orch_args) {
if (orch_args.tensor_count() + orch_args.scalar_count() != 4) {
LOG_ERROR("sdma_async_completion_demo: expected 4 args");
return;
}
Comment on lines +30 to +33

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current check only verifies that the sum of tensor_count() and scalar_count() is 4. If the orchestrator is invoked with an unexpected combination of arguments (e.g., 2 tensors and 2 scalars), accessing orch_args.tensor(2) or other indices will result in an out-of-bounds access and potentially crash. It is safer to explicitly validate that tensor_count() is exactly 3 and scalar_count() is exactly 1.

Suggested change
if (orch_args.tensor_count() + orch_args.scalar_count() != 4) {
LOG_ERROR("sdma_async_completion_demo: expected 4 args");
return;
}
if (orch_args.tensor_count() != 3 || orch_args.scalar_count() != 1) {
LOG_ERROR("sdma_async_completion_demo: expected 3 tensors and 1 scalar");
return;
}
References
  1. Ensure that index-based accessors perform bounds checks to prevent undefined behavior or out-of-bounds memory access.


const Tensor &input = orch_args.tensor(0).ref();
const Tensor &out = orch_args.tensor(1).ref();
const Tensor &result = orch_args.tensor(2).ref();
auto *comm_ctx = reinterpret_cast<CommContext *>(static_cast<uintptr_t>(orch_args.scalar(0)));
Comment on lines +30 to +38

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Validate the tensor/scalar split, not just total arg count.

A call with 4 args but the wrong mix can pass this guard and still fail when accessing tensor(2) or scalar(0).

Proposed fix
-    if (orch_args.tensor_count() + orch_args.scalar_count() != 4) {
-        LOG_ERROR("sdma_async_completion_demo: expected 4 args");
+    if (orch_args.tensor_count() != 3 || orch_args.scalar_count() != 1) {
+        LOG_ERROR("sdma_async_completion_demo: expected 3 tensor args and 1 scalar arg");
         return;
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (orch_args.tensor_count() + orch_args.scalar_count() != 4) {
LOG_ERROR("sdma_async_completion_demo: expected 4 args");
return;
}
Tensor input = from_tensor_arg(orch_args.tensor(0));
Tensor out = from_tensor_arg(orch_args.tensor(1));
Tensor result = from_tensor_arg(orch_args.tensor(2));
auto *comm_ctx = reinterpret_cast<CommContext *>(static_cast<uintptr_t>(orch_args.scalar(0)));
if (orch_args.tensor_count() != 3 || orch_args.scalar_count() != 1) {
LOG_ERROR("sdma_async_completion_demo: expected 3 tensor args and 1 scalar arg");
return;
}
Tensor input = from_tensor_arg(orch_args.tensor(0));
Tensor out = from_tensor_arg(orch_args.tensor(1));
Tensor result = from_tensor_arg(orch_args.tensor(2));
auto *comm_ctx = reinterpret_cast<CommContext *>(static_cast<uintptr_t>(orch_args.scalar(0)));
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@examples/a5/tensormap_and_ringbuffer/sdma_async_completion_demo/kernels/orchestration/sdma_async_completion_orch.cpp`
around lines 31 - 39, The current guard in sdma_async_completion_orch.cpp only
checks the total argument count, so a bad tensor/scalar mix can still reach
tensor(2) and scalar(0). Update the validation around the orchestration argument
parsing to verify the exact split expected by Tensor accessors and the comm_ctx
scalar, not just orch_args.tensor_count() + orch_args.scalar_count(). Keep the
existing error handling in the same flow so invalid inputs are rejected before
from_tensor_arg() and reinterpret_cast<CommContext *> are used.


L0TaskArgs producer_args;
producer_args.add_input(input);
producer_args.add_output(out);
producer_args.add_scalar(reinterpret_cast<uint64_t>(comm_ctx));
rt_submit_aiv_task(0, producer_args);

L0TaskArgs consumer_args;
consumer_args.add_input(out);
consumer_args.add_output(result);
rt_submit_aiv_task(1, consumer_args);
}

} // extern "C"
Loading
Loading