From c8b3db1cf649de3dbbe4b73a435bf9046cf1bf71 Mon Sep 17 00:00:00 2001 From: MC952-arch Date: Wed, 27 May 2026 18:15:06 +0800 Subject: [PATCH 1/4] regpool: support null comm registration and optimize containers --- flagcx/core/include/reg_pool.h | 14 ++- flagcx/core/include/register.h | 4 +- flagcx/core/reg_pool.cc | 206 ++++++++++++++++++++------------- flagcx/flagcx.cc | 30 +++-- 4 files changed, 162 insertions(+), 92 deletions(-) diff --git a/flagcx/core/include/reg_pool.h b/flagcx/core/include/reg_pool.h index c9cec242f..017537fbc 100644 --- a/flagcx/core/include/reg_pool.h +++ b/flagcx/core/include/reg_pool.h @@ -6,11 +6,13 @@ #include "flagcx.h" #include "net.h" #include "register.h" -#include #include +#include class flagcxRegPool { public: + static constexpr uintptr_t GLOBAL_POOL_KEY = 0; // nullptr comm maps here + flagcxRegPool(); ~flagcxRegPool(); @@ -25,16 +27,18 @@ class flagcxRegPool { flagcxResult_t removeAllP2pHandles(void *comm); flagcxResult_t registerBuffer(void *comm, void *data, size_t length); flagcxResult_t deregisterBuffer(void *comm, void *handle); - std::map> &getGlobalMap(); + std::unordered_map> + &getGlobalMap(); flagcxRegItem *getItem(const void *comm, void *data); void dump(); private: void mapRegItemPages(uintptr_t commKey, flagcxRegItem *reg); - std::map> + std::unordered_map> regMap; // > - std::map> - regPool; // + std::unordered_map> + regPool; // > (only GLOBAL_POOL_KEY owns + // data) uintptr_t pageSize; }; diff --git a/flagcx/core/include/register.h b/flagcx/core/include/register.h index 412c5fc3d..a36155706 100644 --- a/flagcx/core/include/register.h +++ b/flagcx/core/include/register.h @@ -3,7 +3,7 @@ #include "core.h" #include "device.h" -#include +#include #define FLAGCX_IPC_HANDLE_SIZE 64 @@ -60,7 +60,7 @@ struct flagcxRegItem { uintptr_t beginAddr = 0; uintptr_t endAddr = 0; int refCount = 1; - std::list> handles; + std::vector> handles; void *homoRegHandle = nullptr; // backend CCL handle (homo path only) flagcxIpcHandleData ipcHandleData = {}; // IPC handle bytes (both paths) }; diff --git a/flagcx/core/reg_pool.cc b/flagcx/core/reg_pool.cc index 0e9489051..dc5462893 100644 --- a/flagcx/core/reg_pool.cc +++ b/flagcx/core/reg_pool.cc @@ -4,8 +4,6 @@ #include #include -#define DEFAULT_REGPOOL_SIZE 16 - flagcxRegPool::flagcxRegPool() { pageSize = sysconf(_SC_PAGESIZE); } flagcxRegPool::~flagcxRegPool() { @@ -24,7 +22,7 @@ inline void flagcxRegPool::getPagedAddr(void *data, size_t length, flagcxResult_t flagcxRegPool::addNetHandle(void *comm, flagcxRegItem *reg, void *handle, struct flagcxProxyConnector *proxyConn) { - if (comm == nullptr || reg == nullptr) { + if (reg == nullptr) { return flagcxSuccess; } for (auto &handlePair : reg->handles) { @@ -43,7 +41,7 @@ flagcxRegPool::addNetHandle(void *comm, flagcxRegItem *reg, void *handle, flagcxResult_t flagcxRegPool::addP2pHandle(void *comm, flagcxRegItem *reg, void *handle, struct flagcxProxyConnector *proxyConn) { - if (comm == nullptr || reg == nullptr) { + if (reg == nullptr) { return flagcxSuccess; } for (auto &handlePair : reg->handles) { @@ -65,17 +63,19 @@ flagcxResult_t flagcxRegPool::removeRegItemNetHandles(void *comm, return flagcxSuccess; } - for (auto it = reg->handles.begin(); it != reg->handles.end();) { - if (it->first.handle) { - FLAGCXCHECK(flagcxNetDeregisterBuffer(comm, it->first.proxyConn, - it->first.handle)); - it->first.handle = nullptr; - it->first.proxyConn = nullptr; + for (size_t i = 0; i < reg->handles.size();) { + auto &entry = reg->handles[i]; + if (entry.first.handle) { + FLAGCXCHECK(flagcxNetDeregisterBuffer(comm, entry.first.proxyConn, + entry.first.handle)); + entry.first.handle = nullptr; + entry.first.proxyConn = nullptr; } - if (it->first.handle == nullptr && it->second.handle == nullptr) { - it = reg->handles.erase(it); + if (entry.first.handle == nullptr && entry.second.handle == nullptr) { + reg->handles[i] = reg->handles.back(); + reg->handles.pop_back(); } else { - ++it; + ++i; } } return flagcxSuccess; @@ -87,18 +87,20 @@ flagcxResult_t flagcxRegPool::removeRegItemP2pHandles(void *comm, return flagcxSuccess; } - for (auto it = reg->handles.begin(); it != reg->handles.end();) { - if (it->second.handle) { - flagcxIpcRegInfo *ipcInfo = (flagcxIpcRegInfo *)it->second.handle; + for (size_t i = 0; i < reg->handles.size();) { + auto &entry = reg->handles[i]; + if (entry.second.handle) { + flagcxIpcRegInfo *ipcInfo = (flagcxIpcRegInfo *)entry.second.handle; FLAGCXCHECK(flagcxP2pDeregisterBuffer( reinterpret_cast(comm), ipcInfo)); - it->second.handle = nullptr; - it->second.proxyConn = nullptr; + entry.second.handle = nullptr; + entry.second.proxyConn = nullptr; } - if (it->first.handle == nullptr && it->second.handle == nullptr) { - it = reg->handles.erase(it); + if (entry.first.handle == nullptr && entry.second.handle == nullptr) { + reg->handles[i] = reg->handles.back(); + reg->handles.pop_back(); } else { - ++it; + ++i; } } return flagcxSuccess; @@ -108,13 +110,11 @@ flagcxResult_t flagcxRegPool::removeAllP2pHandles(void *comm) { if (comm == nullptr) { return flagcxSuccess; } - uintptr_t commKey = reinterpret_cast(comm); - auto poolIt = regPool.find(commKey); - if (poolIt == regPool.end()) { - return flagcxSuccess; - } - for (auto ® : poolIt->second) { - FLAGCXCHECK(removeRegItemP2pHandles(comm, ®)); + // Iterate over all items in the global pool and remove p2p handles + // associated with this comm + auto &globalPool = regPool[GLOBAL_POOL_KEY]; + for (auto &pair : globalPool) { + FLAGCXCHECK(removeRegItemP2pHandles(comm, &pair.second)); } return flagcxSuccess; } @@ -131,83 +131,134 @@ void flagcxRegPool::mapRegItemPages(uintptr_t commKey, flagcxRegItem *reg) { flagcxResult_t flagcxRegPool::registerBuffer(void *comm, void *data, size_t length) { - if (comm == nullptr || data == nullptr || length == 0) + if (data == nullptr || length == 0) return flagcxSuccess; - uintptr_t commKey = reinterpret_cast(comm); + uintptr_t commKey = + comm ? reinterpret_cast(comm) : GLOBAL_POOL_KEY; uintptr_t beginAddr, endAddr; getPagedAddr(data, length, &beginAddr, &endAddr); - auto ®CommPool = regPool[commKey]; - for (auto it = regCommPool.begin(); it != regCommPool.end(); it++) { - // found a place to insert - if (beginAddr < it->beginAddr) { - flagcxRegItem reg{beginAddr, endAddr, 1, {}}; - auto &insertedReg = *regCommPool.insert(it, std::move(reg)); - mapRegItemPages(commKey, &insertedReg); - return flagcxSuccess; - // already inserted, just increase ref count - } else if (it->beginAddr <= beginAddr && it->endAddr >= endAddr) { - it->refCount++; - return flagcxSuccess; + // Always check/insert into the global pool (single source of truth) + auto &globalPool = regPool[GLOBAL_POOL_KEY]; + auto it = globalPool.find(beginAddr); + if (it != globalPool.end()) { + // Already registered: bump refCount + it->second.refCount++; + // If comm is non-null, ensure it's mapped in the comm-specific regMap + if (comm != nullptr) { + mapRegItemPages(commKey, &it->second); } + return flagcxSuccess; } - // not found, insert to the end + // Not found: create new item in global pool flagcxRegItem reg{beginAddr, endAddr, 1, {}}; - regCommPool.push_back(std::move(reg)); - mapRegItemPages(commKey, ®CommPool.back()); + auto [inserted, success] = globalPool.emplace(beginAddr, std::move(reg)); + flagcxRegItem *regPtr = &inserted->second; + + // Map pages in global regMap + mapRegItemPages(GLOBAL_POOL_KEY, regPtr); + + // If comm is non-null, also map pages in comm-specific regMap + if (comm != nullptr) { + mapRegItemPages(commKey, regPtr); + } + return flagcxSuccess; } flagcxResult_t flagcxRegPool::deregisterBuffer(void *comm, void *handle) { - if (comm == nullptr || handle == nullptr) { + if (handle == nullptr) { return flagcxSuccess; } - uintptr_t commKey = reinterpret_cast(comm); + uintptr_t commKey = + comm ? reinterpret_cast(comm) : GLOBAL_POOL_KEY; flagcxRegItem *reg = (flagcxRegItem *)handle; - auto ®CommPool = regPool[commKey]; - for (auto it = regCommPool.begin(); it != regCommPool.end(); it++) { - if (&(*it) == reg) { - it->refCount--; - if (it->refCount > 0) { - return flagcxSuccess; + // Find the item in the global pool + auto &globalPool = regPool[GLOBAL_POOL_KEY]; + auto poolIt = globalPool.find(reg->beginAddr); + if (poolIt == globalPool.end() || &poolIt->second != reg) { + WARN("Could not find the given handle in regPool"); + return flagcxInvalidUsage; + } + + reg->refCount--; + + // Remove comm-specific page mappings + if (comm != nullptr && commKey != GLOBAL_POOL_KEY) { + auto mapIt = regMap.find(commKey); + if (mapIt != regMap.end()) { + auto &commMap = mapIt->second; + for (uintptr_t addr = reg->beginAddr; addr < reg->endAddr; + addr += pageSize) { + commMap.erase(addr); } - FLAGCXCHECK(removeRegItemNetHandles(comm, reg)); - FLAGCXCHECK(removeRegItemP2pHandles(comm, reg)); - auto ®CommMap = regMap[commKey]; - for (auto mapIter = regCommMap.begin(); mapIter != regCommMap.end();) { - if (mapIter->second == reg) { - mapIter = regCommMap.erase(mapIter); - } else { - mapIter++; - } + if (commMap.empty()) { + regMap.erase(mapIt); } - regCommPool.erase(it); - return flagcxSuccess; } } - WARN("Could not find the given handle in regPool"); - return flagcxInvalidUsage; + if (reg->refCount > 0) { + return flagcxSuccess; + } + + // refCount == 0: full cleanup + FLAGCXCHECK(removeRegItemNetHandles(comm, reg)); + FLAGCXCHECK(removeRegItemP2pHandles(comm, reg)); + + // Remove from global regMap + auto globalMapIt = regMap.find(GLOBAL_POOL_KEY); + if (globalMapIt != regMap.end()) { + auto &globalMap = globalMapIt->second; + for (uintptr_t addr = reg->beginAddr; addr < reg->endAddr; + addr += pageSize) { + globalMap.erase(addr); + } + if (globalMap.empty()) { + regMap.erase(globalMapIt); + } + } + + // Remove from global pool (this destroys the flagcxRegItem) + globalPool.erase(poolIt); + return flagcxSuccess; } -std::map> & +std::unordered_map> & flagcxRegPool::getGlobalMap() { return regMap; } flagcxRegItem *flagcxRegPool::getItem(const void *comm, void *data) { - uintptr_t commKey = reinterpret_cast(comm); uintptr_t beginAddr, endAddr; getPagedAddr(data, 0, &beginAddr, &endAddr); - auto it = regMap[commKey].find(beginAddr); - if (it == regMap[commKey].end()) { - return nullptr; + + // If comm is non-null, check comm-specific regMap first + if (comm != nullptr) { + uintptr_t commKey = reinterpret_cast(comm); + auto mapIt = regMap.find(commKey); + if (mapIt != regMap.end()) { + auto it = mapIt->second.find(beginAddr); + if (it != mapIt->second.end()) { + return it->second; + } + } + } + + // Fall through to global pool + auto globalMapIt = regMap.find(GLOBAL_POOL_KEY); + if (globalMapIt != regMap.end()) { + auto it = globalMapIt->second.find(beginAddr); + if (it != globalMapIt->second.end()) { + return it->second; + } } - return it->second; + + return nullptr; } void flagcxRegPool::dump() { @@ -218,14 +269,13 @@ void flagcxRegPool::dump() { for (auto &p : c.second) { printf("beginAddr(%lu) -> regItem[%lu,%lu,%d]\n", p.first, p.second->beginAddr, p.second->endAddr, p.second->refCount); - auto it = p.second->handles.begin(); - for (; it != p.second->handles.end(); it++) { - printf("handlePtr(%p) -> netHandle[%p,%p] p2pHandle[%p,%p]\n", &(*it), - it->first.handle, it->first.proxyConn, it->second.handle, - it->second.proxyConn); + for (auto &h : p.second->handles) { + printf("handlePtr(%p) -> netHandle[%p,%p] p2pHandle[%p,%p]\n", &h, + h.first.handle, h.first.proxyConn, h.second.handle, + h.second.proxyConn); } } printf("==comm(%lu)==\n", c.first); } printf("========================\n"); -} \ No newline at end of file +} diff --git a/flagcx/flagcx.cc b/flagcx/flagcx.cc index 07603498e..8557cb320 100644 --- a/flagcx/flagcx.cc +++ b/flagcx/flagcx.cc @@ -998,7 +998,9 @@ flagcxOneSideBarrierDeregister(const flagcxComm_t comm, flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, size_t size, void **handle) { - FLAGCXCHECK(flagcxEnsureCommReady(comm)); + if (comm != nullptr) { + FLAGCXCHECK(flagcxEnsureCommReady(comm)); + } if (buff == NULL || size == 0) { WARN("Invalid buffer or size for buffer registration."); @@ -1007,13 +1009,22 @@ flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, // Step 1: Register in globalRegPool (both paths) // Key: heteroComm if available (p2p/net downstream use it), else homoComm - void *regKey = - comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + // If comm is NULL, register in global pool only (GLOBAL_POOL_KEY) + void *regKey = nullptr; + if (comm != nullptr) { + regKey = + comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + } globalRegPool.registerBuffer(regKey, buff, size); flagcxRegItem *regItem = globalRegPool.getItem(regKey, buff); *handle = reinterpret_cast(regItem); + // Null comm: pool-only registration, skip backend steps + if (comm == nullptr) { + return flagcxSuccess; + } + // Re-registration: backend handle + IPC handle already set up if (regItem->refCount > 1) { return flagcxSuccess; @@ -1082,13 +1093,15 @@ flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, } flagcxResult_t flagcxCommDeregister(const flagcxComm_t comm, void *handle) { - FLAGCXCHECK(flagcxEnsureCommReady(comm)); + if (comm != nullptr) { + FLAGCXCHECK(flagcxEnsureCommReady(comm)); + } if (handle == nullptr) return flagcxSuccess; flagcxRegItem *regItem = reinterpret_cast(handle); // Backend-specific deregistration (homo path only, last ref only) - if (regItem->refCount == 1) { + if (comm != nullptr && regItem->refCount == 1) { if (useHomoComm(comm) && !useHeteroComm() && regItem->homoRegHandle) { cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister( comm->homoComm, regItem->homoRegHandle); @@ -1096,8 +1109,11 @@ flagcxResult_t flagcxCommDeregister(const flagcxComm_t comm, void *handle) { } // Clean up globalRegPool (both paths) - void *regKey = - comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + void *regKey = nullptr; + if (comm != nullptr) { + regKey = + comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + } globalRegPool.deregisterBuffer(regKey, handle); return flagcxSuccess; } From c4702236265131640f6223c9e84531570f6ef4e8 Mon Sep 17 00:00:00 2001 From: MC952-arch Date: Thu, 28 May 2026 12:14:16 +0800 Subject: [PATCH 2/4] Fix issues --- flagcx/core/include/reg_pool.h | 4 +- flagcx/core/include/register.h | 10 ++++- flagcx/core/p2p.cc | 6 +-- flagcx/core/reg_pool.cc | 34 ++++++++------ flagcx/flagcx.cc | 81 ++++++++++++++++++++-------------- 5 files changed, 82 insertions(+), 53 deletions(-) diff --git a/flagcx/core/include/reg_pool.h b/flagcx/core/include/reg_pool.h index 017537fbc..e96be181f 100644 --- a/flagcx/core/include/reg_pool.h +++ b/flagcx/core/include/reg_pool.h @@ -6,6 +6,7 @@ #include "flagcx.h" #include "net.h" #include "register.h" +#include #include #include @@ -36,7 +37,8 @@ class flagcxRegPool { void mapRegItemPages(uintptr_t commKey, flagcxRegItem *reg); std::unordered_map> regMap; // > - std::unordered_map> + std::unordered_map< + uintptr_t, std::unordered_map>> regPool; // > (only GLOBAL_POOL_KEY owns // data) uintptr_t pageSize; diff --git a/flagcx/core/include/register.h b/flagcx/core/include/register.h index a36155706..4d4337f86 100644 --- a/flagcx/core/include/register.h +++ b/flagcx/core/include/register.h @@ -3,6 +3,8 @@ #include "core.h" #include "device.h" +#include +#include #include #define FLAGCX_IPC_HANDLE_SIZE 64 @@ -28,11 +30,13 @@ struct netRegInfo { struct flagcxRegNetHandle { void *handle = NULL; struct flagcxProxyConnector *proxyConn = NULL; + void *ownerComm = NULL; // comm that registered this handle }; struct flagcxRegP2pHandle { void *handle = NULL; struct flagcxProxyConnector *proxyConn = NULL; + void *ownerComm = NULL; // comm that registered this handle }; struct flagcxIpcImpInfo { @@ -61,8 +65,10 @@ struct flagcxRegItem { uintptr_t endAddr = 0; int refCount = 1; std::vector> handles; - void *homoRegHandle = nullptr; // backend CCL handle (homo path only) - flagcxIpcHandleData ipcHandleData = {}; // IPC handle bytes (both paths) + flagcxIpcHandleData localIpcHandleData = + {}; // sender's IPC handle bytes (hetero path) + std::unordered_map + homoRegHandles; // commKey → backend CCL handle }; struct flagcxReg { diff --git a/flagcx/core/p2p.cc b/flagcx/core/p2p.cc index 100e99bbf..d2782ed63 100644 --- a/flagcx/core/p2p.cc +++ b/flagcx/core/p2p.cc @@ -817,9 +817,9 @@ static flagcxResult_t p2pRegisterBuffer(flagcxHeteroComm *comm, } else if (legacyIpcCap) { // Different process: get IPC handle for our own buffer char zeros[sizeof(flagcxIpcHandleData)] = {}; - if (memcmp(®Item->ipcHandleData, zeros, + if (memcmp(®Item->localIpcHandleData, zeros, sizeof(flagcxIpcHandleData)) != 0) { - memcpy(&handleData, ®Item->ipcHandleData, + memcpy(&handleData, ®Item->localIpcHandleData, sizeof(flagcxIpcHandleData)); } else { flagcxIpcMemHandle_t ipcHandle = NULL; @@ -832,7 +832,7 @@ static flagcxResult_t p2pRegisterBuffer(flagcxHeteroComm *comm, fail); if (handleSize <= sizeof(flagcxIpcHandleData)) { memcpy(&handleData, ipcHandle, handleSize); - memcpy(®Item->ipcHandleData, ipcHandle, handleSize); + memcpy(®Item->localIpcHandleData, ipcHandle, handleSize); } deviceAdaptor->ipcMemHandleFree(ipcHandle); } diff --git a/flagcx/core/reg_pool.cc b/flagcx/core/reg_pool.cc index dc5462893..fa32a32aa 100644 --- a/flagcx/core/reg_pool.cc +++ b/flagcx/core/reg_pool.cc @@ -31,8 +31,8 @@ flagcxRegPool::addNetHandle(void *comm, flagcxRegItem *reg, void *handle, return flagcxSuccess; } } - flagcxRegNetHandle netHandle{handle, proxyConn}; - flagcxRegP2pHandle p2pHandle{nullptr, nullptr}; + flagcxRegNetHandle netHandle{handle, proxyConn, comm}; + flagcxRegP2pHandle p2pHandle{nullptr, nullptr, nullptr}; reg->handles.push_back(std::make_pair(netHandle, p2pHandle)); return flagcxSuccess; @@ -50,8 +50,8 @@ flagcxRegPool::addP2pHandle(void *comm, flagcxRegItem *reg, void *handle, return flagcxSuccess; } } - flagcxRegNetHandle netHandle{nullptr, nullptr}; - flagcxRegP2pHandle p2pHandle{handle, proxyConn}; + flagcxRegNetHandle netHandle{nullptr, nullptr, nullptr}; + flagcxRegP2pHandle p2pHandle{handle, proxyConn, comm}; reg->handles.push_back(std::make_pair(netHandle, p2pHandle)); return flagcxSuccess; @@ -66,10 +66,11 @@ flagcxResult_t flagcxRegPool::removeRegItemNetHandles(void *comm, for (size_t i = 0; i < reg->handles.size();) { auto &entry = reg->handles[i]; if (entry.first.handle) { - FLAGCXCHECK(flagcxNetDeregisterBuffer(comm, entry.first.proxyConn, - entry.first.handle)); + FLAGCXCHECK(flagcxNetDeregisterBuffer( + entry.first.ownerComm, entry.first.proxyConn, entry.first.handle)); entry.first.handle = nullptr; entry.first.proxyConn = nullptr; + entry.first.ownerComm = nullptr; } if (entry.first.handle == nullptr && entry.second.handle == nullptr) { reg->handles[i] = reg->handles.back(); @@ -92,9 +93,11 @@ flagcxResult_t flagcxRegPool::removeRegItemP2pHandles(void *comm, if (entry.second.handle) { flagcxIpcRegInfo *ipcInfo = (flagcxIpcRegInfo *)entry.second.handle; FLAGCXCHECK(flagcxP2pDeregisterBuffer( - reinterpret_cast(comm), ipcInfo)); + reinterpret_cast(entry.second.ownerComm), + ipcInfo)); entry.second.handle = nullptr; entry.second.proxyConn = nullptr; + entry.second.ownerComm = nullptr; } if (entry.first.handle == nullptr && entry.second.handle == nullptr) { reg->handles[i] = reg->handles.back(); @@ -114,7 +117,7 @@ flagcxResult_t flagcxRegPool::removeAllP2pHandles(void *comm) { // associated with this comm auto &globalPool = regPool[GLOBAL_POOL_KEY]; for (auto &pair : globalPool) { - FLAGCXCHECK(removeRegItemP2pHandles(comm, &pair.second)); + FLAGCXCHECK(removeRegItemP2pHandles(comm, pair.second.get())); } return flagcxSuccess; } @@ -144,18 +147,21 @@ flagcxResult_t flagcxRegPool::registerBuffer(void *comm, void *data, auto it = globalPool.find(beginAddr); if (it != globalPool.end()) { // Already registered: bump refCount - it->second.refCount++; + it->second->refCount++; // If comm is non-null, ensure it's mapped in the comm-specific regMap if (comm != nullptr) { - mapRegItemPages(commKey, &it->second); + mapRegItemPages(commKey, it->second.get()); } return flagcxSuccess; } // Not found: create new item in global pool - flagcxRegItem reg{beginAddr, endAddr, 1, {}}; - auto [inserted, success] = globalPool.emplace(beginAddr, std::move(reg)); - flagcxRegItem *regPtr = &inserted->second; + auto reg = std::make_unique(); + reg->beginAddr = beginAddr; + reg->endAddr = endAddr; + reg->refCount = 1; + auto [it2, didInsert] = globalPool.emplace(beginAddr, std::move(reg)); + flagcxRegItem *regPtr = it2->second.get(); // Map pages in global regMap mapRegItemPages(GLOBAL_POOL_KEY, regPtr); @@ -180,7 +186,7 @@ flagcxResult_t flagcxRegPool::deregisterBuffer(void *comm, void *handle) { // Find the item in the global pool auto &globalPool = regPool[GLOBAL_POOL_KEY]; auto poolIt = globalPool.find(reg->beginAddr); - if (poolIt == globalPool.end() || &poolIt->second != reg) { + if (poolIt == globalPool.end() || poolIt->second.get() != reg) { WARN("Could not find the given handle in regPool"); return flagcxInvalidUsage; } diff --git a/flagcx/flagcx.cc b/flagcx/flagcx.cc index 8557cb320..ed28c9ac6 100644 --- a/flagcx/flagcx.cc +++ b/flagcx/flagcx.cc @@ -1025,10 +1025,7 @@ flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, return flagcxSuccess; } - // Re-registration: backend handle + IPC handle already set up - if (regItem->refCount > 1) { - return flagcxSuccess; - } + uintptr_t thisCommKey = reinterpret_cast(regKey); flagcxResult_t res = flagcxSuccess; @@ -1037,34 +1034,43 @@ flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, // (cudaIpcGetMemHandle is incompatible with ncclMemAlloc VMM buffers) // and Step 3 (one-sided MR registration, hetero-only). if (useHomoComm(comm) && !useHeteroComm()) { + // Re-registration: this comm already completed homo backend init + if (regItem->homoRegHandles.count(thisCommKey)) { + return flagcxSuccess; + } void *homoHandle = nullptr; res = cclAdaptors[flagcxCCLAdaptorDevice]->commRegister( comm->homoComm, buff, size, &homoHandle); if (res != flagcxSuccess) goto fail; - regItem->homoRegHandle = homoHandle; + regItem->homoRegHandles[thisCommKey] = homoHandle; return flagcxSuccess; } // Step 2b: Create IPC handle for the buffer (hetero path only) + // Write-once: if localIpcHandleData is already populated, skip { - flagcxIpcMemHandle_t handlePtr = nullptr; - size_t ipcSize = 0; - res = deviceAdaptor->ipcMemHandleCreate(&handlePtr, &ipcSize); - if (res != flagcxSuccess) - goto fail; - res = deviceAdaptor->ipcMemHandleGet(handlePtr, buff); - if (res != flagcxSuccess) { + char zeros[sizeof(flagcxIpcHandleData)] = {}; + if (memcmp(®Item->localIpcHandleData, zeros, + sizeof(flagcxIpcHandleData)) == 0) { + flagcxIpcMemHandle_t handlePtr = nullptr; + size_t ipcSize = 0; + res = deviceAdaptor->ipcMemHandleCreate(&handlePtr, &ipcSize); + if (res != flagcxSuccess) + goto fail; + res = deviceAdaptor->ipcMemHandleGet(handlePtr, buff); + if (res != flagcxSuccess) { + deviceAdaptor->ipcMemHandleFree(handlePtr); + goto fail; + } + if (ipcSize > sizeof(flagcxIpcHandleData)) { + deviceAdaptor->ipcMemHandleFree(handlePtr); + res = flagcxInternalError; + goto fail; + } + memcpy(®Item->localIpcHandleData, handlePtr, ipcSize); deviceAdaptor->ipcMemHandleFree(handlePtr); - goto fail; } - if (ipcSize > sizeof(flagcxIpcHandleData)) { - deviceAdaptor->ipcMemHandleFree(handlePtr); - res = flagcxInternalError; - goto fail; - } - memcpy(®Item->ipcHandleData, handlePtr, ipcSize); - deviceAdaptor->ipcMemHandleFree(handlePtr); } // Step 3: One-sided MR registration (hetero path only) @@ -1081,10 +1087,13 @@ flagcxResult_t flagcxCommRegister(const flagcxComm_t comm, void *buff, fail: // Undo Step 2a - if (useHomoComm(comm) && !useHeteroComm() && regItem->homoRegHandle) { - cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister(comm->homoComm, - regItem->homoRegHandle); - regItem->homoRegHandle = nullptr; + if (useHomoComm(comm) && !useHeteroComm()) { + auto it = regItem->homoRegHandles.find(thisCommKey); + if (it != regItem->homoRegHandles.end()) { + cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister(comm->homoComm, + it->second); + regItem->homoRegHandles.erase(it); + } } // Undo Step 1 globalRegPool.deregisterBuffer(regKey, regItem); @@ -1100,20 +1109,26 @@ flagcxResult_t flagcxCommDeregister(const flagcxComm_t comm, void *handle) { return flagcxSuccess; flagcxRegItem *regItem = reinterpret_cast(handle); - // Backend-specific deregistration (homo path only, last ref only) - if (comm != nullptr && regItem->refCount == 1) { - if (useHomoComm(comm) && !useHeteroComm() && regItem->homoRegHandle) { - cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister( - comm->homoComm, regItem->homoRegHandle); - } - } - - // Clean up globalRegPool (both paths) void *regKey = nullptr; if (comm != nullptr) { regKey = comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; } + + // Backend-specific deregistration (homo path) + if (comm != nullptr) { + uintptr_t thisCommKey = reinterpret_cast(regKey); + if (useHomoComm(comm) && !useHeteroComm()) { + auto it = regItem->homoRegHandles.find(thisCommKey); + if (it != regItem->homoRegHandles.end()) { + cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister(comm->homoComm, + it->second); + regItem->homoRegHandles.erase(it); + } + } + } + + // Clean up globalRegPool (both paths) globalRegPool.deregisterBuffer(regKey, handle); return flagcxSuccess; } From d97dcc052616c5a9332b9be767a8504c2cc1d44b Mon Sep 17 00:00:00 2001 From: MC952-arch Date: Thu, 28 May 2026 14:57:21 +0800 Subject: [PATCH 3/4] test(core): add unit tests for regPool registration APIs --- .github/workflows/test.yml | 91 ++++--- test/script/auto_script.sh | 152 ----------- test/unittest/core/test_reg_pool.cpp | 382 +++++++++++++++++++++++++++ 3 files changed, 443 insertions(+), 182 deletions(-) delete mode 100644 test/script/auto_script.sh create mode 100644 test/unittest/core/test_reg_pool.cpp diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3431d889e..19187002e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Run Container and Execute Tests +name: Perf Tests in Container on: push: @@ -8,8 +8,12 @@ on: branches: - main +env: + MPI_HOME: /usr/local/mpi + PERF_BIN: /__w/FlagCX/FlagCX/test/perf/host_api/build/bin + jobs: - test-in-container: + perf-test: runs-on: [self-hosted, cx-build] container: image: localhost:5000/flagscale:cuda12.8.1-cudnn9.7.1-python3.12-torch2.7.0-time2507111538 @@ -32,42 +36,69 @@ jobs: submodules: true set-safe-directory: true - - name: Set up Python and Install Dependencies + - name: Build FlagCX run: | - apt update -y - apt-get install -y python3 python3-pip python3-venv git - python3 -m venv venv - . venv/bin/activate cd /__w/FlagCX/FlagCX - git config --global --add safe.directory /__w/FlagCX/FlagCX - pip install setuptools pre-commit - pre-commit install + make -j$(nproc) USE_NVIDIA=1 + + - name: Build perf tests + run: | + cd /__w/FlagCX/FlagCX/test/perf + make -j$(nproc) USE_NVIDIA=1 - - name: Run Code Format Check with pre-commit + - name: Wait for GPU run: | cd /__w/FlagCX/FlagCX - . venv/bin/activate - apt update -y - apt-get install clang-format -y - git fetch --all - if [ -n "$GITHUB_HEAD_REF" ] && [ -n "$GITHUB_BASE_REF" ]; then - from_ref="origin/$GITHUB_HEAD_REF" - to_ref="origin/$GITHUB_BASE_REF" - echo "From reference: $from_ref; To reference: $to_ref" - pre-commit run --from-ref "$from_ref" --to-ref "$to_ref" - fi - continue-on-error: false + source test/script/_gpu_check.sh + wait_for_gpu + + - name: "Perf tests (homoRunner)" + run: | + export PATH=$MPI_HOME/bin:$PATH + export LD_LIBRARY_PATH=/__w/FlagCX/FlagCX/build/lib:$LD_LIBRARY_PATH + set -e + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_alltoallv -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_allreduce -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_allgather -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_reducescatter -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_broadcast -b 128M -e 1G -f 2 -r 0 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_gather -b 128M -e 1G -f 2 -r 0 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_scatter -b 128M -e 1G -f 2 -r 0 -p 1 + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_reduce -b 128M -e 1G -f 2 -r 0 -p 1 - - name: Check the current working directory + - name: "Perf tests (uniRunner)" run: | - echo "Current directory: $(pwd)" - ls -l ./test/script + export PATH=$MPI_HOME/bin:$PATH + export LD_LIBRARY_PATH=/__w/FlagCX/FlagCX/build/lib:$LD_LIBRARY_PATH + set -e + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_alltoallv -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_allgather -b 128M -e 1G -f 2 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_broadcast -b 128M -e 1G -f 2 -r 0 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_gather -b 128M -e 1G -f 2 -r 0 -p 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_scatter -b 128M -e 1G -f 2 -r 0 -p 1 - - name: Ensure script has execute permissions - run: chmod +x /__w/FlagCX/FlagCX/test/script/auto_script.sh + - name: "Registration -R 1 (homoRunner)" + run: | + export PATH=$MPI_HOME/bin:$PATH + export LD_LIBRARY_PATH=/__w/FlagCX/FlagCX/build/lib:$LD_LIBRARY_PATH + mpirun -np 8 --allow-run-as-root $PERF_BIN/perf_allreduce -b 128M -e 1G -f 2 -p 1 -R 1 - - name: Run Auto Test Script in Container + - name: "Registration -R 1 (uniRunner P2P)" run: | - cd /__w/FlagCX/FlagCX - ./test/script/auto_script.sh + export PATH=$MPI_HOME/bin:$PATH + export LD_LIBRARY_PATH=/__w/FlagCX/FlagCX/build/lib:$LD_LIBRARY_PATH + set -e + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_VMM_ENABLE=0 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 -R 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_VMM_ENABLE=0 -x FLAGCX_USE_HETERO_COMM=1 $PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 -R 1 + - name: "Registration -R 1 (uniRunner NET)" + run: | + export PATH=$MPI_HOME/bin:$PATH + export LD_LIBRARY_PATH=/__w/FlagCX/FlagCX/build/lib:$LD_LIBRARY_PATH + set -e + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_VMM_ENABLE=0 -x FLAGCX_USE_HETERO_COMM=1 -x FLAGCX_P2P_DISABLE=1 $PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 -R 1 + mpirun -np 8 --allow-run-as-root -x FLAGCX_MEM_ENABLE=1 -x FLAGCX_VMM_ENABLE=0 -x FLAGCX_USE_HETERO_COMM=1 -x FLAGCX_P2P_DISABLE=1 $PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 -R 1 diff --git a/test/script/auto_script.sh b/test/script/auto_script.sh deleted file mode 100644 index d4b6f001e..000000000 --- a/test/script/auto_script.sh +++ /dev/null @@ -1,152 +0,0 @@ -#!/bin/bash -BUILD_DIR="build" - -mkdir -p $BUILD_DIR - -export MPI_HOME=/usr/local/mpi -export PATH=$MPI_HOME/bin:$PATH -make -j$(nproc) USE_NVIDIA=1 - -if [ $? -ne 0 ]; then - echo "Compilation failed!" - exit 1 -fi - -cd test/perf -make -j$(nproc) USE_NVIDIA=1 - -if [ $? -ne 0 ]; then - echo "Test compilation failed!" - exit 1 -fi - -# Perf binaries are in host_api/build/bin/ with perf_ prefix -PERF_BIN=host_api/build/bin - -source ../script/_gpu_check.sh -wait_for_gpu - -mpirun -np 8 ./$PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_alltoall in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_alltoall -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_alltoall in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_alltoallv -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_alltoallv in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_alltoallv -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_alltoallv in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_sendrecv in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_sendrecv -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_sendrecv in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_allreduce -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_allreduce in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_allgather -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_allgather in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_allgather -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_allgather in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_reducescatter -b 128M -e 1G -f 2 -p 1 -if [ $? -ne 0 ]; then - echo "test_reducescatter in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_broadcast -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_broadcast in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_broadcast -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_broadcast in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_gather -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_gather in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_gather -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_gather in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_scatter -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_scatter in homoRunner mode failed!" - exit 1 -fi - -mpirun -np 8 \ - -x FLAGCX_MEM_ENABLE=1 \ - -x FLAGCX_USE_HETERO_COMM=1 \ - ./$PERF_BIN/perf_scatter -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_scatter in uniRunner mode failed!" - exit 1 -fi - -mpirun -np 8 ./$PERF_BIN/perf_reduce -b 128M -e 1G -f 2 -r 0 -p 1 -if [ $? -ne 0 ]; then - echo "test_reduce execution failed!" - exit 1 -fi - -echo "All tests completed successfully!" diff --git a/test/unittest/core/test_reg_pool.cpp b/test/unittest/core/test_reg_pool.cpp new file mode 100644 index 000000000..b2cb7dd88 --- /dev/null +++ b/test/unittest/core/test_reg_pool.cpp @@ -0,0 +1,382 @@ +// Unit tests for flagcxRegPool — buffer registration pool. +// Source: flagcx/core/reg_pool.cc + flagcx/core/include/reg_pool.h +// Links against libflagcx. No MPI, no GPU required. + +#include + +#include "reg_pool.h" +#include +#include + +// Helper: create a fake comm pointer from an integer +static void *fakeComm(uintptr_t id) { return reinterpret_cast(id); } + +// Helper: create a fake proxyConn pointer +static struct flagcxProxyConnector *fakeProxy(uintptr_t id) { + return reinterpret_cast(id); +} + +class RegPoolTest : public ::testing::Test { +protected: + void SetUp() override { + // Use a fresh pool for each test + pool = new flagcxRegPool(); + pageSize = sysconf(_SC_PAGESIZE); + } + void TearDown() override { delete pool; } + + flagcxRegPool *pool; + uintptr_t pageSize; + + // Allocate a page-aligned buffer of given size + void *alignedAddr(uintptr_t base) { + return reinterpret_cast(base * pageSize); + } +}; + +// ============================================================================= +// 1. Basic Registration +// ============================================================================= + +TEST_F(RegPoolTest, RegisterBuffer_NullData_NoOp) { + EXPECT_EQ(pool->registerBuffer(fakeComm(1), nullptr, 1024), flagcxSuccess); + EXPECT_EQ(pool->registerBuffer(fakeComm(1), alignedAddr(10), 0), + flagcxSuccess); +} + +TEST_F(RegPoolTest, RegisterBuffer_SingleBuffer) { + void *data = alignedAddr(10); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(fakeComm(1), data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->refCount, 1); +} + +TEST_F(RegPoolTest, RegisterBuffer_PageAlignment) { + // Use an address in the middle of a page + void *data = reinterpret_cast(10 * pageSize + 100); + size_t length = pageSize + 200; // spans 2+ pages + + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, length), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(fakeComm(1), data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->beginAddr, 10 * pageSize); + // endAddr should be page-aligned ceiling + uintptr_t expectedEnd = + (10 * pageSize + 100 + length + pageSize - 1) & ~(pageSize - 1); + EXPECT_EQ(item->endAddr, expectedEnd); +} + +// ============================================================================= +// 2. Refcounting +// ============================================================================= + +TEST_F(RegPoolTest, RegisterBuffer_SameBufferTwice_IncrementsRefCount) { + void *data = alignedAddr(20); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(fakeComm(1), data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->refCount, 2); +} + +TEST_F(RegPoolTest, DeregisterBuffer_DecrementsRefCount) { + void *data = alignedAddr(30); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(fakeComm(1), data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->refCount, 2); + + ASSERT_EQ(pool->deregisterBuffer(fakeComm(1), item), flagcxSuccess); + + // Item should still exist with refCount=1 + flagcxRegItem *item2 = pool->getItem(fakeComm(1), data); + ASSERT_NE(item2, nullptr); + EXPECT_EQ(item2->refCount, 1); +} + +TEST_F(RegPoolTest, DeregisterBuffer_LastRef_RemovesItem) { + void *data = alignedAddr(40); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(fakeComm(1), data); + ASSERT_NE(item, nullptr); + + ASSERT_EQ(pool->deregisterBuffer(fakeComm(1), item), flagcxSuccess); + + // Item should be gone + EXPECT_EQ(pool->getItem(fakeComm(1), data), nullptr); +} + +// ============================================================================= +// 3. Pointer Stability (Issue #1 validation) +// ============================================================================= + +TEST_F(RegPoolTest, PointerStability_ManyInsertions) { + constexpr int N = 1000; + std::vector items(N); + + // Register N distinct buffers + for (int i = 0; i < N; i++) { + void *data = alignedAddr(100 + i); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + items[i] = pool->getItem(nullptr, data); + ASSERT_NE(items[i], nullptr); + } + + // Verify all pointers are still valid (no dangling after rehash) + for (int i = 0; i < N; i++) { + void *data = alignedAddr(100 + i); + flagcxRegItem *current = pool->getItem(nullptr, data); + EXPECT_EQ(current, items[i]) + << "Pointer mismatch at index " << i << " — likely rehash invalidation"; + EXPECT_EQ(current->beginAddr, (100 + i) * pageSize); + } +} + +// ============================================================================= +// 4. Multi-Comm Semantics (Issue #2 validation) +// ============================================================================= + +TEST_F(RegPoolTest, RegisterBuffer_TwoComms_SameBuffer) { + void *data = alignedAddr(50); + void *commA = fakeComm(0x1000); + void *commB = fakeComm(0x2000); + + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + ASSERT_EQ(pool->registerBuffer(commB, data, pageSize), flagcxSuccess); + + flagcxRegItem *itemA = pool->getItem(commA, data); + flagcxRegItem *itemB = pool->getItem(commB, data); + ASSERT_NE(itemA, nullptr); + ASSERT_NE(itemB, nullptr); + // Same underlying item + EXPECT_EQ(itemA, itemB); + EXPECT_EQ(itemA->refCount, 2); +} + +TEST_F(RegPoolTest, DeregisterBuffer_OneComm_OtherStillValid) { + void *data = alignedAddr(60); + void *commA = fakeComm(0x3000); + void *commB = fakeComm(0x4000); + + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + ASSERT_EQ(pool->registerBuffer(commB, data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + + // Deregister from commA + ASSERT_EQ(pool->deregisterBuffer(commA, item), flagcxSuccess); + + // commB should still find it + flagcxRegItem *itemB = pool->getItem(commB, data); + ASSERT_NE(itemB, nullptr); + EXPECT_EQ(itemB->refCount, 1); + + // commA's mapping is gone, but global fallback still works + flagcxRegItem *itemA = pool->getItem(commA, data); + EXPECT_NE(itemA, nullptr); // found via global fallback +} + +TEST_F(RegPoolTest, RegisterBuffer_NullComm_GlobalOnly) { + void *data = alignedAddr(70); + + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + + // Null comm query finds it + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + + // Non-null comm also finds it via global fallback + flagcxRegItem *item2 = pool->getItem(fakeComm(0x5000), data); + EXPECT_EQ(item2, item); +} + +// ============================================================================= +// 5. Handle Management (Issue #3 validation) +// ============================================================================= + +TEST_F(RegPoolTest, AddNetHandle_StoresOwnerComm) { + void *data = alignedAddr(80); + void *commA = fakeComm(0x6000); + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + + void *fakeHandle = reinterpret_cast(0xDEAD); + auto *proxy = fakeProxy(0xBEEF); + ASSERT_EQ(pool->addNetHandle(commA, item, fakeHandle, proxy), flagcxSuccess); + + ASSERT_EQ(item->handles.size(), 1u); + EXPECT_EQ(item->handles[0].first.handle, fakeHandle); + EXPECT_EQ(item->handles[0].first.proxyConn, proxy); + EXPECT_EQ(item->handles[0].first.ownerComm, commA); +} + +TEST_F(RegPoolTest, AddP2pHandle_StoresOwnerComm) { + void *data = alignedAddr(81); + void *commA = fakeComm(0x7000); + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + + void *fakeHandle = reinterpret_cast(0xCAFE); + auto *proxy = fakeProxy(0xFACE); + ASSERT_EQ(pool->addP2pHandle(commA, item, fakeHandle, proxy), flagcxSuccess); + + ASSERT_EQ(item->handles.size(), 1u); + EXPECT_EQ(item->handles[0].second.handle, fakeHandle); + EXPECT_EQ(item->handles[0].second.proxyConn, proxy); + EXPECT_EQ(item->handles[0].second.ownerComm, commA); +} + +TEST_F(RegPoolTest, AddNetHandle_DuplicateProxyConn_Updates) { + void *data = alignedAddr(82); + void *commA = fakeComm(0x8000); + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + + auto *proxy = fakeProxy(0xAAAA); + void *handle1 = reinterpret_cast(0x1111); + void *handle2 = reinterpret_cast(0x2222); + + ASSERT_EQ(pool->addNetHandle(commA, item, handle1, proxy), flagcxSuccess); + ASSERT_EQ(pool->addNetHandle(commA, item, handle2, proxy), flagcxSuccess); + + // Should update in-place, not add a second entry + EXPECT_EQ(item->handles.size(), 1u); + EXPECT_EQ(item->handles[0].first.handle, handle2); +} + +TEST_F(RegPoolTest, AddNetHandle_NullReg_NoOp) { + EXPECT_EQ(pool->addNetHandle(fakeComm(1), nullptr, nullptr, nullptr), + flagcxSuccess); +} + +// ============================================================================= +// 6. Page Mapping +// ============================================================================= + +TEST_F(RegPoolTest, GetItem_DifferentOffsetSamePage) { + // Register at page boundary + void *data = alignedAddr(90); + ASSERT_EQ(pool->registerBuffer(fakeComm(1), data, pageSize), flagcxSuccess); + + // Query with an offset within the same page + void *offsetData = reinterpret_cast(90 * pageSize + 128); + flagcxRegItem *item = pool->getItem(fakeComm(1), offsetData); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->beginAddr, 90 * pageSize); +} + +TEST_F(RegPoolTest, GetItem_CommSpecificFallsToGlobal) { + void *data = alignedAddr(91); + // Register with null comm (global only) + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + + // Query with a comm that never registered — should fall through to global + flagcxRegItem *item = pool->getItem(fakeComm(0x9000), data); + EXPECT_NE(item, nullptr); +} + +// ============================================================================= +// 7. Edge Cases +// ============================================================================= + +TEST_F(RegPoolTest, DeregisterBuffer_NullHandle_NoOp) { + EXPECT_EQ(pool->deregisterBuffer(fakeComm(1), nullptr), flagcxSuccess); +} + +TEST_F(RegPoolTest, DeregisterBuffer_InvalidHandle_ReturnsError) { + // Create a stack-local regItem that's not in the pool + flagcxRegItem fakeItem; + fakeItem.beginAddr = 999 * pageSize; + fakeItem.endAddr = 1000 * pageSize; + + EXPECT_EQ(pool->deregisterBuffer(fakeComm(1), &fakeItem), flagcxInvalidUsage); +} + +// ============================================================================= +// homoRegHandles (per-comm storage in flagcxRegItem) +// ============================================================================= + +TEST_F(RegPoolTest, HomoRegHandles_PerCommStorage) { + void *data = alignedAddr(100); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + + uintptr_t commKeyA = 0xA000; + uintptr_t commKeyB = 0xB000; + void *handleA = reinterpret_cast(0xAAAA); + void *handleB = reinterpret_cast(0xBBBB); + + item->homoRegHandles[commKeyA] = handleA; + item->homoRegHandles[commKeyB] = handleB; + + EXPECT_EQ(item->homoRegHandles.size(), 2u); + EXPECT_EQ(item->homoRegHandles[commKeyA], handleA); + EXPECT_EQ(item->homoRegHandles[commKeyB], handleB); +} + +TEST_F(RegPoolTest, HomoRegHandles_EraseOneComm_OtherRemains) { + void *data = alignedAddr(101); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + + uintptr_t commKeyA = 0xC000; + uintptr_t commKeyB = 0xD000; + item->homoRegHandles[commKeyA] = reinterpret_cast(0x1); + item->homoRegHandles[commKeyB] = reinterpret_cast(0x2); + + item->homoRegHandles.erase(commKeyA); + + EXPECT_EQ(item->homoRegHandles.count(commKeyA), 0u); + EXPECT_EQ(item->homoRegHandles.count(commKeyB), 1u); + EXPECT_EQ(item->homoRegHandles[commKeyB], reinterpret_cast(0x2)); +} + +// ============================================================================= +// localIpcHandleData (write-once semantics) +// ============================================================================= + +TEST_F(RegPoolTest, LocalIpcHandleData_InitiallyZero) { + void *data = alignedAddr(102); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + + char zeros[sizeof(flagcxIpcHandleData)] = {}; + EXPECT_EQ( + memcmp(&item->localIpcHandleData, zeros, sizeof(flagcxIpcHandleData)), 0); +} + +TEST_F(RegPoolTest, LocalIpcHandleData_WriteOnce) { + void *data = alignedAddr(103); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + + // Simulate writing IPC handle data + char fakeIpc[sizeof(flagcxIpcHandleData)]; + memset(fakeIpc, 0xAB, sizeof(fakeIpc)); + memcpy(&item->localIpcHandleData, fakeIpc, sizeof(flagcxIpcHandleData)); + + // Verify it's non-zero now + char zeros[sizeof(flagcxIpcHandleData)] = {}; + EXPECT_NE( + memcmp(&item->localIpcHandleData, zeros, sizeof(flagcxIpcHandleData)), 0); + + // Verify content matches what we wrote + EXPECT_EQ( + memcmp(&item->localIpcHandleData, fakeIpc, sizeof(flagcxIpcHandleData)), + 0); +} From b85ca8356db579aae9ecd5585c8b57632e1f031a Mon Sep 17 00:00:00 2001 From: MC952-arch Date: Thu, 28 May 2026 16:13:44 +0800 Subject: [PATCH 4/4] Fix issues --- .github/workflows/test.yml | 1 + flagcx/core/include/reg_pool.h | 1 + flagcx/core/include/register.h | 1 - flagcx/core/init.cc | 3 +- flagcx/core/net.cc | 3 +- flagcx/core/p2p.cc | 5 +- flagcx/core/reg_pool.cc | 104 ++++++++++++++++++++------- flagcx/flagcx.cc | 50 +++++++++---- test/script/_gpu_check.sh | 12 +++- test/unittest/core/test_reg_pool.cpp | 51 +++++++++++++ 10 files changed, 186 insertions(+), 45 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 19187002e..2bae088de 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,6 +47,7 @@ jobs: make -j$(nproc) USE_NVIDIA=1 - name: Wait for GPU + shell: bash run: | cd /__w/FlagCX/FlagCX source test/script/_gpu_check.sh diff --git a/flagcx/core/include/reg_pool.h b/flagcx/core/include/reg_pool.h index e96be181f..0440f29f6 100644 --- a/flagcx/core/include/reg_pool.h +++ b/flagcx/core/include/reg_pool.h @@ -26,6 +26,7 @@ class flagcxRegPool { struct flagcxProxyConnector *proxyConn); flagcxResult_t removeRegItemP2pHandles(void *comm, flagcxRegItem *reg); flagcxResult_t removeAllP2pHandles(void *comm); + flagcxResult_t removeAllNetHandles(void *comm); flagcxResult_t registerBuffer(void *comm, void *data, size_t length); flagcxResult_t deregisterBuffer(void *comm, void *handle); std::unordered_map> diff --git a/flagcx/core/include/register.h b/flagcx/core/include/register.h index 4d4337f86..1eb7b5e74 100644 --- a/flagcx/core/include/register.h +++ b/flagcx/core/include/register.h @@ -4,7 +4,6 @@ #include "core.h" #include "device.h" #include -#include #include #define FLAGCX_IPC_HANDLE_SIZE 64 diff --git a/flagcx/core/init.cc b/flagcx/core/init.cc index dca03f73d..363ab6bf3 100644 --- a/flagcx/core/init.cc +++ b/flagcx/core/init.cc @@ -465,8 +465,9 @@ flagcxResult_t flagcxHeteroCommUserRank(const flagcxHeteroComm_t comm, flagcxResult_t flagcxHeteroCommDestroy(flagcxHeteroComm_t comm) { FLAGCXCHECK(flagcxHeteroRmaProxyStop(comm)); - // Clean up P2P IPC handles while proxy is still alive and peerSocks valid + // Clean up P2P/Net handles while proxy is still alive and peerSocks valid FLAGCXCHECK(globalRegPool.removeAllP2pHandles(comm)); + FLAGCXCHECK(globalRegPool.removeAllNetHandles(comm)); // Stop: send stop + close peerSocks FLAGCXCHECK(flagcxProxyStop(comm)); // Destroy: join thread, free proxy resources diff --git a/flagcx/core/net.cc b/flagcx/core/net.cc index 4f925c84c..ebf8d33a9 100644 --- a/flagcx/core/net.cc +++ b/flagcx/core/net.cc @@ -430,7 +430,8 @@ static flagcxResult_t netRegisterBuffer(flagcxHeteroComm *comm, peerProxyConn = &peerConn->proxyConn; for (auto it = regRecord->handles.begin(); it != regRecord->handles.end(); it++) { - if (it->first.proxyConn == peerProxyConn && it->first.handle) { + if (it->first.proxyConn == peerProxyConn && it->first.handle && + it->first.ownerComm == comm) { found = true; outHandle[p] = it->first.handle; *outRegBufFlag = 1; diff --git a/flagcx/core/p2p.cc b/flagcx/core/p2p.cc index d2782ed63..8fb96d67d 100644 --- a/flagcx/core/p2p.cc +++ b/flagcx/core/p2p.cc @@ -764,10 +764,11 @@ static flagcxResult_t p2pRegisterBuffer(flagcxHeteroComm *comm, for (int p = 0; p < nPeers; p++) { int peerRank = peerRanks[p]; - // Check cache: existing info with handleReady for this peer + // Check cache: existing info with handleReady for this peer (this comm + // only) flagcxIpcRegInfo *existingInfo = NULL; for (auto &handlePair : regItem->handles) { - if (handlePair.second.handle) { + if (handlePair.second.handle && handlePair.second.ownerComm == comm) { flagcxIpcRegInfo *info = (flagcxIpcRegInfo *)handlePair.second.handle; if (info->peerRank == peerRank) { existingInfo = info; diff --git a/flagcx/core/reg_pool.cc b/flagcx/core/reg_pool.cc index fa32a32aa..bb0bfd484 100644 --- a/flagcx/core/reg_pool.cc +++ b/flagcx/core/reg_pool.cc @@ -22,12 +22,13 @@ inline void flagcxRegPool::getPagedAddr(void *data, size_t length, flagcxResult_t flagcxRegPool::addNetHandle(void *comm, flagcxRegItem *reg, void *handle, struct flagcxProxyConnector *proxyConn) { - if (reg == nullptr) { + if (reg == nullptr || comm == nullptr) { return flagcxSuccess; } for (auto &handlePair : reg->handles) { if (handlePair.first.proxyConn == proxyConn) { handlePair.first.handle = handle; + handlePair.first.ownerComm = comm; return flagcxSuccess; } } @@ -41,12 +42,13 @@ flagcxRegPool::addNetHandle(void *comm, flagcxRegItem *reg, void *handle, flagcxResult_t flagcxRegPool::addP2pHandle(void *comm, flagcxRegItem *reg, void *handle, struct flagcxProxyConnector *proxyConn) { - if (reg == nullptr) { + if (reg == nullptr || comm == nullptr) { return flagcxSuccess; } for (auto &handlePair : reg->handles) { if (handlePair.second.proxyConn == proxyConn) { handlePair.second.handle = handle; + handlePair.second.ownerComm = comm; return flagcxSuccess; } } @@ -59,13 +61,15 @@ flagcxRegPool::addP2pHandle(void *comm, flagcxRegItem *reg, void *handle, flagcxResult_t flagcxRegPool::removeRegItemNetHandles(void *comm, flagcxRegItem *reg) { - if (comm == nullptr || reg == nullptr) { + if (reg == nullptr) { return flagcxSuccess; } for (size_t i = 0; i < reg->handles.size();) { auto &entry = reg->handles[i]; - if (entry.first.handle) { + // comm == nullptr: remove all; comm != nullptr: remove only this comm's + if (entry.first.handle && + (comm == nullptr || entry.first.ownerComm == comm)) { FLAGCXCHECK(flagcxNetDeregisterBuffer( entry.first.ownerComm, entry.first.proxyConn, entry.first.handle)); entry.first.handle = nullptr; @@ -84,13 +88,15 @@ flagcxResult_t flagcxRegPool::removeRegItemNetHandles(void *comm, flagcxResult_t flagcxRegPool::removeRegItemP2pHandles(void *comm, flagcxRegItem *reg) { - if (comm == nullptr || reg == nullptr) { + if (reg == nullptr) { return flagcxSuccess; } for (size_t i = 0; i < reg->handles.size();) { auto &entry = reg->handles[i]; - if (entry.second.handle) { + // comm == nullptr: remove all; comm != nullptr: remove only this comm's + if (entry.second.handle && + (comm == nullptr || entry.second.ownerComm == comm)) { flagcxIpcRegInfo *ipcInfo = (flagcxIpcRegInfo *)entry.second.handle; FLAGCXCHECK(flagcxP2pDeregisterBuffer( reinterpret_cast(entry.second.ownerComm), @@ -122,6 +128,19 @@ flagcxResult_t flagcxRegPool::removeAllP2pHandles(void *comm) { return flagcxSuccess; } +flagcxResult_t flagcxRegPool::removeAllNetHandles(void *comm) { + if (comm == nullptr) { + return flagcxSuccess; + } + // Iterate over all items in the global pool and remove net handles + // associated with this comm + auto &globalPool = regPool[GLOBAL_POOL_KEY]; + for (auto &pair : globalPool) { + FLAGCXCHECK(removeRegItemNetHandles(comm, pair.second.get())); + } + return flagcxSuccess; +} + void flagcxRegPool::mapRegItemPages(uintptr_t commKey, flagcxRegItem *reg) { if (reg == nullptr) { return; @@ -142,20 +161,53 @@ flagcxResult_t flagcxRegPool::registerBuffer(void *comm, void *data, uintptr_t beginAddr, endAddr; getPagedAddr(data, length, &beginAddr, &endAddr); - // Always check/insert into the global pool (single source of truth) - auto &globalPool = regPool[GLOBAL_POOL_KEY]; - auto it = globalPool.find(beginAddr); - if (it != globalPool.end()) { - // Already registered: bump refCount - it->second->refCount++; - // If comm is non-null, ensure it's mapped in the comm-specific regMap + // Check if ANY page in [beginAddr, endAddr) already belongs to an existing + // item via regMap. This handles partial overlaps where the new buffer starts + // on an unmapped page but overlaps an existing registration. + flagcxRegItem *existing = nullptr; + auto globalMapIt = regMap.find(GLOBAL_POOL_KEY); + if (globalMapIt != regMap.end()) { + for (uintptr_t addr = beginAddr; addr < endAddr; addr += pageSize) { + auto it = globalMapIt->second.find(addr); + if (it != globalMapIt->second.end()) { + existing = it->second; + break; + } + } + } + + if (existing) { + existing->refCount++; + // Extend backward if new buffer starts before existing range + if (beginAddr < existing->beginAddr) { + uintptr_t oldBegin = existing->beginAddr; + existing->beginAddr = beginAddr; + for (uintptr_t addr = beginAddr; addr < oldBegin; addr += pageSize) { + regMap[GLOBAL_POOL_KEY][addr] = existing; + } + // Update regPool key to match new beginAddr + auto &globalPool = regPool[GLOBAL_POOL_KEY]; + auto node = globalPool.extract(oldBegin); + node.key() = beginAddr; + globalPool.insert(std::move(node)); + } + // Extend forward if new buffer goes beyond existing range + if (endAddr > existing->endAddr) { + uintptr_t oldEnd = existing->endAddr; + existing->endAddr = endAddr; + for (uintptr_t addr = oldEnd; addr < endAddr; addr += pageSize) { + regMap[GLOBAL_POOL_KEY][addr] = existing; + } + } + // Ensure comm-specific mapping covers full range if (comm != nullptr) { - mapRegItemPages(commKey, it->second.get()); + mapRegItemPages(commKey, existing); } return flagcxSuccess; } // Not found: create new item in global pool + auto &globalPool = regPool[GLOBAL_POOL_KEY]; auto reg = std::make_unique(); reg->beginAddr = beginAddr; reg->endAddr = endAddr; @@ -212,20 +264,24 @@ flagcxResult_t flagcxRegPool::deregisterBuffer(void *comm, void *handle) { return flagcxSuccess; } - // refCount == 0: full cleanup - FLAGCXCHECK(removeRegItemNetHandles(comm, reg)); - FLAGCXCHECK(removeRegItemP2pHandles(comm, reg)); + // refCount == 0: full cleanup (nullptr = remove all handles) + FLAGCXCHECK(removeRegItemNetHandles(nullptr, reg)); + FLAGCXCHECK(removeRegItemP2pHandles(nullptr, reg)); - // Remove from global regMap - auto globalMapIt = regMap.find(GLOBAL_POOL_KEY); - if (globalMapIt != regMap.end()) { - auto &globalMap = globalMapIt->second; + // Remove ALL regMap entries (global + comm-specific) that reference this item + for (auto mapIt = regMap.begin(); mapIt != regMap.end();) { + auto &pageMap = mapIt->second; for (uintptr_t addr = reg->beginAddr; addr < reg->endAddr; addr += pageSize) { - globalMap.erase(addr); + auto it = pageMap.find(addr); + if (it != pageMap.end() && it->second == reg) { + pageMap.erase(it); + } } - if (globalMap.empty()) { - regMap.erase(globalMapIt); + if (pageMap.empty()) { + mapIt = regMap.erase(mapIt); + } else { + ++mapIt; } } diff --git a/flagcx/flagcx.cc b/flagcx/flagcx.cc index ed28c9ac6..ad4770643 100644 --- a/flagcx/flagcx.cc +++ b/flagcx/flagcx.cc @@ -1109,26 +1109,48 @@ flagcxResult_t flagcxCommDeregister(const flagcxComm_t comm, void *handle) { return flagcxSuccess; flagcxRegItem *regItem = reinterpret_cast(handle); - void *regKey = nullptr; - if (comm != nullptr) { - regKey = - comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + // Null comm: only valid if no backend handles exist on this item + // AND the item is not mapped under any comm-specific key + if (comm == nullptr) { + if (!regItem->homoRegHandles.empty() || !regItem->handles.empty()) { + WARN("flagcxCommDeregister: comm is nullptr but handle has backend " + "registrations that require a valid comm to clean up"); + return flagcxInvalidArgument; + } + // Check if item is mapped under any non-global commKey + auto &globalMap = globalRegPool.getGlobalMap(); + for (auto &[key, pageMap] : globalMap) { + if (key == flagcxRegPool::GLOBAL_POOL_KEY) + continue; + if (pageMap.find(regItem->beginAddr) != pageMap.end()) { + WARN("flagcxCommDeregister: comm is nullptr but handle has " + "comm-specific regMap entries that require a valid comm"); + return flagcxInvalidArgument; + } + } + globalRegPool.deregisterBuffer(nullptr, handle); + return flagcxSuccess; } + void *regKey = + comm->heteroComm ? (void *)comm->heteroComm : (void *)comm->homoComm; + // Backend-specific deregistration (homo path) - if (comm != nullptr) { - uintptr_t thisCommKey = reinterpret_cast(regKey); - if (useHomoComm(comm) && !useHeteroComm()) { - auto it = regItem->homoRegHandles.find(thisCommKey); - if (it != regItem->homoRegHandles.end()) { - cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister(comm->homoComm, - it->second); - regItem->homoRegHandles.erase(it); - } + uintptr_t thisCommKey = reinterpret_cast(regKey); + if (useHomoComm(comm) && !useHeteroComm()) { + auto it = regItem->homoRegHandles.find(thisCommKey); + if (it != regItem->homoRegHandles.end()) { + cclAdaptors[flagcxCCLAdaptorDevice]->commDeregister(comm->homoComm, + it->second); + regItem->homoRegHandles.erase(it); } } - // Clean up globalRegPool (both paths) + // Remove this comm's net/p2p handles from the regItem + globalRegPool.removeRegItemNetHandles(regKey, regItem); + globalRegPool.removeRegItemP2pHandles(regKey, regItem); + + // Clean up globalRegPool (refCount--, page mappings, item removal at 0) globalRegPool.deregisterBuffer(regKey, handle); return flagcxSuccess; } diff --git a/test/script/_gpu_check.sh b/test/script/_gpu_check.sh index a75c8ca10..7199edd53 100644 --- a/test/script/_gpu_check.sh +++ b/test/script/_gpu_check.sh @@ -7,8 +7,16 @@ wait_for_gpu() { while true; do - IFS=$'\n' read -d '' -r -a memory_usage_array <<< "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits)" - IFS=$'\n' read -d '' -r -a memory_total_array <<< "$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits)" + if ! mapfile -t memory_usage_array < <(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits); then + echo "nvidia-smi failed, retrying..." + sleep 1m + continue + fi + if ! mapfile -t memory_total_array < <(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits); then + echo "nvidia-smi failed, retrying..." + sleep 1m + continue + fi need_wait=false diff --git a/test/unittest/core/test_reg_pool.cpp b/test/unittest/core/test_reg_pool.cpp index b2cb7dd88..30734f319 100644 --- a/test/unittest/core/test_reg_pool.cpp +++ b/test/unittest/core/test_reg_pool.cpp @@ -380,3 +380,54 @@ TEST_F(RegPoolTest, LocalIpcHandleData_WriteOnce) { memcmp(&item->localIpcHandleData, fakeIpc, sizeof(flagcxIpcHandleData)), 0); } + +// ============================================================================= +// 10. Register/Deregister Symmetry (API contract scenarios) +// ============================================================================= + +TEST_F(RegPoolTest, RegisterNullComm_DeregisterNullComm_Works) { + // Register(nullptr) + Deregister(nullptr) → pool-only, works + void *data = alignedAddr(200); + ASSERT_EQ(pool->registerBuffer(nullptr, data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(nullptr, data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->refCount, 1); + + ASSERT_EQ(pool->deregisterBuffer(nullptr, item), flagcxSuccess); + EXPECT_EQ(pool->getItem(nullptr, data), nullptr); +} + +TEST_F(RegPoolTest, RegisterComm_DeregisterComm_Works) { + // Register(comm) + Deregister(comm) → full cleanup, works + void *data = alignedAddr(201); + void *commA = fakeComm(0xA000); + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + EXPECT_EQ(item->refCount, 1); + + ASSERT_EQ(pool->deregisterBuffer(commA, item), flagcxSuccess); + EXPECT_EQ(pool->getItem(commA, data), nullptr); + EXPECT_EQ(pool->getItem(nullptr, data), nullptr); +} + +TEST_F(RegPoolTest, RegisterComm_DeregisterNullComm_PoolCleanupOnly) { + // Register(comm) + Deregister(nullptr) → pool removes item but cannot + // clean backend handles. At pool level this succeeds (pool doesn't know + // about backend handles). The flagcxCommDeregister layer guards this. + void *data = alignedAddr(202); + void *commA = fakeComm(0xB000); + ASSERT_EQ(pool->registerBuffer(commA, data, pageSize), flagcxSuccess); + + flagcxRegItem *item = pool->getItem(commA, data); + ASSERT_NE(item, nullptr); + + // Pool-level deregister with nullptr succeeds (no backend awareness) + ASSERT_EQ(pool->deregisterBuffer(nullptr, item), flagcxSuccess); + // Item removed from global pool + EXPECT_EQ(pool->getItem(nullptr, data), nullptr); + // Comm-specific mappings also cleaned up (no dangling pointers) + EXPECT_EQ(pool->getItem(commA, data), nullptr); +}