Skip to content

[PAL] Add torch plugin support for Sunrise#483

Open
cyx11111 wants to merge 4 commits into
flagos-ai:mainfrom
cyx11111:main
Open

[PAL] Add torch plugin support for Sunrise#483
cyx11111 wants to merge 4 commits into
flagos-ai:mainfrom
cyx11111:main

Conversation

@cyx11111
Copy link
Copy Markdown
Contributor

PR Category
PAL

PR Types
New Features

PR Description
This PR introduces torch plugin support on Sunrise

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds PyTorch plugin support for the Sunrise/PTPU (PCCL + Tang runtime) backend, mirroring the existing TSM/Enflame PrivateUse1 paths and adding Sunrise-specific handling for p2p coalescing and stream/event wrapping.

Changes:

  • Wires USE_SUNRISE_ADAPTOR into the Torch plugin (__init__.py, _build_config.py, backend_flagcx.{hpp,cpp}, event_flagcx.hpp, stream_guard_flagcx.hpp) including a new flagcxPtpuEvent and a PTPU-aware flagcxStreamGuard that uses torchpt::get_stream_from_pool instead of getStreamFromExternal.
  • Implements PTPU-specific p2p via per-pair PCCL sub-comms (getOrInitPtpuPairComm) and rewrites startCoalescing/endCoalescing for Sunrise to defer and reorder send/recv ops in peer-ascending order.
  • Adds a sunrise_comm_traits.h header (aliased to default traits) and updates the README PCCL row (gather marked unsupported, PyTorch support marked supported).

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
README.md Marks PCCL PyTorch support as ✓ and PCCL gather as unsupported.
plugin/torch/_build_config.py Adds sunrise adaptor entries, pt-smi detection, and Sunrise include/lib/library wiring via torch_ptpu and CMAKE_TANG_TOOLKIT_DIR.
plugin/torch/flagcx/init.py Pre-imports torch_ptpu, reclaims the ptpu device→flagcx backend mapping, and adds ptpu to the device-name list.
plugin/torch/flagcx/include/backend_flagcx.hpp Adds ptpu device name, PTPU flagcxWork event, pair-comm helper, and PTPU coalescing state.
plugin/torch/flagcx/include/event_flagcx.hpp Adds flagcxPtpuEvent wrapping torchpt::PTPUEvent with a primed-event workaround for FlagCX-allocated tangStream_t.
plugin/torch/flagcx/include/stream_guard_flagcx.hpp Adds PTPU stream guard using get_stream_from_pool and manual current-stream restore.
plugin/torch/flagcx/src/backend_flagcx.cpp Adds Sunrise device check, PTPU pair-comm init via store, deferred/sorted p2p coalescing, and PTPU send/recv pair-comm routing.
flagcx/adaptor/include/device_api/comm_traits.h Dispatches to sunrise_comm_traits.h under USE_SUNRISE_ADAPTOR.
flagcx/adaptor/include/device_api/sunrise_comm_traits.h New header providing inner-type stubs and aliasing DeviceAPI to the default traits.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1314 to +1348
// Allocate a unique_id buffer. flagcxGetUniqueId() also creates a small
// bootstrap listener as a side-effect; on the follower side we simply
// overwrite the buffer with the leader's id (the follower's listener
// remains unused, matching how initComm() handles the global comm).
flagcxUniqueId_t uid = nullptr;
C10D_FLAGCX_CHECK(flagcxGetUniqueId(&uid), std::nullopt);
if (numRanks == 2) {
const std::string storeKey = "flagcx/p2p/" + key + "/uniqueId";
if (p2pRank == 0) {
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t *>(uid),
reinterpret_cast<uint8_t *>(uid) +
sizeof(flagcxUniqueId));
store_->set(storeKey, std::string(vec.begin(), vec.end()));
} else {
try {
auto vec = store_->get(storeKey);
TORCH_CHECK_WITH(DistBackendError, vec.size() == sizeof(flagcxUniqueId),
"Invalid size for flagcxUniqueId on p2p key '", key,
"'");
std::memcpy(reinterpret_cast<uint8_t *>(uid), vec.data(),
sizeof(flagcxUniqueId));
} catch (const std::exception &e) {
C10_THROW_ERROR(DistBackendError,
std::string("Failed to retrieve PCCL p2p unique id "
"from the store for key '") +
key + "': " + e.what());
}
}
}

flagcxComm_t pairComm = nullptr;
C10D_FLAGCX_CHECK(flagcxCommInitRank(&pairComm, numRanks, uid, p2pRank),
std::nullopt);
ptpuPairComms_.emplace(key, pairComm);
return pairComm;
Copy link
Copy Markdown
Contributor Author

@cyx11111 cyx11111 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 2309728

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants