-
Notifications
You must be signed in to change notification settings - Fork 63
Add: SDMA workspace overlay + async completion demo on a5 onboard #1179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| /* | ||
| * Copyright (c) PyPTO Contributors. | ||
| * This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| * CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| * See LICENSE in the root of the software repository for the full text of the License. | ||
| * ----------------------------------------------------------------------------------------------------------- | ||
| */ | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #ifndef __gm__ | ||
| #define __gm__ | ||
| #endif | ||
| #ifndef __aicore__ | ||
| #define __aicore__ [aicore] | ||
| #endif | ||
|
|
||
| #include <pto/pto-inst.hpp> | ||
|
|
||
| #include "tensor.h" | ||
|
|
||
| using namespace pto; | ||
|
|
||
| extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { | ||
| __gm__ Tensor *src_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); | ||
| __gm__ Tensor *result_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); | ||
|
|
||
| __gm__ float *src = reinterpret_cast<__gm__ float *>(src_tensor->buffer.addr) + src_tensor->start_offset; | ||
| __gm__ float *result = reinterpret_cast<__gm__ float *>(result_tensor->buffer.addr) + result_tensor->start_offset; | ||
|
|
||
| constexpr int kTotalRows = 128; | ||
| constexpr int kRows = 64; | ||
| constexpr int kCols = 128; | ||
| constexpr int kIters = kTotalRows / kRows; | ||
| using DynShapeDim5 = Shape<1, 1, 1, kRows, kCols>; | ||
| using DynStrideDim5 = pto::Stride<1, 1, 1, kCols, 1>; | ||
| using GlobalData = GlobalTensor<float, DynShapeDim5, DynStrideDim5>; | ||
| using TileData = Tile<TileType::Vec, float, kRows, kCols, BLayout::RowMajor, -1, -1>; | ||
|
|
||
| TileData src_tile(kRows, kCols); | ||
| TileData result_tile(kRows, kCols); | ||
| TASSIGN(src_tile, 0x0); | ||
| TASSIGN(result_tile, 0x10000); | ||
|
|
||
| constexpr int kChunkElems = kRows * kCols; | ||
| for (int iter = 0; iter < kIters; ++iter) { | ||
| GlobalData src_global(src + iter * kChunkElems); | ||
| GlobalData result_global(result + iter * kChunkElems); | ||
| TLOAD(src_tile, src_global); | ||
| set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); | ||
| wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); | ||
|
|
||
| TADDS(result_tile, src_tile, 1.0f); | ||
| set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); | ||
| wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); | ||
|
|
||
| TSTORE(result_global, result_tile); | ||
| set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); | ||
| wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| /* | ||
| * Copyright (c) PyPTO Contributors. | ||
| * This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| * CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| * See LICENSE in the root of the software repository for the full text of the License. | ||
| * ----------------------------------------------------------------------------------------------------------- | ||
| */ | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #ifndef __gm__ | ||
| #define __gm__ | ||
| #endif | ||
| #ifndef __aicore__ | ||
| #define __aicore__ [aicore] | ||
| #endif | ||
|
|
||
| #include <pto/pto-inst.hpp> | ||
|
|
||
| #include "backend/sdma/sdma_completion_kernel.h" | ||
| #include "platform_comm/comm_context.h" | ||
| #include "pto_async_kernel_api.h" | ||
| #include "tensor.h" | ||
|
|
||
| using namespace pto; | ||
|
|
||
| template <typename T> | ||
| static inline __aicore__ __gm__ T *comm_remote_ptr(__gm__ CommContext *ctx, __gm__ T *local_ptr, int peer_rank) { | ||
| uint64_t local_base = ctx->windowsIn[ctx->rankId]; | ||
| uint64_t offset = reinterpret_cast<uint64_t>(local_ptr) - local_base; | ||
| return reinterpret_cast<__gm__ T *>(ctx->windowsIn[peer_rank] + offset); | ||
| } | ||
|
|
||
| extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { | ||
| __gm__ Tensor *in_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); | ||
| __gm__ Tensor *out_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); | ||
| __gm__ CommContext *comm_ctx = reinterpret_cast<__gm__ CommContext *>(args[2]); | ||
|
|
||
| __gm__ float *local_in = reinterpret_cast<__gm__ float *>(in_tensor->buffer.addr) + in_tensor->start_offset; | ||
| __gm__ float *local_out = reinterpret_cast<__gm__ float *>(out_tensor->buffer.addr) + out_tensor->start_offset; | ||
|
|
||
| int rank = static_cast<int>(comm_ctx->rankId); | ||
| int nranks = static_cast<int>(comm_ctx->rankNum); | ||
| if (nranks != 2 || comm_ctx->workSpace == 0) { | ||
| pipe_barrier(PIPE_ALL); | ||
| return; | ||
| } | ||
| int peer_rank = 1 - rank; | ||
|
|
||
| constexpr int kElems = 128 * 128; | ||
| using FlatShape = Shape<1, 1, 1, 1, kElems>; | ||
| using FlatStride = pto::Stride<kElems, kElems, kElems, kElems, 1>; | ||
| using GlobalData = GlobalTensor<float, FlatShape, FlatStride>; | ||
| using ScratchTile = Tile<TileType::Vec, uint8_t, 1, SDMA_SCRATCH_ALIGNMENT>; | ||
|
|
||
| __gm__ float *remote_in = comm_remote_ptr(comm_ctx, local_in, peer_rank); | ||
| GlobalData remote_global(remote_in); | ||
| GlobalData local_global(local_out); | ||
|
|
||
| ScratchTile scratch_tile; | ||
| TASSIGN(scratch_tile, 0x0); | ||
|
|
||
| AsyncCtx async_ctx = get_async_ctx(args); | ||
| send_request_entry( | ||
| async_ctx, | ||
| SdmaTget(local_global, remote_global, scratch_tile, reinterpret_cast<__gm__ uint8_t *>(comm_ctx->workSpace)) | ||
| ); | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||
| * Copyright (c) PyPTO Contributors. | ||||||||||||||||||||||||||||||||||||||
| * This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||||||||||||||||||||||||||||||||||||||
| * CANN Open Software License Agreement Version 2.0 (the "License"). | ||||||||||||||||||||||||||||||||||||||
| * Please refer to the License for details. You may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||
| * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||||||||||||||||||||||||||||||||||||||
| * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||||||||||||||||||||||||||||||||||||||
| * See LICENSE in the root of the software repository for the full text of the License. | ||||||||||||||||||||||||||||||||||||||
| * ----------------------------------------------------------------------------------------------------------- | ||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include <stdint.h> | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include "platform_comm/comm_context.h" | ||||||||||||||||||||||||||||||||||||||
| #include "pto_orchestration_api.h" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| extern "C" { | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| __attribute__((visibility("default"))) PTO2OrchestrationConfig | ||||||||||||||||||||||||||||||||||||||
| sdma_async_completion_orchestration_config(const L2TaskArgs &orch_args) { | ||||||||||||||||||||||||||||||||||||||
| (void)orch_args; | ||||||||||||||||||||||||||||||||||||||
| return PTO2OrchestrationConfig{.expected_arg_count = 4}; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| __attribute__((visibility("default"))) PTO2OrchestrationConfig aicpu_orchestration_config(const L2TaskArgs &orch_args) { | ||||||||||||||||||||||||||||||||||||||
| return sdma_async_completion_orchestration_config(orch_args); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| __attribute__((visibility("default"))) void sdma_async_completion_orchestration(const L2TaskArgs &orch_args) { | ||||||||||||||||||||||||||||||||||||||
| if (orch_args.tensor_count() + orch_args.scalar_count() != 4) { | ||||||||||||||||||||||||||||||||||||||
| LOG_ERROR("sdma_async_completion_demo: expected 4 args"); | ||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const Tensor &input = orch_args.tensor(0).ref(); | ||||||||||||||||||||||||||||||||||||||
| const Tensor &out = orch_args.tensor(1).ref(); | ||||||||||||||||||||||||||||||||||||||
| const Tensor &result = orch_args.tensor(2).ref(); | ||||||||||||||||||||||||||||||||||||||
| auto *comm_ctx = reinterpret_cast<CommContext *>(static_cast<uintptr_t>(orch_args.scalar(0))); | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+30
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win Validate the tensor/scalar split, not just total arg count. A call with 4 args but the wrong mix can pass this guard and still fail when accessing Proposed fix- if (orch_args.tensor_count() + orch_args.scalar_count() != 4) {
- LOG_ERROR("sdma_async_completion_demo: expected 4 args");
+ if (orch_args.tensor_count() != 3 || orch_args.scalar_count() != 1) {
+ LOG_ERROR("sdma_async_completion_demo: expected 3 tensor args and 1 scalar arg");
return;
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| L0TaskArgs producer_args; | ||||||||||||||||||||||||||||||||||||||
| producer_args.add_input(input); | ||||||||||||||||||||||||||||||||||||||
| producer_args.add_output(out); | ||||||||||||||||||||||||||||||||||||||
| producer_args.add_scalar(reinterpret_cast<uint64_t>(comm_ctx)); | ||||||||||||||||||||||||||||||||||||||
| rt_submit_aiv_task(0, producer_args); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| L0TaskArgs consumer_args; | ||||||||||||||||||||||||||||||||||||||
| consumer_args.add_input(out); | ||||||||||||||||||||||||||||||||||||||
| consumer_args.add_output(result); | ||||||||||||||||||||||||||||||||||||||
| rt_submit_aiv_task(1, consumer_args); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| } // extern "C" | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current check only verifies that the sum of
tensor_count()andscalar_count()is 4. If the orchestrator is invoked with an unexpected combination of arguments (e.g., 2 tensors and 2 scalars), accessingorch_args.tensor(2)or other indices will result in an out-of-bounds access and potentially crash. It is safer to explicitly validate thattensor_count()is exactly 3 andscalar_count()is exactly 1.References