From 13f49dc78ba5d9beefaffb067845d9a5db6eb3cc Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Fri, 15 May 2026 14:22:09 +0800 Subject: [PATCH 01/11] implement O(1) mr lookup --- flagcx/core/flagcx_p2p.cc | 123 +++++++++++++++++++++++++++++++------- 1 file changed, 101 insertions(+), 22 deletions(-) diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 21deb1bc..fe0d3edd 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -12,6 +12,7 @@ #include "flagcx_p2p.h" #include "adaptor.h" +#include "debug.h" #include "flagcx_net.h" #include "flagcx_net_adaptor.h" #include "ib_common.h" @@ -174,6 +175,7 @@ static std::vector gNotifyList; static std::mutex gNotifyMutex; static std::unordered_map gMemRegInfo; +static std::unordered_map gMrToBaseAddr; static std::mutex gMemMutex; static uint64_t gNextMrId = 1; @@ -374,15 +376,28 @@ static bool findMemReg(uintptr_t addr, FlagcxP2pMemRegEntry *out) { } static FlagcxP2pMemRegEntry *findMemRegByMr(FlagcxP2pMr mr) { - for (std::unordered_map::iterator it = - gMemRegInfo.begin(); - it != gMemRegInfo.end(); ++it) { - if (it->second.mrId == mr) - return &it->second; - } + std::unordered_map::const_iterator mrIt = + gMrToBaseAddr.find(mr); + if (mrIt == gMrToBaseAddr.end()) + return NULL; + + std::unordered_map::iterator entryIt = + gMemRegInfo.find(mrIt->second); + if (entryIt != gMemRegInfo.end()) + return &entryIt->second; + return NULL; } +static bool memRegContains(const FlagcxP2pMemRegEntry &entry, uintptr_t addr, + size_t size) { + if (addr < entry.baseAddr) + return false; + + const uintptr_t offset = addr - entry.baseAddr; + return offset <= entry.size && size <= entry.size - offset; +} + static int resolveIbDevN(int netDev) { if (netDev < 0 || netDev >= flagcxNMergedIbDevs) return 0; @@ -1113,6 +1128,7 @@ void flagcxP2pEngineDestroy(FlagcxP2pEngine *engine) { engine->adaptor->deregMr(&devCtx, it->second.mhandle); } gMemRegInfo.clear(); + gMrToBaseAddr.clear(); } if (engine->topoMgr) { @@ -1308,6 +1324,7 @@ int flagcxP2pEngineReg(FlagcxP2pEngine *engine, uintptr_t data, size_t size, gMemRegInfo.find(data); if (existing != gMemRegInfo.end()) { mrId = existing->second.mrId; + gMrToBaseAddr[mrId] = existing->first; return 0; } @@ -1337,6 +1354,7 @@ int flagcxP2pEngineReg(FlagcxP2pEngine *engine, uintptr_t data, size_t size, } gMemRegInfo[data] = entry; + gMrToBaseAddr[entry.mrId] = data; mrId = entry.mrId; return 0; } @@ -1346,18 +1364,24 @@ void flagcxP2pEngineMrDestroy(FlagcxP2pEngine *engine, FlagcxP2pMr mr) { return; std::lock_guard lock(gMemMutex); - for (std::unordered_map::iterator it = - gMemRegInfo.begin(); - it != gMemRegInfo.end(); ++it) { - if (it->second.mrId == mr) { - struct { - int ibDevN; - } devCtx = {it->second.ibDevN}; - engine->adaptor->deregMr(&devCtx, it->second.mhandle); - gMemRegInfo.erase(it); - return; - } + std::unordered_map::iterator mrIt = + gMrToBaseAddr.find(mr); + if (mrIt == gMrToBaseAddr.end()) + return; + + std::unordered_map::iterator entryIt = + gMemRegInfo.find(mrIt->second); + if (entryIt == gMemRegInfo.end()) { + gMrToBaseAddr.erase(mrIt); + return; } + + struct { + int ibDevN; + } devCtx = {entryIt->second.ibDevN}; + engine->adaptor->deregMr(&devCtx, entryIt->second.mhandle); + gMemRegInfo.erase(entryIt); + gMrToBaseAddr.erase(mrIt); } int flagcxP2pEnginePrepareDesc(FlagcxP2pEngine *engine, FlagcxP2pMr mr, @@ -1457,22 +1481,73 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, std::vector descs, int numIovs, uint64_t *transferId, std::vector ipcBufs) { - (void)mrIds; - if (conn == NULL || numIovs <= 0 || transferId == NULL) + if (conn == NULL || numIovs <= 0 || transferId == NULL) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: invalid args (conn=%p, " + "numIovs=%d, transferId=%p)\n", + conn, numIovs, (void *)transferId); return -1; + } + + if (dstVec.size() < static_cast(numIovs) || + sizeVec.size() < static_cast(numIovs) || + descs.size() < static_cast(numIovs)) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: vector length mismatch " + "(numIovs=%d)\n", + numIovs); + return -1; + } if (conn->isLocal && (conn->sameProcess || !ipcBufs.empty())) { - return startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, transferId, - ipcBufs, false); + fprintf(stderr, + "[FlagCX P2P] ReadVector taking local transfer path: numIovs=%d\n", + numIovs); + int rc = startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, + transferId, ipcBufs, false); + fprintf(stderr, "[FlagCX P2P] ReadVector local transfer returned: rc=%d\n", + rc); + return rc; + } + + if (mrIds.size() < static_cast(numIovs)) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: mrIds length mismatch " + "(numIovs=%d)\n", + numIovs); + return -1; } std::vector localEntries(numIovs); { + auto t0 = std::chrono::high_resolution_clock::now(); std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { - if (!findMemReg((uintptr_t)dstVec[i], &localEntries[i])) + FlagcxP2pMemRegEntry *entry = findMemRegByMr(mrIds[i]); + if (entry == NULL) { + fprintf( + stderr, + "[FlagCX P2P] ReadVector memReg lookup failed: iov=%d, mr=%lu\n", i, + (unsigned long)mrIds[i]); return -1; + } + + if (!memRegContains(*entry, reinterpret_cast(dstVec[i]), + sizeVec[i])) { + fprintf(stderr, + "[FlagCX P2P] ReadVector memReg bounds check failed: iov=%d, " + "mr=%lu, addr=%p, size=%zu\n", + i, (unsigned long)mrIds[i], dstVec[i], sizeVec[i]); + return -1; + } + + localEntries[i] = *entry; } + auto t1 = std::chrono::high_resolution_clock::now(); + double ms = std::chrono::duration(t1 - t0).count(); + fprintf(stderr, + "[FlagCX P2P] ReadVector memReg lookup: numIovs=%d, time=%.4f ms\n", + numIovs, ms); } ensureAsyncWorkerStarted(); @@ -1501,6 +1576,10 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, pthread_cond_signal(&gAsyncWorker.cv); *transferId = xferId; + fprintf( + stderr, + "[FlagCX P2P] ReadVector submitted async task: xferId=%lu, numIovs=%d\n", + (unsigned long)xferId, numIovs); return 0; } From a4d9aed8dbb9296747b189f48c3e4e43a1aafd6a Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Mon, 18 May 2026 15:00:19 +0800 Subject: [PATCH 02/11] modify MR lookup logic in write vector function as well --- flagcx/core/flagcx_p2p.cc | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index fe0d3edd..babffe56 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -1650,21 +1650,35 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, std::vector descs, int numIovs, uint64_t *transferId, std::vector ipcBufs) { - (void)mrIds; if (conn == NULL || numIovs <= 0 || transferId == NULL) return -1; + if (dstVec.size() < static_cast(numIovs) || + sizeVec.size() < static_cast(numIovs) || + descs.size() < static_cast(numIovs)) + return -1; + if (conn->isLocal && (conn->sameProcess || !ipcBufs.empty())) { return startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, transferId, ipcBufs, true); } + if (mrIds.size() < static_cast(numIovs)) + return -1; + std::vector localEntries(numIovs); { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { - if (!findMemReg((uintptr_t)dstVec[i], &localEntries[i])) + FlagcxP2pMemRegEntry *entry = findMemRegByMr(mrIds[i]); + if (entry == NULL) + return -1; + + if (!memRegContains(*entry, reinterpret_cast(dstVec[i]), + sizeVec[i])) return -1; + + localEntries[i] = *entry; } } From 9f834d2994c2774a63771f7d6e83eac26700306a Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Mon, 18 May 2026 17:39:57 +0800 Subject: [PATCH 03/11] add batch poll CQ --- flagcx/adaptor/include/flagcx_net_adaptor.h | 4 + flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 116 ++++++++++++++++++-- flagcx/core/flagcx_p2p.cc | 70 ++++++++---- 3 files changed, 163 insertions(+), 27 deletions(-) diff --git a/flagcx/adaptor/include/flagcx_net_adaptor.h b/flagcx/adaptor/include/flagcx_net_adaptor.h index 169e5571..ce1c6071 100644 --- a/flagcx/adaptor/include/flagcx_net_adaptor.h +++ b/flagcx/adaptor/include/flagcx_net_adaptor.h @@ -135,6 +135,10 @@ struct flagcxNetAdaptor_latest { const size_t *sizes, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **requests, int *posted); + // Optional batch completion test — polls CQ once for multiple requests. + // If NULL, caller falls back to per-request test(). + flagcxResult_t (*testBatch)(void **requests, int nRequests, int *doneFlags, + int *doneCount); }; #define flagcxNetAdaptor flagcxNetAdaptor_latest diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 4860a20e..841bb928 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -72,6 +72,7 @@ struct flagcxP2pConnMeta { #define FLAGCX_P2P_REQ_UNUSED 0 #define FLAGCX_P2P_REQ_IPUT 1 #define FLAGCX_P2P_REQ_IGET 2 +#define FLAGCX_P2P_BATCH_POLL_SIZE 32 struct flagcxP2pRequest { int type; @@ -672,6 +673,88 @@ static flagcxResult_t flagcxP2pTest(void *request, int *done, int *sizes) { return flagcxSuccess; } +static flagcxResult_t flagcxP2pTestBatch(void **requests, int nRequests, + int *doneFlags, int *doneCount) { + if (nRequests == 0) { + *doneCount = 0; + return flagcxSuccess; + } + + // Initialize all done flags to 0 + for (int i = 0; i < nRequests; i++) { + doneFlags[i] = 0; + } + + // Get CQ from first valid request (all requests share the same CQ) + struct ibv_cq *cq = nullptr; + struct flagcxP2pRequest *baseReqs = nullptr; + for (int i = 0; i < nRequests; i++) { + struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)requests[i]; + if (req != nullptr && req->type != FLAGCX_P2P_REQ_UNUSED) { + cq = req->cq; + baseReqs = req->reqs; + break; + } + } + + if (cq == nullptr) { + // All requests are NULL or already done + int completed = 0; + for (int i = 0; i < nRequests; i++) { + struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)requests[i]; + if (req == nullptr || req->type == FLAGCX_P2P_REQ_UNUSED) { + doneFlags[i] = 1; + completed++; + } + } + *doneCount = completed; + return flagcxSuccess; + } + + // Batch poll the CQ + struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; + int nCqe = 0; + FLAGCXCHECK(flagcxWrapIbvPollCq(cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, &nCqe)); + + int completed = 0; + + // Process all returned completions + for (int w = 0; w < nCqe; w++) { + struct ibv_wc *wc = &wcs[w]; + + if (wc->status != IBV_WC_SUCCESS) { + WARN("NET/IB_P2P : CQ error: status=%d opcode=%d wr_id=%lu", wc->status, + wc->opcode, wc->wr_id); + return flagcxRemoteError; + } + + // Map CQE back to request via wr_id + uint32_t reqIdx = wc->wr_id; + if (reqIdx >= FLAGCX_P2P_MAX_REQUESTS) { + WARN("NET/IB_P2P : invalid wr_id %u in CQE", reqIdx); + return flagcxInternalError; + } + struct flagcxP2pRequest *completedReq = &baseReqs[reqIdx]; + + completedReq->events--; + if (completedReq->events == 0) { + completedReq->type = FLAGCX_P2P_REQ_UNUSED; + + // Check if this completion matches any of our requested requests + for (int i = 0; i < nRequests; i++) { + if (requests[i] == completedReq) { + doneFlags[i] = 1; + completed++; + break; + } + } + } + } + + *doneCount = completed; + return flagcxSuccess; +} + /* ------------------------------------------------------------------ */ /* Close */ /* ------------------------------------------------------------------ */ @@ -793,20 +876,39 @@ static flagcxResult_t flagcxP2pGetDevFromName(char *name, int *dev) { struct flagcxNetAdaptor flagcxNetIbP2p = { // Basic functions - "IB_P2P", flagcxP2pInit, flagcxP2pDevices, flagcxP2pGetProperties, + "IB_P2P", + flagcxP2pInit, + flagcxP2pDevices, + flagcxP2pGetProperties, // Setup functions - flagcxP2pListen, flagcxP2pConnect, flagcxP2pAccept, flagcxP2pCloseSend, - flagcxP2pCloseRecv, flagcxP2pCloseListen, + flagcxP2pListen, + flagcxP2pConnect, + flagcxP2pAccept, + flagcxP2pCloseSend, + flagcxP2pCloseRecv, + flagcxP2pCloseListen, // Memory region functions - flagcxP2pRegMr, flagcxP2pRegMrDmaBuf, flagcxP2pDeregMr, + flagcxP2pRegMr, + flagcxP2pRegMrDmaBuf, + flagcxP2pDeregMr, // Two-sided functions (stubs) - flagcxP2pIsend, flagcxP2pIrecv, flagcxP2pIflush, flagcxP2pTest, + flagcxP2pIsend, + flagcxP2pIrecv, + flagcxP2pIflush, + flagcxP2pTest, // One-sided functions - flagcxP2pIput, flagcxP2pIget, flagcxP2pIputSignal, + flagcxP2pIput, + flagcxP2pIget, + flagcxP2pIputSignal, // Device name lookup - flagcxP2pGetDevFromName}; + flagcxP2pGetDevFromName, + + // Optional batch operations + nullptr, // iputBatch + flagcxP2pTestBatch, // testBatch +}; diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index babffe56..d8d1059c 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -188,7 +188,6 @@ static uint64_t gNextXferId = 1; /* ------------------------------------------------------------------ */ static constexpr int kWindowSize = 64; -static constexpr int kBatchPollCqe = 32; enum AsyncXferOp { ASYNC_XFER_READ, ASYNC_XFER_WRITE }; @@ -291,27 +290,58 @@ static void asyncWorkerFunc() { // Batch-poll completions for in-flight requests int newlyCompleted = 0; - for (int i = completed; i < issued; i++) { - int slot = i % kWindowSize; - if (inflightReqs[slot] == nullptr) { - if (i == completed) - completed++; - continue; + + if (adaptor->testBatch != nullptr) { + // Collect non-null in-flight requests for batch testing + void *batchRequests[kWindowSize]; + int batchIndices[kWindowSize]; + int batchCount = 0; + + for (int i = completed; i < issued; i++) { + int slot = i % kWindowSize; + if (inflightReqs[slot] != nullptr) { + batchRequests[batchCount] = inflightReqs[slot]; + batchIndices[batchCount] = i; + batchCount++; + } } - int done = 0, sizes = 0; - flagcxResult_t res = adaptor->test(inflightReqs[slot], &done, &sizes); - if (res != flagcxSuccess) { - inflightReqs[slot] = nullptr; - if (i == completed) - completed++; - newlyCompleted++; - continue; + + if (batchCount > 0) { + int doneFlags[kWindowSize]; + int doneCount = 0; + flagcxResult_t res = adaptor->testBatch(batchRequests, batchCount, + doneFlags, &doneCount); + if (res != flagcxSuccess) { + error = true; + } else { + for (int b = 0; b < batchCount; b++) { + if (doneFlags[b]) { + int i = batchIndices[b]; + int slot = i % kWindowSize; + inflightReqs[slot] = nullptr; + newlyCompleted++; + } + } + } } - if (done) { - inflightReqs[slot] = nullptr; - if (i == completed) - completed++; - newlyCompleted++; + } else { + // Fallback: per-request polling + for (int i = completed; i < issued; i++) { + int slot = i % kWindowSize; + if (inflightReqs[slot] == nullptr) { + continue; + } + int done = 0, sizes = 0; + flagcxResult_t res = adaptor->test(inflightReqs[slot], &done, &sizes); + if (res != flagcxSuccess) { + inflightReqs[slot] = nullptr; + newlyCompleted++; + continue; + } + if (done) { + inflightReqs[slot] = nullptr; + newlyCompleted++; + } } } From 8d6377beb5330d723dd92532e3e1c8e8aa5efff3 Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Tue, 19 May 2026 11:08:01 +0800 Subject: [PATCH 04/11] move CQ polling to separate thread --- flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 251 +++++++++++++++---------- 1 file changed, 155 insertions(+), 96 deletions(-) diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 841bb928..1bfa2db3 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -15,11 +15,14 @@ #include #include +#include #include +#include #include #include #include #include +#include /* ------------------------------------------------------------------ */ /* Internal structs */ @@ -79,6 +82,7 @@ struct flagcxP2pRequest { int events; // outstanding CQEs expected struct ibv_cq *cq; // CQ to poll for this request struct flagcxP2pRequest *reqs; // back-pointer to owning reqs[] array + std::atomic *reqDone; // pointer to owning comm's reqDone[] array }; // P2P send comm — one QP, one CQ, blocking connect @@ -90,6 +94,8 @@ struct flagcxP2pSendComm { struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; uint64_t putSignalScratchpad; struct ibv_mr *putSignalScratchpadMr; + std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; + std::atomic cqError{false}; }; // P2P recv comm — symmetric with send comm so both sides can initiate transfers @@ -101,6 +107,8 @@ struct flagcxP2pRecvComm { struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; uint64_t putSignalScratchpad; struct ibv_mr *putSignalScratchpadMr; + std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; + std::atomic cqError{false}; }; /* ------------------------------------------------------------------ */ @@ -111,11 +119,128 @@ static struct flagcxP2pDevCtx flagcxP2pDevCtxs[MAX_IB_DEVS]; static int flagcxP2pInitialized = 0; static pthread_mutex_t flagcxP2pInitLock = PTHREAD_MUTEX_INITIALIZER; +/* ------------------------------------------------------------------ */ +/* Background CQ Poller */ +/* ------------------------------------------------------------------ */ + +struct CqPollEntry { + struct ibv_cq *cq; + struct flagcxP2pRequest *reqs; + std::atomic *reqDone; + std::atomic *cqError; + bool active; +}; + +struct CqPoller { + std::thread thread; + std::mutex mutex; + std::vector entries; + std::atomic running{false}; +}; + +static CqPoller gCqPoller; + +static void cqPollerFunc() { + while (gCqPoller.running.load(std::memory_order_relaxed)) { + std::vector snapshot; + { + std::lock_guard lock(gCqPoller.mutex); + snapshot = gCqPoller.entries; + } + + bool anyWork = false; + for (auto &entry : snapshot) { + if (!entry.active) + continue; + + struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; + int nCqe = 0; + if (flagcxWrapIbvPollCq(entry.cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, + &nCqe) != flagcxSuccess) { + entry.cqError->store(true, std::memory_order_release); + continue; + } + + for (int i = 0; i < nCqe; i++) { + if (wcs[i].status != IBV_WC_SUCCESS) { + WARN("NET/IB_P2P : CQ poller got error status %d for wr_id %lu", + wcs[i].status, wcs[i].wr_id); + entry.cqError->store(true, std::memory_order_release); + break; + } + uint32_t idx = (uint32_t)wcs[i].wr_id; + if (idx >= FLAGCX_P2P_MAX_REQUESTS) + continue; + + entry.reqs[idx].events--; + if (entry.reqs[idx].events == 0) { + entry.reqs[idx].type = FLAGCX_P2P_REQ_UNUSED; + entry.reqDone[idx].store(1, std::memory_order_release); + } + } + if (nCqe > 0) + anyWork = true; + } + + if (!anyWork) { + std::this_thread::sleep_for(std::chrono::microseconds(1)); + } + } +} + +static void ensureCqPollerStarted() { + if (!gCqPoller.running.load(std::memory_order_acquire)) { + std::lock_guard lock(gCqPoller.mutex); + if (!gCqPoller.running.load(std::memory_order_relaxed)) { + gCqPoller.running.store(true, std::memory_order_release); + gCqPoller.thread = std::thread(cqPollerFunc); + } + } +} + +static void cqPollerRegister(struct ibv_cq *cq, struct flagcxP2pRequest *reqs, + std::atomic *reqDone, + std::atomic *cqError) { + ensureCqPollerStarted(); + std::lock_guard lock(gCqPoller.mutex); + gCqPoller.entries.push_back({cq, reqs, reqDone, cqError, true}); +} + +static void cqPollerStop() { + if (gCqPoller.running.load(std::memory_order_acquire)) { + gCqPoller.running.store(false, std::memory_order_release); + if (gCqPoller.thread.joinable()) { + gCqPoller.thread.join(); + } + std::lock_guard lock(gCqPoller.mutex); + gCqPoller.entries.clear(); + } +} + +static void cqPollerUnregister(struct ibv_cq *cq) { + bool anyActive = false; + { + std::lock_guard lock(gCqPoller.mutex); + for (auto &entry : gCqPoller.entries) { + if (entry.cq == cq) { + entry.active = false; + } else if (entry.active) { + anyActive = true; + } + } + } + std::this_thread::sleep_for(std::chrono::microseconds(100)); + if (!anyActive) { + cqPollerStop(); + } +} + /* ------------------------------------------------------------------ */ /* Request helpers */ /* ------------------------------------------------------------------ */ static flagcxResult_t flagcxP2pGetRequest(struct flagcxP2pRequest *reqs, + std::atomic *reqDone, struct ibv_cq *cq, int type, struct flagcxP2pRequest **req) { for (int i = 0; i < FLAGCX_P2P_MAX_REQUESTS; i++) { @@ -124,6 +249,8 @@ static flagcxResult_t flagcxP2pGetRequest(struct flagcxP2pRequest *reqs, reqs[i].events = 0; reqs[i].cq = cq; reqs[i].reqs = reqs; + reqs[i].reqDone = reqDone; + reqDone[i].store(0, std::memory_order_relaxed); *req = &reqs[i]; return flagcxSuccess; } @@ -411,6 +538,8 @@ static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); + cqPollerRegister(comm->base.cq, comm->reqs, comm->reqDone, &comm->cqError); + *sendComm = comm; return flagcxSuccess; } @@ -469,6 +598,8 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); + cqPollerRegister(comm->base.cq, comm->reqs, comm->reqDone, &comm->cqError); + *recvComm = comm; return flagcxSuccess; } @@ -486,7 +617,7 @@ static flagcxResult_t flagcxP2pIput(void *sendComm, uint64_t srcOff, struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, FLAGCX_P2P_REQ_IPUT, &req)); struct ibv_sge sge; @@ -527,7 +658,7 @@ static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, FLAGCX_P2P_REQ_IGET, &req)); struct ibv_sge sge; @@ -569,7 +700,7 @@ flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, (struct flagcxP2pMrHandle *)signalHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, FLAGCX_P2P_REQ_IPUT, &req)); bool chainData = (size > 0 && srcHandles != NULL && dstHandles != NULL); @@ -638,34 +769,14 @@ static flagcxResult_t flagcxP2pTest(void *request, int *done, int *sizes) { return flagcxSuccess; } - int nCqe = 0; - struct ibv_wc wc; - FLAGCXCHECK(flagcxWrapIbvPollCq(req->cq, 1, &wc, &nCqe)); - - if (nCqe == 0) - return flagcxSuccess; - - if (wc.status != IBV_WC_SUCCESS) { - WARN("NET/IB_P2P : CQ error: status=%d opcode=%d wr_id=%lu", wc.status, - wc.opcode, wc.wr_id); - return flagcxRemoteError; - } - - // Map CQE back to the correct request via wr_id - uint32_t reqIdx = wc.wr_id; - if (reqIdx >= FLAGCX_P2P_MAX_REQUESTS) { - WARN("NET/IB_P2P : invalid wr_id %u in CQE", reqIdx); + uint32_t idx = (uint32_t)(req - req->reqs); + if (idx >= FLAGCX_P2P_MAX_REQUESTS) { + WARN("NET/IB_P2P : invalid request index %u in test()", idx); return flagcxInternalError; } - struct flagcxP2pRequest *completedReq = &req->reqs[reqIdx]; - completedReq->events--; - if (completedReq->events == 0) { - completedReq->type = FLAGCX_P2P_REQ_UNUSED; - } - - // Check if the originally requested op is done - if (req->type == FLAGCX_P2P_REQ_UNUSED) { + if (req->reqDone[idx].load(std::memory_order_acquire)) { + req->reqDone[idx].store(0, std::memory_order_relaxed); *done = 1; if (sizes) *sizes = 0; @@ -675,82 +786,26 @@ static flagcxResult_t flagcxP2pTest(void *request, int *done, int *sizes) { static flagcxResult_t flagcxP2pTestBatch(void **requests, int nRequests, int *doneFlags, int *doneCount) { - if (nRequests == 0) { - *doneCount = 0; - return flagcxSuccess; - } - - // Initialize all done flags to 0 + int completed = 0; for (int i = 0; i < nRequests; i++) { doneFlags[i] = 0; - } - - // Get CQ from first valid request (all requests share the same CQ) - struct ibv_cq *cq = nullptr; - struct flagcxP2pRequest *baseReqs = nullptr; - for (int i = 0; i < nRequests; i++) { struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)requests[i]; - if (req != nullptr && req->type != FLAGCX_P2P_REQ_UNUSED) { - cq = req->cq; - baseReqs = req->reqs; - break; + if (req == NULL || req->type == FLAGCX_P2P_REQ_UNUSED) { + doneFlags[i] = 1; + completed++; + continue; } - } - if (cq == nullptr) { - // All requests are NULL or already done - int completed = 0; - for (int i = 0; i < nRequests; i++) { - struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)requests[i]; - if (req == nullptr || req->type == FLAGCX_P2P_REQ_UNUSED) { - doneFlags[i] = 1; - completed++; - } - } - *doneCount = completed; - return flagcxSuccess; - } - - // Batch poll the CQ - struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; - int nCqe = 0; - FLAGCXCHECK(flagcxWrapIbvPollCq(cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, &nCqe)); - - int completed = 0; - - // Process all returned completions - for (int w = 0; w < nCqe; w++) { - struct ibv_wc *wc = &wcs[w]; + uint32_t idx = (uint32_t)(req - req->reqs); + if (idx >= FLAGCX_P2P_MAX_REQUESTS) + continue; - if (wc->status != IBV_WC_SUCCESS) { - WARN("NET/IB_P2P : CQ error: status=%d opcode=%d wr_id=%lu", wc->status, - wc->opcode, wc->wr_id); - return flagcxRemoteError; - } - - // Map CQE back to request via wr_id - uint32_t reqIdx = wc->wr_id; - if (reqIdx >= FLAGCX_P2P_MAX_REQUESTS) { - WARN("NET/IB_P2P : invalid wr_id %u in CQE", reqIdx); - return flagcxInternalError; - } - struct flagcxP2pRequest *completedReq = &baseReqs[reqIdx]; - - completedReq->events--; - if (completedReq->events == 0) { - completedReq->type = FLAGCX_P2P_REQ_UNUSED; - - // Check if this completion matches any of our requested requests - for (int i = 0; i < nRequests; i++) { - if (requests[i] == completedReq) { - doneFlags[i] = 1; - completed++; - break; - } - } + if (req->reqDone[idx].load(std::memory_order_acquire)) { + req->reqDone[idx].store(0, std::memory_order_relaxed); + doneFlags[i] = 1; + completed++; } } - *doneCount = completed; return flagcxSuccess; } @@ -793,6 +848,8 @@ static void flagcxP2pDrainCq(struct ibv_cq *cq) { static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; if (comm) { + if (comm->base.cq) + cqPollerUnregister(comm->base.cq); flagcxP2pDrainCq(comm->base.cq); if (comm->qp.qp) FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); @@ -810,6 +867,8 @@ static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { static flagcxResult_t flagcxP2pCloseRecv(void *recvComm) { struct flagcxP2pRecvComm *comm = (struct flagcxP2pRecvComm *)recvComm; if (comm) { + if (comm->base.cq) + cqPollerUnregister(comm->base.cq); flagcxP2pDrainCq(comm->base.cq); if (comm->qp.qp) FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); From 387659fc1046958c4deb6f2cda41473a3b0ec09e Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Thu, 21 May 2026 15:18:26 +0800 Subject: [PATCH 05/11] add multi-QP implementation --- flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 384 +++++++++++++++++-------- flagcx/core/flagcx_p2p.cc | 21 +- 2 files changed, 271 insertions(+), 134 deletions(-) diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 1bfa2db3..9eff3199 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -4,7 +4,7 @@ * IBRC P2P Net Adaptor — implements flagcxNetAdaptor for one-sided RDMA * (P2P) use cases. Shares IB device discovery and utility code with the * existing IBRC adaptor but uses P2P-native handle formats, eager PD - * allocation, and simplified (single-QP, no-FIFO) connection setup. + * allocation, and simplified (no-FIFO) connection setup. ************************************************************************/ #include "flagcx_common.h" @@ -76,6 +76,7 @@ struct flagcxP2pConnMeta { #define FLAGCX_P2P_REQ_IPUT 1 #define FLAGCX_P2P_REQ_IGET 2 #define FLAGCX_P2P_BATCH_POLL_SIZE 32 +#define FLAGCX_P2P_QPS_PER_CONN 4 struct flagcxP2pRequest { int type; @@ -85,16 +86,22 @@ struct flagcxP2pRequest { std::atomic *reqDone; // pointer to owning comm's reqDone[] array }; -// P2P send comm — one QP, one CQ, blocking connect +struct flagcxP2pChannel { + struct ibv_cq *cq; + struct flagcxIbQp qp; +}; + +// P2P send comm — fixed QP/CQ channels, blocking connect struct flagcxP2pSendComm { int ibDevN; // MUST be first field struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct flagcxP2pChannel channels[FLAGCX_P2P_QPS_PER_CONN]; struct flagcxSocket sock; struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; uint64_t putSignalScratchpad; struct ibv_mr *putSignalScratchpadMr; std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; + std::atomic nextChannel{0}; std::atomic cqError{false}; }; @@ -102,12 +109,13 @@ struct flagcxP2pSendComm { struct flagcxP2pRecvComm { int ibDevN; // MUST be first field struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct flagcxP2pChannel channels[FLAGCX_P2P_QPS_PER_CONN]; struct flagcxSocket sock; struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; uint64_t putSignalScratchpad; struct ibv_mr *putSignalScratchpadMr; std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; + std::atomic nextChannel{0}; std::atomic cqError{false}; }; @@ -142,44 +150,43 @@ static CqPoller gCqPoller; static void cqPollerFunc() { while (gCqPoller.running.load(std::memory_order_relaxed)) { - std::vector snapshot; + bool anyWork = false; { + // Keep unregister serialized with polling so CQ teardown cannot race a + // snapshot that still contains the CQ. std::lock_guard lock(gCqPoller.mutex); - snapshot = gCqPoller.entries; - } - - bool anyWork = false; - for (auto &entry : snapshot) { - if (!entry.active) - continue; - - struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; - int nCqe = 0; - if (flagcxWrapIbvPollCq(entry.cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, - &nCqe) != flagcxSuccess) { - entry.cqError->store(true, std::memory_order_release); - continue; - } + for (auto &entry : gCqPoller.entries) { + if (!entry.active) + continue; - for (int i = 0; i < nCqe; i++) { - if (wcs[i].status != IBV_WC_SUCCESS) { - WARN("NET/IB_P2P : CQ poller got error status %d for wr_id %lu", - wcs[i].status, wcs[i].wr_id); + struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; + int nCqe = 0; + if (flagcxWrapIbvPollCq(entry.cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, + &nCqe) != flagcxSuccess) { entry.cqError->store(true, std::memory_order_release); - break; - } - uint32_t idx = (uint32_t)wcs[i].wr_id; - if (idx >= FLAGCX_P2P_MAX_REQUESTS) continue; + } - entry.reqs[idx].events--; - if (entry.reqs[idx].events == 0) { - entry.reqs[idx].type = FLAGCX_P2P_REQ_UNUSED; - entry.reqDone[idx].store(1, std::memory_order_release); + for (int i = 0; i < nCqe; i++) { + if (wcs[i].status != IBV_WC_SUCCESS) { + WARN("NET/IB_P2P : CQ poller got error status %d for wr_id %lu", + wcs[i].status, wcs[i].wr_id); + entry.cqError->store(true, std::memory_order_release); + break; + } + uint32_t idx = (uint32_t)wcs[i].wr_id; + if (idx >= FLAGCX_P2P_MAX_REQUESTS) + continue; + + entry.reqs[idx].events--; + if (entry.reqs[idx].events == 0) { + entry.reqs[idx].type = FLAGCX_P2P_REQ_UNUSED; + entry.reqDone[idx].store(1, std::memory_order_release); + } } + if (nCqe > 0) + anyWork = true; } - if (nCqe > 0) - anyWork = true; } if (!anyWork) { @@ -229,7 +236,6 @@ static void cqPollerUnregister(struct ibv_cq *cq) { } } } - std::this_thread::sleep_for(std::chrono::microseconds(100)); if (!anyActive) { cqPollerStop(); } @@ -403,10 +409,13 @@ static flagcxResult_t flagcxP2pListen(int dev, void *opaqueHandle, return flagcxSuccess; } -// Helper: set up PD (from eager init), CQ, QP, and GID for a connection +static flagcxResult_t flagcxP2pReleasePd(int ibDevN); +static void flagcxP2pDrainCq(struct ibv_cq *cq); + +// Helper: set up PD (from eager init), CQs, QPs, and GID for a connection static flagcxResult_t flagcxP2pSetupConn(int dev, struct flagcxIbNetCommDevBase *base, - struct flagcxIbQp *qp, + struct flagcxP2pChannel *channels, int *outIbDevN) { struct flagcxIbMergedDev *mergedDev = flagcxIbMergedDevs + dev; int ibDevN = mergedDev->devs[0]; // v1: single physical NIC @@ -421,26 +430,51 @@ static flagcxResult_t flagcxP2pSetupConn(int dev, base->pd = ibDev->pd; pthread_mutex_unlock(&ibDev->lock); - // Create CQ for this connection - FLAGCXCHECK(flagcxWrapIbvCreateCq( - &base->cq, ibDev->context, 2 * FLAGCX_P2P_MAX_REQUESTS, NULL, NULL, 0)); + int accessFlags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_ATOMIC; // Get GID info - FLAGCXCHECK(flagcxIbGetGidIndex(ibDev->context, ibDev->portNum, - ibDev->portAttr.gid_tbl_len, - &base->gidInfo.localGidIndex)); - FLAGCXCHECK(flagcxWrapIbvQueryGid(ibDev->context, ibDev->portNum, - base->gidInfo.localGidIndex, - &base->gidInfo.localGid)); + flagcxResult_t res; + FLAGCXCHECKGOTO(flagcxIbGetGidIndex(ibDev->context, ibDev->portNum, + ibDev->portAttr.gid_tbl_len, + &base->gidInfo.localGidIndex), + res, setup_fail); + FLAGCXCHECKGOTO(flagcxWrapIbvQueryGid(ibDev->context, ibDev->portNum, + base->gidInfo.localGidIndex, + &base->gidInfo.localGid), + res, setup_fail); base->gidInfo.linkLayer = ibDev->link; - // Create RC QP with remote write, read, and atomic access - int accessFlags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | - IBV_ACCESS_REMOTE_ATOMIC; - FLAGCXCHECK(flagcxIbCreateQp(ibDev->portNum, base, accessFlags, qp)); - qp->devIndex = 0; + // Create RC QPs with remote write, read, and atomic access. + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { + FLAGCXCHECKGOTO(flagcxWrapIbvCreateCq(&channels[i].cq, ibDev->context, + 2 * FLAGCX_P2P_MAX_REQUESTS, NULL, + NULL, 0), + res, setup_fail); + base->cq = channels[i].cq; + FLAGCXCHECKGOTO( + flagcxIbCreateQp(ibDev->portNum, base, accessFlags, &channels[i].qp), + res, setup_fail); + channels[i].qp.devIndex = 0; + } return flagcxSuccess; + +setup_fail: + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { + if (channels[i].qp.qp) { + flagcxWrapIbvDestroyQp(channels[i].qp.qp); + channels[i].qp.qp = NULL; + } + if (channels[i].cq) { + flagcxWrapIbvDestroyCq(channels[i].cq); + channels[i].cq = NULL; + } + } + base->cq = NULL; + flagcxP2pReleasePd(ibDevN); + base->pd = NULL; + return res; } // Helper: build local connection metadata @@ -483,65 +517,135 @@ flagcxP2pTransitionQp(struct flagcxIbQp *qp, return flagcxSuccess; } +static void flagcxP2pRegisterChannels(struct flagcxP2pChannel *channels, + struct flagcxP2pRequest *reqs, + std::atomic *reqDone, + std::atomic *cqError) { + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + cqPollerRegister(channels[i].cq, reqs, reqDone, cqError); +} + +static void flagcxP2pUnregisterChannels(struct flagcxP2pChannel *channels) { + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + if (channels[i].cq) + cqPollerUnregister(channels[i].cq); +} + +static flagcxResult_t +flagcxP2pDestroyChannels(struct flagcxP2pChannel *channels) { + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + flagcxP2pDrainCq(channels[i].cq); + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { + if (channels[i].qp.qp) { + FLAGCXCHECK(flagcxWrapIbvDestroyQp(channels[i].qp.qp)); + channels[i].qp.qp = NULL; + } + } + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { + if (channels[i].cq) { + FLAGCXCHECK(flagcxWrapIbvDestroyCq(channels[i].cq)); + channels[i].cq = NULL; + } + } + return flagcxSuccess; +} + +static inline struct flagcxP2pChannel * +flagcxP2pNextChannel(struct flagcxP2pChannel *channels, + std::atomic *nextChannel) { + uint32_t idx = nextChannel->fetch_add(1, std::memory_order_relaxed); + return channels + (idx % FLAGCX_P2P_QPS_PER_CONN); +} + static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, void **sendComm) { struct flagcxP2pListenHandle *handle = (struct flagcxP2pListenHandle *)opaqueHandle; + flagcxResult_t res; *sendComm = NULL; // Allocate send comm struct flagcxP2pSendComm *comm; FLAGCXCHECK(flagcxCalloc(&comm, 1)); + int ready = 0; + auto connectStart = std::chrono::steady_clock::time_point(); + struct flagcxP2pConnMeta localMeta[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxP2pConnMeta remoteMeta[FLAGCX_P2P_QPS_PER_CONN]; + int localReady = 1, remoteReady = 0; // TCP connect (blocking with timeout) - FLAGCXCHECK(flagcxSocketInit(&comm->sock, &handle->connectAddr, handle->magic, - flagcxSocketTypeNetIb, NULL, 1)); - FLAGCXCHECK(flagcxSocketConnect(&comm->sock)); - int ready = 0; - auto connectStart = std::chrono::steady_clock::now(); + FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock, &handle->connectAddr, + handle->magic, flagcxSocketTypeNetIb, NULL, + 1), + res, connect_fail); + FLAGCXCHECKGOTO(flagcxSocketConnect(&comm->sock), res, connect_fail); + connectStart = std::chrono::steady_clock::now(); while (!ready) { - FLAGCXCHECK(flagcxSocketReady(&comm->sock, &ready)); + FLAGCXCHECKGOTO(flagcxSocketReady(&comm->sock, &ready), res, connect_fail); if (!ready) { if (std::chrono::steady_clock::now() - connectStart > std::chrono::seconds(30)) { WARN("NET/IB_P2P : connect socket ready timed out after 30s"); - free(comm); - return flagcxSystemError; + res = flagcxSystemError; + goto connect_fail; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } - // Set up PD, CQ, QP - FLAGCXCHECK(flagcxP2pSetupConn(dev, &comm->base, &comm->qp, &comm->ibDevN)); + // Set up PD, CQs, QPs + FLAGCXCHECKGOTO( + flagcxP2pSetupConn(dev, &comm->base, comm->channels, &comm->ibDevN), res, + connect_fail); // Exchange connection metadata - struct flagcxP2pConnMeta localMeta, remoteMeta; - flagcxP2pBuildConnMeta(&localMeta, &comm->base, &comm->qp, comm->ibDevN); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localMeta, sizeof(localMeta))); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteMeta, sizeof(remoteMeta))); - - // Transition QP to RTR then RTS - FLAGCXCHECK( - flagcxP2pTransitionQp(&comm->qp, &comm->base, &remoteMeta, comm->ibDevN)); + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->channels[i].qp, + comm->ibDevN); + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, sizeof(localMeta)), + res, connect_fail); + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, sizeof(remoteMeta)), + res, connect_fail); + + // Transition each matched QP to RTR then RTS. + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->channels[i].qp, &comm->base, + &remoteMeta[i], comm->ibDevN), + res, connect_fail); // Register putSignal scratchpad MR comm->putSignalScratchpad = 0; - FLAGCXCHECK(flagcxWrapIbvRegMr( - &comm->putSignalScratchpadMr, comm->base.pd, &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC)); + FLAGCXCHECKGOTO( + flagcxWrapIbvRegMr(&comm->putSignalScratchpadMr, comm->base.pd, + &comm->putSignalScratchpad, + sizeof(comm->putSignalScratchpad), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC), + res, connect_fail); // Exchange ready - int localReady = 1, remoteReady = 0; - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, + connect_fail); + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, + connect_fail); - cqPollerRegister(comm->base.cq, comm->reqs, comm->reqDone, &comm->cqError); + flagcxP2pRegisterChannels(comm->channels, comm->reqs, comm->reqDone, + &comm->cqError); *sendComm = comm; return flagcxSuccess; + +connect_fail: + if (comm->putSignalScratchpadMr) + flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr); + flagcxP2pDestroyChannels(comm->channels); + if (comm->base.pd) + flagcxP2pReleasePd(comm->ibDevN); + flagcxSocketClose(&comm->sock); + free(comm); + return res; } static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { @@ -555,6 +659,9 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { // TCP accept (blocking, no timeout) flagcxResult_t res; int ready; + struct flagcxP2pConnMeta localMeta[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxP2pConnMeta remoteMeta[FLAGCX_P2P_QPS_PER_CONN]; + int localReady = 1, remoteReady = 0; FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock), res, accept_fail); FLAGCXCHECKGOTO(flagcxSocketAccept(&comm->sock, &lComm->sock), res, accept_fail); @@ -571,37 +678,59 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { return res; } - // Set up PD, CQ, QP - FLAGCXCHECK( - flagcxP2pSetupConn(lComm->dev, &comm->base, &comm->qp, &comm->ibDevN)); + // Set up PD, CQs, QPs + FLAGCXCHECKGOTO(flagcxP2pSetupConn(lComm->dev, &comm->base, comm->channels, + &comm->ibDevN), + res, accept_cleanup); // Exchange connection metadata (accept receives first, then sends) - struct flagcxP2pConnMeta localMeta, remoteMeta; - flagcxP2pBuildConnMeta(&localMeta, &comm->base, &comm->qp, comm->ibDevN); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteMeta, sizeof(remoteMeta))); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localMeta, sizeof(localMeta))); - - // Transition QP to RTR then RTS - FLAGCXCHECK( - flagcxP2pTransitionQp(&comm->qp, &comm->base, &remoteMeta, comm->ibDevN)); + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->channels[i].qp, + comm->ibDevN); + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, sizeof(remoteMeta)), + res, accept_cleanup); + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, sizeof(localMeta)), + res, accept_cleanup); + + // Transition each matched QP to RTR then RTS. + for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->channels[i].qp, &comm->base, + &remoteMeta[i], comm->ibDevN), + res, accept_cleanup); // Register putSignal scratchpad MR (symmetric with connect) comm->putSignalScratchpad = 0; - FLAGCXCHECK(flagcxWrapIbvRegMr( - &comm->putSignalScratchpadMr, comm->base.pd, &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC)); + FLAGCXCHECKGOTO( + flagcxWrapIbvRegMr(&comm->putSignalScratchpadMr, comm->base.pd, + &comm->putSignalScratchpad, + sizeof(comm->putSignalScratchpad), + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC), + res, accept_cleanup); // Exchange ready - int localReady = 1, remoteReady = 0; - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, + accept_cleanup); + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, + accept_cleanup); - cqPollerRegister(comm->base.cq, comm->reqs, comm->reqDone, &comm->cqError); + flagcxP2pRegisterChannels(comm->channels, comm->reqs, comm->reqDone, + &comm->cqError); *recvComm = comm; return flagcxSuccess; + +accept_cleanup: + if (comm->putSignalScratchpadMr) + flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr); + flagcxP2pDestroyChannels(comm->channels); + if (comm->base.pd) + flagcxP2pReleasePd(comm->ibDevN); + flagcxSocketClose(&comm->sock); + free(comm); + return res; } /* ------------------------------------------------------------------ */ @@ -617,8 +746,11 @@ static flagcxResult_t flagcxP2pIput(void *sendComm, uint64_t srcOff, struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, FLAGCX_P2P_REQ_IPUT, &req)); + struct flagcxP2pChannel *channel = + flagcxP2pNextChannel(comm->channels, &comm->nextChannel); + req->cq = channel->cq; struct ibv_sge sge; memset(&sge, 0, sizeof(sge)); @@ -641,9 +773,13 @@ static flagcxResult_t flagcxP2pIput(void *sendComm, uint64_t srcOff, wr.sg_list = &sge; wr.num_sge = 1; - struct ibv_send_wr *bad_wr; - FLAGCXCHECK(flagcxWrapIbvPostSend(comm->qp.qp, &wr, &bad_wr)); req->events = 1; + struct ibv_send_wr *bad_wr; + flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wr, &bad_wr); + if (res != flagcxSuccess) { + flagcxP2pFreeRequest(req); + return res; + } *request = req; return flagcxSuccess; @@ -658,8 +794,11 @@ static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, FLAGCX_P2P_REQ_IGET, &req)); + struct flagcxP2pChannel *channel = + flagcxP2pNextChannel(comm->channels, &comm->nextChannel); + req->cq = channel->cq; struct ibv_sge sge; memset(&sge, 0, sizeof(sge)); @@ -682,9 +821,13 @@ static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, wr.sg_list = &sge; wr.num_sge = 1; - struct ibv_send_wr *bad_wr; - FLAGCXCHECK(flagcxWrapIbvPostSend(comm->qp.qp, &wr, &bad_wr)); req->events = 1; + struct ibv_send_wr *bad_wr; + flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wr, &bad_wr); + if (res != flagcxSuccess) { + flagcxP2pFreeRequest(req); + return res; + } *request = req; return flagcxSuccess; @@ -700,8 +843,11 @@ flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, (struct flagcxP2pMrHandle *)signalHandles; struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, comm->base.cq, + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, FLAGCX_P2P_REQ_IPUT, &req)); + struct flagcxP2pChannel *channel = + flagcxP2pNextChannel(comm->channels, &comm->nextChannel); + req->cq = channel->cq; bool chainData = (size > 0 && srcHandles != NULL && dstHandles != NULL); @@ -748,10 +894,14 @@ flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, wr[1].num_sge = 1; wr[1].next = NULL; - struct ibv_send_wr *bad_wr; - FLAGCXCHECK( - flagcxWrapIbvPostSend(comm->qp.qp, chainData ? &wr[0] : &wr[1], &bad_wr)); req->events = 1; + struct ibv_send_wr *bad_wr; + flagcxResult_t res = flagcxWrapIbvPostSend( + channel->qp.qp, chainData ? &wr[0] : &wr[1], &bad_wr); + if (res != flagcxSuccess) { + flagcxP2pFreeRequest(req); + return res; + } *request = req; return flagcxSuccess; @@ -848,15 +998,10 @@ static void flagcxP2pDrainCq(struct ibv_cq *cq) { static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; if (comm) { - if (comm->base.cq) - cqPollerUnregister(comm->base.cq); - flagcxP2pDrainCq(comm->base.cq); - if (comm->qp.qp) - FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); + flagcxP2pUnregisterChannels(comm->channels); + FLAGCXCHECK(flagcxP2pDestroyChannels(comm->channels)); if (comm->putSignalScratchpadMr) FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); - if (comm->base.cq) - FLAGCXCHECK(flagcxWrapIbvDestroyCq(comm->base.cq)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); @@ -867,15 +1012,10 @@ static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { static flagcxResult_t flagcxP2pCloseRecv(void *recvComm) { struct flagcxP2pRecvComm *comm = (struct flagcxP2pRecvComm *)recvComm; if (comm) { - if (comm->base.cq) - cqPollerUnregister(comm->base.cq); - flagcxP2pDrainCq(comm->base.cq); - if (comm->qp.qp) - FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); + flagcxP2pUnregisterChannels(comm->channels); + FLAGCXCHECK(flagcxP2pDestroyChannels(comm->channels)); if (comm->putSignalScratchpadMr) FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); - if (comm->base.cq) - FLAGCXCHECK(flagcxWrapIbvDestroyCq(comm->base.cq)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index d8d1059c..cd39b0a7 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -56,10 +56,17 @@ struct FlagcxP2pListenHandleView { static_assert(sizeof(FlagcxP2pListenHandleView) <= FLAGCX_NET_HANDLE_MAXSIZE, "listen handle must fit in FLAGCX_NET_HANDLE_MAXSIZE"); +static constexpr int kP2pQpsPerConn = 4; + +struct FlagcxP2pChannelView { + struct ibv_cq *cq; + struct flagcxIbQp qp; +}; + struct FlagcxP2pCommView { int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct FlagcxP2pChannelView channels[kP2pQpsPerConn]; struct flagcxSocket sock; }; @@ -187,7 +194,7 @@ static uint64_t gNextXferId = 1; /* Async Transfer Worker Infrastructure */ /* ------------------------------------------------------------------ */ -static constexpr int kWindowSize = 64; +static constexpr int kWindowSize = 8; enum AsyncXferOp { ASYNC_XFER_READ, ASYNC_XFER_WRITE }; @@ -1550,7 +1557,6 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, std::vector localEntries(numIovs); { - auto t0 = std::chrono::high_resolution_clock::now(); std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { FlagcxP2pMemRegEntry *entry = findMemRegByMr(mrIds[i]); @@ -1573,11 +1579,6 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, localEntries[i] = *entry; } - auto t1 = std::chrono::high_resolution_clock::now(); - double ms = std::chrono::duration(t1 - t0).count(); - fprintf(stderr, - "[FlagCX P2P] ReadVector memReg lookup: numIovs=%d, time=%.4f ms\n", - numIovs, ms); } ensureAsyncWorkerStarted(); @@ -1606,10 +1607,6 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, pthread_cond_signal(&gAsyncWorker.cv); *transferId = xferId; - fprintf( - stderr, - "[FlagCX P2P] ReadVector submitted async task: xferId=%lu, numIovs=%d\n", - (unsigned long)xferId, numIovs); return 0; } From d61bb24cadd6b56f865b77bd5cc53db68e774e1f Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Fri, 22 May 2026 15:27:54 +0800 Subject: [PATCH 06/11] add batch get implementation --- flagcx/adaptor/include/flagcx_net_adaptor.h | 8 ++ flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 80 ++++++++++++- flagcx/core/flagcx_p2p.cc | 125 ++++++++++++++++++++ 3 files changed, 212 insertions(+), 1 deletion(-) diff --git a/flagcx/adaptor/include/flagcx_net_adaptor.h b/flagcx/adaptor/include/flagcx_net_adaptor.h index ce1c6071..248e5c32 100644 --- a/flagcx/adaptor/include/flagcx_net_adaptor.h +++ b/flagcx/adaptor/include/flagcx_net_adaptor.h @@ -25,6 +25,7 @@ typedef enum { // regMr, regMrDmaBuf, deregMr, isend, irecv, iflush, test, // iput, iget, iputSignal, getDevFromName // v2 — adds iputBatch (optional one-sided batch WRITE) +// latest — adds optional batch helpers for one-sided transfers struct flagcxNetAdaptor_v1 { // Basic functions @@ -139,6 +140,13 @@ struct flagcxNetAdaptor_latest { // If NULL, caller falls back to per-request test(). flagcxResult_t (*testBatch)(void **requests, int nRequests, int *doneFlags, int *doneCount); + // Optional one-side batch READ. Success returns one logical request for the + // full batch. + flagcxResult_t (*igetBatch)(void *sendComm, int count, + const uint64_t *srcOffs, const uint64_t *dstOffs, + const size_t *sizes, int srcRank, int dstRank, + void *const *srcHandles, void *const *dstHandles, + void **request); }; #define flagcxNetAdaptor flagcxNetAdaptor_latest diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 9eff3199..f32a93bd 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -71,12 +71,22 @@ struct flagcxP2pConnMeta { }; // P2P request — simplified from flagcxIbRequest -#define FLAGCX_P2P_MAX_REQUESTS 128 +#define FLAGCX_P2P_MAX_REQUESTS 256 #define FLAGCX_P2P_REQ_UNUSED 0 #define FLAGCX_P2P_REQ_IPUT 1 #define FLAGCX_P2P_REQ_IGET 2 #define FLAGCX_P2P_BATCH_POLL_SIZE 32 #define FLAGCX_P2P_QPS_PER_CONN 4 +#define FLAGCX_P2P_IGET_BATCH_MAX_WR 64 +#define FLAGCX_P2P_READ_BATCH_WINDOW 8 + +// flagcxIbCreateQp configures each P2P QP with 2 * MAX_REQUESTS send WRs. +// The read engine round-robins an 8-batch window across four QPs, so one QP +// can hold at most two fixed-size READ batches. +static_assert(2 * FLAGCX_P2P_IGET_BATCH_MAX_WR <= 2 * MAX_REQUESTS, + "P2P READ batch window exceeds QP send queue capacity"); +static_assert(FLAGCX_P2P_READ_BATCH_WINDOW / FLAGCX_P2P_QPS_PER_CONN == 2, + "P2P READ batch capacity assumes two batches per QP"); struct flagcxP2pRequest { int type; @@ -833,6 +843,73 @@ static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, return flagcxSuccess; } +static flagcxResult_t +flagcxP2pIgetBatch(void *sendComm, int count, const uint64_t *srcOffs, + const uint64_t *dstOffs, const size_t *sizes, int srcRank, + int dstRank, void *const *srcHandles, + void *const *dstHandles, void **request) { + struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; + if (count <= 0 || count > FLAGCX_P2P_IGET_BATCH_MAX_WR || srcOffs == NULL || + dstOffs == NULL || sizes == NULL || srcHandles == NULL || + dstHandles == NULL || request == NULL) { + WARN("NET/IB_P2P : invalid igetBatch arguments, count %d", count); + return flagcxInternalError; + } + + struct flagcxP2pRequest *req; + FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, + FLAGCX_P2P_REQ_IGET, &req)); + struct flagcxP2pChannel *channel = + flagcxP2pNextChannel(comm->channels, &comm->nextChannel); + req->cq = channel->cq; + + struct ibv_send_wr wrs[FLAGCX_P2P_IGET_BATCH_MAX_WR]; + struct ibv_sge sges[FLAGCX_P2P_IGET_BATCH_MAX_WR]; + memset(wrs, 0, sizeof(wrs)); + memset(sges, 0, sizeof(sges)); + + for (int i = 0; i < count; i++) { + if (srcHandles[i] == NULL || dstHandles[i] == NULL) { + WARN("NET/IB_P2P : igetBatch handle %d is NULL", i); + flagcxP2pFreeRequest(req); + return flagcxInternalError; + } + + struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles[i]; + struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles[i]; + + sges[i].addr = dst->baseVa + dstOffs[i]; + sges[i].length = (uint32_t)sizes[i]; + if ((size_t)sges[i].length != sizes[i]) { + WARN("NET/IB_P2P : igetBatch size %zu exceeds 32-bit limit", sizes[i]); + flagcxP2pFreeRequest(req); + return flagcxInternalError; + } + sges[i].lkey = dst->lkey; + + wrs[i].opcode = IBV_WR_RDMA_READ; + wrs[i].send_flags = + i == count - 1 ? IBV_SEND_SIGNALED : 0; // final CQE tracks batch + wrs[i].wr_id = i == count - 1 ? req - comm->reqs : 0; + wrs[i].wr.rdma.remote_addr = src->baseVa + srcOffs[i]; + wrs[i].wr.rdma.rkey = src->rkey; + wrs[i].sg_list = &sges[i]; + wrs[i].num_sge = 1; + wrs[i].next = i + 1 < count ? &wrs[i + 1] : NULL; + } + + req->events = 1; + struct ibv_send_wr *bad_wr; + flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wrs[0], &bad_wr); + if (res != flagcxSuccess) { + flagcxP2pFreeRequest(req); + return res; + } + + *request = req; + return flagcxSuccess; +} + static flagcxResult_t flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, size_t size, int srcRank, int dstRank, void **srcHandles, @@ -1110,4 +1187,5 @@ struct flagcxNetAdaptor flagcxNetIbP2p = { // Optional batch operations nullptr, // iputBatch flagcxP2pTestBatch, // testBatch + flagcxP2pIgetBatch, // igetBatch }; diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index cd39b0a7..851f70e8 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -195,6 +195,7 @@ static uint64_t gNextXferId = 1; /* ------------------------------------------------------------------ */ static constexpr int kWindowSize = 8; +static constexpr int kIgetBatchSize = 64; enum AsyncXferOp { ASYNC_XFER_READ, ASYNC_XFER_WRITE }; @@ -224,6 +225,123 @@ static FlagcxP2pCommView *getCommView(void *comm) { return reinterpret_cast(comm); } +struct AsyncReadBatchEntry { + void *request = nullptr; + int firstIov = 0; + int count = 0; +}; + +static bool asyncReadBatched(std::shared_ptr task, + struct flagcxNetAdaptor *adaptor, int connIbDevN) { + AsyncReadBatchEntry inflight[kWindowSize]; + int issuedIovs = 0; + int completedIovs = 0; + int issuedBatches = 0; + int completedBatches = 0; + + while (completedIovs < task->numIovs) { + while (issuedIovs < task->numIovs && + issuedBatches - completedBatches < kWindowSize) { + const int count = std::min(kIgetBatchSize, task->numIovs - issuedIovs); + uint64_t srcOffs[kIgetBatchSize] = {}; + uint64_t dstOffs[kIgetBatchSize] = {}; + FlagcxP2pMrHandleView remoteMrs[kIgetBatchSize] = {}; + void *srcHandles[kIgetBatchSize] = {}; + void *dstHandles[kIgetBatchSize] = {}; + + for (int batchIov = 0; batchIov < count; batchIov++) { + const int iov = issuedIovs + batchIov; + if (task->localEntries[iov].ibDevN != connIbDevN) + return false; + + FlagcxP2pMrHandleView *localMr = + reinterpret_cast( + task->localEntries[iov].mhandle); + dstOffs[batchIov] = (uintptr_t)task->dataVec[iov] - localMr->baseVa; + remoteMrs[batchIov].baseVa = task->descs[iov].addr; + remoteMrs[batchIov].rkey = task->descs[iov].rkey; + srcHandles[batchIov] = &remoteMrs[batchIov]; + dstHandles[batchIov] = task->localEntries[iov].mhandle; + } + + void *request = nullptr; + flagcxResult_t rc = + adaptor->igetBatch(task->conn->sendComm, count, srcOffs, dstOffs, + task->sizeVec.data() + issuedIovs, 0, 0, + srcHandles, dstHandles, &request); + if (rc != flagcxSuccess) + return false; + + AsyncReadBatchEntry &entry = inflight[issuedBatches % kWindowSize]; + entry.request = request; + entry.firstIov = issuedIovs; + entry.count = count; + issuedIovs += count; + issuedBatches++; + } + + int newlyCompleted = 0; + if (adaptor->testBatch != nullptr) { + void *batchRequests[kWindowSize]; + int batchIndices[kWindowSize]; + int batchCount = 0; + for (int batch = completedBatches; batch < issuedBatches; batch++) { + AsyncReadBatchEntry &entry = inflight[batch % kWindowSize]; + if (entry.request != nullptr) { + batchRequests[batchCount] = entry.request; + batchIndices[batchCount] = batch; + batchCount++; + } + } + + if (batchCount > 0) { + int doneFlags[kWindowSize]; + int doneCount = 0; + flagcxResult_t rc = adaptor->testBatch(batchRequests, batchCount, + doneFlags, &doneCount); + if (rc != flagcxSuccess) + return false; + + for (int i = 0; i < batchCount; i++) { + if (doneFlags[i]) { + inflight[batchIndices[i] % kWindowSize].request = nullptr; + newlyCompleted++; + } + } + } + } else { + for (int batch = completedBatches; batch < issuedBatches; batch++) { + AsyncReadBatchEntry &entry = inflight[batch % kWindowSize]; + if (entry.request == nullptr) + continue; + + int done = 0; + int sizes = 0; + flagcxResult_t rc = adaptor->test(entry.request, &done, &sizes); + if (rc != flagcxSuccess) + return false; + if (done) { + entry.request = nullptr; + newlyCompleted++; + } + } + } + + while (completedBatches < issuedBatches) { + AsyncReadBatchEntry &entry = inflight[completedBatches % kWindowSize]; + if (entry.request != nullptr) + break; + completedIovs += entry.count; + completedBatches++; + } + + if (newlyCompleted == 0 && issuedIovs >= task->numIovs) + std::this_thread::sleep_for(std::chrono::microseconds(1)); + } + + return true; +} + static void asyncWorkerFunc() { while (true) { std::shared_ptr task; @@ -244,6 +362,13 @@ static void asyncWorkerFunc() { const int numIovs = task->numIovs; const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + if (task->op == ASYNC_XFER_READ && adaptor->igetBatch != nullptr) { + bool ok = asyncReadBatched(task, adaptor, connIbDevN); + task->result.store(ok ? 0 : -1, std::memory_order_release); + task->done.store(true, std::memory_order_release); + continue; + } + std::vector inflightReqs(kWindowSize, nullptr); int issued = 0, completed = 0; bool error = false; From f151939076449cb964650428d1f793300d850ac5 Mon Sep 17 00:00:00 2001 From: mikethegoblin Date: Thu, 28 May 2026 16:36:22 +0800 Subject: [PATCH 07/11] add FlagcxSlice and FlagcxTransferTask implementation --- flagcx/core/flagcx_p2p.cc | 343 +++++++++++++++++++++++++++----------- 1 file changed, 247 insertions(+), 96 deletions(-) diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 851f70e8..7cddf31f 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -197,25 +197,129 @@ static uint64_t gNextXferId = 1; static constexpr int kWindowSize = 8; static constexpr int kIgetBatchSize = 64; -enum AsyncXferOp { ASYNC_XFER_READ, ASYNC_XFER_WRITE }; +enum FlagcxSlicePolicyKind { + FLAGCX_POLICY_NIXL = 0, + FLAGCX_POLICY_FLAGCX = 1, +}; -struct AsyncTransferTask { - FlagcxP2pConn *conn; - AsyncXferOp op; - int numIovs; - std::vector dataVec; - std::vector sizeVec; - std::vector descs; - std::vector localEntries; - std::atomic done{false}; +struct FlagcxTransferTask; + +struct FlagcxSlice { + // WRITE: local source VA; READ: local destination VA. + uint64_t srcVa = 0; + // WRITE: remote destination VA; READ: remote source VA. + uint64_t dstVa = 0; + uint32_t length = 0; + uint32_t lkey = 0; + uint32_t rkey = 0; + uint8_t opcode = 0; + FlagcxTransferTask *task = nullptr; + + void markSuccess(); + void markFailed(); +}; + +struct FlagcxTransferTask { + FlagcxP2pConn *conn = nullptr; + std::atomic sliceCount{0}; + std::atomic doneSliceCount{0}; std::atomic result{0}; + std::vector sliceList; + + bool isAllDone() const { + const uint64_t total = sliceCount.load(std::memory_order_acquire); + const uint64_t done = doneSliceCount.load(std::memory_order_acquire); + return total > 0 && done >= total; + } +}; + +void FlagcxSlice::markSuccess() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); +} + +void FlagcxSlice::markFailed() { + if (task) { + task->result.store(-1, std::memory_order_release); + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } +} + +struct NixlSlicePolicy { + static constexpr bool kFurtherCut = false; + static constexpr size_t kBlockSize = SIZE_MAX; + static constexpr size_t kFragmentSize = 0; }; +struct FlagcxSlicePolicy { + static constexpr bool kFurtherCut = true; + static constexpr size_t kBlockSize = 64 * 1024; + static constexpr size_t kFragmentSize = 4 * 1024; +}; + +template +static int buildSlices(FlagcxTransferTask *task, uint64_t srcVa, uint64_t dstVa, + size_t totalLen, uint32_t lkey, uint32_t rkey, + uint8_t opcode) { + if (task == nullptr || totalLen == 0) + return -1; + + if constexpr (!Policy::kFurtherCut) { + if (totalLen > UINT32_MAX) + return -1; + FlagcxSlice *slice = new FlagcxSlice; + slice->srcVa = srcVa; + slice->dstVa = dstVa; + slice->length = static_cast(totalLen); + slice->lkey = lkey; + slice->rkey = rkey; + slice->opcode = opcode; + slice->task = task; + task->sliceList.push_back(slice); + task->sliceCount.fetch_add(1, std::memory_order_release); + return 0; + } else { + size_t off = 0; + while (off < totalLen) { + const size_t remaining = totalLen - off; + const bool mergeTail = + remaining <= Policy::kBlockSize + Policy::kFragmentSize; + const size_t len = mergeTail ? remaining : Policy::kBlockSize; + if (len > UINT32_MAX) + return -1; + + FlagcxSlice *slice = new FlagcxSlice; + slice->srcVa = srcVa + off; + slice->dstVa = dstVa + off; + slice->length = static_cast(len); + slice->lkey = lkey; + slice->rkey = rkey; + slice->opcode = opcode; + slice->task = task; + task->sliceList.push_back(slice); + task->sliceCount.fetch_add(1, std::memory_order_release); + + off += len; + if (mergeTail) + break; + } + return 0; + } +} + +static void cleanupTransferTask(FlagcxTransferTask *task) { + if (task == nullptr) + return; + for (FlagcxSlice *slice : task->sliceList) + delete slice; + delete task; +} + struct AsyncWorker { std::thread thread; pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; pthread_cond_t cv = PTHREAD_COND_INITIALIZER; - std::deque> queue; + std::deque> queue; std::atomic stop{false}; }; @@ -227,56 +331,78 @@ static FlagcxP2pCommView *getCommView(void *comm) { struct AsyncReadBatchEntry { void *request = nullptr; - int firstIov = 0; + int firstSlice = 0; int count = 0; }; -static bool asyncReadBatched(std::shared_ptr task, - struct flagcxNetAdaptor *adaptor, int connIbDevN) { +static inline bool isReadSlice(const FlagcxSlice *slice) { + return slice != nullptr && slice->opcode == IBV_WR_RDMA_READ; +} + +static inline uint64_t sliceLocalVa(const FlagcxSlice *slice) { + return slice->srcVa; +} + +static inline uint64_t sliceRemoteVa(const FlagcxSlice *slice) { + return slice->dstVa; +} + +static void markSlicesFailed(std::shared_ptr task, + int firstSlice) { + if (!task) + return; + for (size_t i = firstSlice; i < task->sliceList.size(); i++) + task->sliceList[i]->markFailed(); +} + +static bool asyncReadBatched(std::shared_ptr task, + struct flagcxNetAdaptor *adaptor) { AsyncReadBatchEntry inflight[kWindowSize]; - int issuedIovs = 0; - int completedIovs = 0; + const int numSlices = static_cast(task->sliceList.size()); + int issuedSlices = 0; + int completedSlices = 0; int issuedBatches = 0; int completedBatches = 0; - while (completedIovs < task->numIovs) { - while (issuedIovs < task->numIovs && + while (completedSlices < numSlices) { + while (issuedSlices < numSlices && issuedBatches - completedBatches < kWindowSize) { - const int count = std::min(kIgetBatchSize, task->numIovs - issuedIovs); + const int count = std::min(kIgetBatchSize, numSlices - issuedSlices); uint64_t srcOffs[kIgetBatchSize] = {}; uint64_t dstOffs[kIgetBatchSize] = {}; + size_t sizes[kIgetBatchSize] = {}; FlagcxP2pMrHandleView remoteMrs[kIgetBatchSize] = {}; + FlagcxP2pMrHandleView localMrs[kIgetBatchSize] = {}; void *srcHandles[kIgetBatchSize] = {}; void *dstHandles[kIgetBatchSize] = {}; - for (int batchIov = 0; batchIov < count; batchIov++) { - const int iov = issuedIovs + batchIov; - if (task->localEntries[iov].ibDevN != connIbDevN) + for (int batchSlice = 0; batchSlice < count; batchSlice++) { + const int sliceIdx = issuedSlices + batchSlice; + FlagcxSlice *slice = task->sliceList[sliceIdx]; + if (!isReadSlice(slice)) return false; - FlagcxP2pMrHandleView *localMr = - reinterpret_cast( - task->localEntries[iov].mhandle); - dstOffs[batchIov] = (uintptr_t)task->dataVec[iov] - localMr->baseVa; - remoteMrs[batchIov].baseVa = task->descs[iov].addr; - remoteMrs[batchIov].rkey = task->descs[iov].rkey; - srcHandles[batchIov] = &remoteMrs[batchIov]; - dstHandles[batchIov] = task->localEntries[iov].mhandle; + remoteMrs[batchSlice].baseVa = sliceRemoteVa(slice); + remoteMrs[batchSlice].rkey = slice->rkey; + localMrs[batchSlice].baseVa = sliceLocalVa(slice); + localMrs[batchSlice].lkey = slice->lkey; + srcHandles[batchSlice] = &remoteMrs[batchSlice]; + dstHandles[batchSlice] = &localMrs[batchSlice]; + sizes[batchSlice] = slice->length; } void *request = nullptr; flagcxResult_t rc = adaptor->igetBatch(task->conn->sendComm, count, srcOffs, dstOffs, - task->sizeVec.data() + issuedIovs, 0, 0, - srcHandles, dstHandles, &request); + sizes, 0, 0, srcHandles, dstHandles, &request); if (rc != flagcxSuccess) return false; AsyncReadBatchEntry &entry = inflight[issuedBatches % kWindowSize]; entry.request = request; - entry.firstIov = issuedIovs; + entry.firstSlice = issuedSlices; entry.count = count; - issuedIovs += count; + issuedSlices += count; issuedBatches++; } @@ -304,7 +430,11 @@ static bool asyncReadBatched(std::shared_ptr task, for (int i = 0; i < batchCount; i++) { if (doneFlags[i]) { - inflight[batchIndices[i] % kWindowSize].request = nullptr; + AsyncReadBatchEntry &entry = + inflight[batchIndices[i] % kWindowSize]; + for (int s = 0; s < entry.count; s++) + task->sliceList[entry.firstSlice + s]->markSuccess(); + entry.request = nullptr; newlyCompleted++; } } @@ -321,6 +451,8 @@ static bool asyncReadBatched(std::shared_ptr task, if (rc != flagcxSuccess) return false; if (done) { + for (int s = 0; s < entry.count; s++) + task->sliceList[entry.firstSlice + s]->markSuccess(); entry.request = nullptr; newlyCompleted++; } @@ -331,11 +463,11 @@ static bool asyncReadBatched(std::shared_ptr task, AsyncReadBatchEntry &entry = inflight[completedBatches % kWindowSize]; if (entry.request != nullptr) break; - completedIovs += entry.count; + completedSlices += entry.count; completedBatches++; } - if (newlyCompleted == 0 && issuedIovs >= task->numIovs) + if (newlyCompleted == 0 && issuedSlices >= numSlices) std::this_thread::sleep_for(std::chrono::microseconds(1)); } @@ -344,7 +476,7 @@ static bool asyncReadBatched(std::shared_ptr task, static void asyncWorkerFunc() { while (true) { - std::shared_ptr task; + std::shared_ptr task; pthread_mutex_lock(&gAsyncWorker.mutex); while (gAsyncWorker.queue.empty() && !gAsyncWorker.stop.load()) { pthread_cond_wait(&gAsyncWorker.cv, &gAsyncWorker.mutex); @@ -359,13 +491,20 @@ static void asyncWorkerFunc() { FlagcxP2pConn *conn = task->conn; struct flagcxNetAdaptor *adaptor = conn->engine->adaptor; - const int numIovs = task->numIovs; - const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + const int numSlices = static_cast(task->sliceList.size()); - if (task->op == ASYNC_XFER_READ && adaptor->igetBatch != nullptr) { - bool ok = asyncReadBatched(task, adaptor, connIbDevN); + if (numSlices == 0) { + task->result.store(-1, std::memory_order_release); + task->sliceCount.store(1, std::memory_order_release); + task->doneSliceCount.store(1, std::memory_order_release); + continue; + } + + if (isReadSlice(task->sliceList[0]) && adaptor->igetBatch != nullptr) { + bool ok = asyncReadBatched(task, adaptor); + if (!ok) + markSlicesFailed(task, 0); task->result.store(ok ? 0 : -1, std::memory_order_release); - task->done.store(true, std::memory_order_release); continue; } @@ -373,42 +512,35 @@ static void asyncWorkerFunc() { int issued = 0, completed = 0; bool error = false; - while (completed < numIovs && !error) { + while (completed < numSlices && !error) { // Post up to kWindowSize ahead of completed - while (issued < numIovs && (issued - completed) < kWindowSize) { - if (task->localEntries[issued].ibDevN != connIbDevN) { - error = true; - break; - } - - FlagcxP2pMrHandleView *localMr = - reinterpret_cast( - task->localEntries[issued].mhandle); - + while (issued < numSlices && (issued - completed) < kWindowSize) { + FlagcxSlice *slice = task->sliceList[issued]; + FlagcxP2pMrHandleView localMr; FlagcxP2pMrHandleView remoteMr; + memset(&localMr, 0, sizeof(localMr)); memset(&remoteMr, 0, sizeof(remoteMr)); - remoteMr.baseVa = task->descs[issued].addr; - remoteMr.rkey = task->descs[issued].rkey; + + localMr.baseVa = sliceLocalVa(slice); + localMr.lkey = slice->lkey; + remoteMr.baseVa = sliceRemoteVa(slice); + remoteMr.rkey = slice->rkey; void *request = NULL; flagcxResult_t rc; - if (task->op == ASYNC_XFER_READ) { + if (isReadSlice(slice)) { const uint64_t srcOff = 0; - const uint64_t dstOff = - (uintptr_t)task->dataVec[issued] - localMr->baseVa; - rc = adaptor->iget(conn->sendComm, srcOff, dstOff, - task->sizeVec[issued], 0, 0, (void **)&remoteMr, - (void **)task->localEntries[issued].mhandle, - &request); + const uint64_t dstOff = 0; + rc = + adaptor->iget(conn->sendComm, srcOff, dstOff, slice->length, 0, 0, + (void **)&remoteMr, (void **)&localMr, &request); } else { - const uint64_t srcOff = - (uintptr_t)task->dataVec[issued] - localMr->baseVa; + const uint64_t srcOff = 0; const uint64_t dstOff = 0; - rc = adaptor->iput(conn->sendComm, srcOff, dstOff, - task->sizeVec[issued], 0, 0, - (void **)task->localEntries[issued].mhandle, - (void **)&remoteMr, &request); + rc = + adaptor->iput(conn->sendComm, srcOff, dstOff, slice->length, 0, 0, + (void **)&localMr, (void **)&remoteMr, &request); } if (rc != flagcxSuccess) { @@ -450,6 +582,7 @@ static void asyncWorkerFunc() { if (doneFlags[b]) { int i = batchIndices[b]; int slot = i % kWindowSize; + task->sliceList[i]->markSuccess(); inflightReqs[slot] = nullptr; newlyCompleted++; } @@ -466,11 +599,13 @@ static void asyncWorkerFunc() { int done = 0, sizes = 0; flagcxResult_t res = adaptor->test(inflightReqs[slot], &done, &sizes); if (res != flagcxSuccess) { + task->sliceList[i]->markFailed(); inflightReqs[slot] = nullptr; newlyCompleted++; continue; } if (done) { + task->sliceList[i]->markSuccess(); inflightReqs[slot] = nullptr; newlyCompleted++; } @@ -484,13 +619,15 @@ static void asyncWorkerFunc() { } // Yield briefly if no progress was made - if (newlyCompleted == 0 && issued >= numIovs) { + if (newlyCompleted == 0 && issued >= numSlices) { std::this_thread::sleep_for(std::chrono::microseconds(1)); } } - task->result.store(error ? -1 : 0, std::memory_order_release); - task->done.store(true, std::memory_order_release); + if (error) { + markSlicesFailed(task, completed); + task->result.store(-1, std::memory_order_release); + } } } @@ -518,7 +655,7 @@ static void stopAsyncWorker() { } // Map from transfer ID to async task (for XferStatus polling) -static std::unordered_map> +static std::unordered_map> gAsyncXferMap; static std::mutex gAsyncXferMutex; @@ -1680,7 +1817,12 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, return -1; } - std::vector localEntries(numIovs); + std::shared_ptr task(new FlagcxTransferTask, + cleanupTransferTask); + task->conn = conn; + task->sliceList.reserve(numIovs); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { @@ -1702,21 +1844,23 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, return -1; } - localEntries[i] = *entry; + if (entry->ibDevN != connIbDevN) + return -1; + + FlagcxP2pMrHandleView *localMr = + reinterpret_cast(entry->mhandle); + if (buildSlices( + task.get(), + static_cast(reinterpret_cast(dstVec[i])), + descs[i].addr, sizeVec[i], localMr->lkey, descs[i].rkey, + IBV_WR_RDMA_READ) != 0) { + return -1; + } } } ensureAsyncWorkerStarted(); - auto task = std::make_shared(); - task->conn = conn; - task->op = ASYNC_XFER_READ; - task->numIovs = numIovs; - task->dataVec = std::move(dstVec); - task->sizeVec = std::move(sizeVec); - task->descs = std::move(descs); - task->localEntries = std::move(localEntries); - const uint64_t xferId = [&] { std::lock_guard lock(gAsyncXferMutex); uint64_t id = gNextXferId++; @@ -1818,7 +1962,12 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, if (mrIds.size() < static_cast(numIovs)) return -1; - std::vector localEntries(numIovs); + std::shared_ptr task(new FlagcxTransferTask, + cleanupTransferTask); + task->conn = conn; + task->sliceList.reserve(numIovs); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { @@ -1830,21 +1979,23 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, sizeVec[i])) return -1; - localEntries[i] = *entry; + if (entry->ibDevN != connIbDevN) + return -1; + + FlagcxP2pMrHandleView *localMr = + reinterpret_cast(entry->mhandle); + if (buildSlices( + task.get(), + static_cast(reinterpret_cast(dstVec[i])), + descs[i].addr, sizeVec[i], localMr->lkey, descs[i].rkey, + IBV_WR_RDMA_WRITE) != 0) { + return -1; + } } } ensureAsyncWorkerStarted(); - auto task = std::make_shared(); - task->conn = conn; - task->op = ASYNC_XFER_WRITE; - task->numIovs = numIovs; - task->dataVec = std::move(dstVec); - task->sizeVec = std::move(sizeVec); - task->descs = std::move(descs); - task->localEntries = std::move(localEntries); - const uint64_t xferId = [&] { std::lock_guard lock(gAsyncXferMutex); uint64_t id = gNextXferId++; @@ -1905,7 +2056,7 @@ bool flagcxP2pEngineXferStatus(FlagcxP2pConn *conn, uint64_t transferId) { std::lock_guard lock(gAsyncXferMutex); auto it = gAsyncXferMap.find(transferId); if (it != gAsyncXferMap.end()) { - if (it->second->done.load(std::memory_order_acquire)) { + if (it->second->isAllDone()) { gAsyncXferMap.erase(it); return true; } From 0d4e2771dc502faa0474cfa53b1c141dca6d56b5 Mon Sep 17 00:00:00 2001 From: leoda1 Date: Fri, 29 May 2026 10:52:18 +0800 Subject: [PATCH 08/11] add workerPool to process all slices, all configuration settings can be independently adjusted using environment variables. --- flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 885 +++++++---------- flagcx/core/flagcx_p2p.cc | 1203 +++++++++++++++--------- flagcx/include/flagcx_p2p.h | 55 ++ 3 files changed, 1156 insertions(+), 987 deletions(-) diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index f32a93bd..c8c3deab 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -9,6 +9,7 @@ #include "flagcx_common.h" #include "flagcx_net_adaptor.h" +#include "flagcx_p2p.h" #include "ib_common.h" #include "ibvwrap.h" #include "socket.h" @@ -19,15 +20,66 @@ #include #include #include +#include +#include #include #include #include #include +struct FlagcxSlice; + +extern struct ibv_cq *flagcxP2pPoolGetSharedCq(int ibDevN, + struct ibv_context *ctx); +extern void flagcxP2pPoolRegisterQp(int ibDevN, void *sendComm, + struct ibv_qp *qp); +extern void flagcxP2pPoolUnregisterQp(int ibDevN, struct ibv_qp *qp); +extern flagcxResult_t flagcxP2pPoolSubmit(int ibDevN, void *sendComm, + FlagcxSlice **slices, int count); + /* ------------------------------------------------------------------ */ /* Internal structs */ /* ------------------------------------------------------------------ */ +struct FlagcxTransferTask { + std::atomic sliceCount{0}; + std::atomic doneSliceCount{0}; + std::vector sliceList; + + bool isAllDone() const { + auto total = sliceCount.load(std::memory_order_acquire); + auto done = doneSliceCount.load(std::memory_order_acquire); + return total > 0 && done >= total; + } +}; + +enum FlagcxSliceOp : uint8_t { + FLAGCX_SLICE_OP_WRITE = 0, + FLAGCX_SLICE_OP_READ = 1, +}; + +struct FlagcxSlice { + uint64_t srcVa; + uint64_t dstVa; + uint32_t length; + uint32_t lkey; + uint32_t rkey; + uint8_t opcode; + std::string peerNicPath; + FlagcxTransferTask *task; + volatile int *qpDepth; + + inline void markSuccess() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } + + inline void markFailed() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } +}; + // Per-device context — created at init, holds eagerly allocated PD. // Passed as the `comm` parameter to regMr/deregMr when no connection exists. // ibDevN MUST be the first field so regMr can cast any comm pointer to extract @@ -70,63 +122,28 @@ struct flagcxP2pConnMeta { enum ibv_mtu mtu; }; -// P2P request — simplified from flagcxIbRequest -#define FLAGCX_P2P_MAX_REQUESTS 256 -#define FLAGCX_P2P_REQ_UNUSED 0 -#define FLAGCX_P2P_REQ_IPUT 1 -#define FLAGCX_P2P_REQ_IGET 2 -#define FLAGCX_P2P_BATCH_POLL_SIZE 32 -#define FLAGCX_P2P_QPS_PER_CONN 4 -#define FLAGCX_P2P_IGET_BATCH_MAX_WR 64 -#define FLAGCX_P2P_READ_BATCH_WINDOW 8 - -// flagcxIbCreateQp configures each P2P QP with 2 * MAX_REQUESTS send WRs. -// The read engine round-robins an 8-batch window across four QPs, so one QP -// can hold at most two fixed-size READ batches. -static_assert(2 * FLAGCX_P2P_IGET_BATCH_MAX_WR <= 2 * MAX_REQUESTS, - "P2P READ batch window exceeds QP send queue capacity"); -static_assert(FLAGCX_P2P_READ_BATCH_WINDOW / FLAGCX_P2P_QPS_PER_CONN == 2, - "P2P READ batch capacity assumes two batches per QP"); - -struct flagcxP2pRequest { - int type; - int events; // outstanding CQEs expected - struct ibv_cq *cq; // CQ to poll for this request - struct flagcxP2pRequest *reqs; // back-pointer to owning reqs[] array - std::atomic *reqDone; // pointer to owning comm's reqDone[] array -}; - -struct flagcxP2pChannel { - struct ibv_cq *cq; - struct flagcxIbQp qp; +struct flagcxP2pSliceReq { + FlagcxTransferTask task; + FlagcxSlice slice; }; -// P2P send comm — fixed QP/CQ channels, blocking connect +// Field order through `sock` mirrors core's FlagcxP2pCommView — do not reorder. struct flagcxP2pSendComm { - int ibDevN; // MUST be first field + int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxP2pChannel channels[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; - struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; - uint64_t putSignalScratchpad; - struct ibv_mr *putSignalScratchpadMr; - std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; std::atomic nextChannel{0}; - std::atomic cqError{false}; + int numQps{0}; // resolved from flagcxP2pGlobalConfig().qpsPerConn at connect/accept }; -// P2P recv comm — symmetric with send comm so both sides can initiate transfers struct flagcxP2pRecvComm { - int ibDevN; // MUST be first field + int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxP2pChannel channels[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; - struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; - uint64_t putSignalScratchpad; - struct ibv_mr *putSignalScratchpadMr; - std::atomic reqDone[FLAGCX_P2P_MAX_REQUESTS]; std::atomic nextChannel{0}; - std::atomic cqError{false}; + int numQps{0}; // resolved from flagcxP2pGlobalConfig().qpsPerConn at connect/accept }; /* ------------------------------------------------------------------ */ @@ -137,149 +154,6 @@ static struct flagcxP2pDevCtx flagcxP2pDevCtxs[MAX_IB_DEVS]; static int flagcxP2pInitialized = 0; static pthread_mutex_t flagcxP2pInitLock = PTHREAD_MUTEX_INITIALIZER; -/* ------------------------------------------------------------------ */ -/* Background CQ Poller */ -/* ------------------------------------------------------------------ */ - -struct CqPollEntry { - struct ibv_cq *cq; - struct flagcxP2pRequest *reqs; - std::atomic *reqDone; - std::atomic *cqError; - bool active; -}; - -struct CqPoller { - std::thread thread; - std::mutex mutex; - std::vector entries; - std::atomic running{false}; -}; - -static CqPoller gCqPoller; - -static void cqPollerFunc() { - while (gCqPoller.running.load(std::memory_order_relaxed)) { - bool anyWork = false; - { - // Keep unregister serialized with polling so CQ teardown cannot race a - // snapshot that still contains the CQ. - std::lock_guard lock(gCqPoller.mutex); - for (auto &entry : gCqPoller.entries) { - if (!entry.active) - continue; - - struct ibv_wc wcs[FLAGCX_P2P_BATCH_POLL_SIZE]; - int nCqe = 0; - if (flagcxWrapIbvPollCq(entry.cq, FLAGCX_P2P_BATCH_POLL_SIZE, wcs, - &nCqe) != flagcxSuccess) { - entry.cqError->store(true, std::memory_order_release); - continue; - } - - for (int i = 0; i < nCqe; i++) { - if (wcs[i].status != IBV_WC_SUCCESS) { - WARN("NET/IB_P2P : CQ poller got error status %d for wr_id %lu", - wcs[i].status, wcs[i].wr_id); - entry.cqError->store(true, std::memory_order_release); - break; - } - uint32_t idx = (uint32_t)wcs[i].wr_id; - if (idx >= FLAGCX_P2P_MAX_REQUESTS) - continue; - - entry.reqs[idx].events--; - if (entry.reqs[idx].events == 0) { - entry.reqs[idx].type = FLAGCX_P2P_REQ_UNUSED; - entry.reqDone[idx].store(1, std::memory_order_release); - } - } - if (nCqe > 0) - anyWork = true; - } - } - - if (!anyWork) { - std::this_thread::sleep_for(std::chrono::microseconds(1)); - } - } -} - -static void ensureCqPollerStarted() { - if (!gCqPoller.running.load(std::memory_order_acquire)) { - std::lock_guard lock(gCqPoller.mutex); - if (!gCqPoller.running.load(std::memory_order_relaxed)) { - gCqPoller.running.store(true, std::memory_order_release); - gCqPoller.thread = std::thread(cqPollerFunc); - } - } -} - -static void cqPollerRegister(struct ibv_cq *cq, struct flagcxP2pRequest *reqs, - std::atomic *reqDone, - std::atomic *cqError) { - ensureCqPollerStarted(); - std::lock_guard lock(gCqPoller.mutex); - gCqPoller.entries.push_back({cq, reqs, reqDone, cqError, true}); -} - -static void cqPollerStop() { - if (gCqPoller.running.load(std::memory_order_acquire)) { - gCqPoller.running.store(false, std::memory_order_release); - if (gCqPoller.thread.joinable()) { - gCqPoller.thread.join(); - } - std::lock_guard lock(gCqPoller.mutex); - gCqPoller.entries.clear(); - } -} - -static void cqPollerUnregister(struct ibv_cq *cq) { - bool anyActive = false; - { - std::lock_guard lock(gCqPoller.mutex); - for (auto &entry : gCqPoller.entries) { - if (entry.cq == cq) { - entry.active = false; - } else if (entry.active) { - anyActive = true; - } - } - } - if (!anyActive) { - cqPollerStop(); - } -} - -/* ------------------------------------------------------------------ */ -/* Request helpers */ -/* ------------------------------------------------------------------ */ - -static flagcxResult_t flagcxP2pGetRequest(struct flagcxP2pRequest *reqs, - std::atomic *reqDone, - struct ibv_cq *cq, int type, - struct flagcxP2pRequest **req) { - for (int i = 0; i < FLAGCX_P2P_MAX_REQUESTS; i++) { - if (reqs[i].type == FLAGCX_P2P_REQ_UNUSED) { - reqs[i].type = type; - reqs[i].events = 0; - reqs[i].cq = cq; - reqs[i].reqs = reqs; - reqs[i].reqDone = reqDone; - reqDone[i].store(0, std::memory_order_relaxed); - *req = &reqs[i]; - return flagcxSuccess; - } - } - WARN("NET/IB_P2P : unable to allocate request"); - *req = NULL; - return flagcxInternalError; -} - -static inline void flagcxP2pFreeRequest(struct flagcxP2pRequest *req) { - req->type = FLAGCX_P2P_REQ_UNUSED; -} - /* ------------------------------------------------------------------ */ /* Init / Devices / Properties */ /* ------------------------------------------------------------------ */ @@ -420,13 +294,12 @@ static flagcxResult_t flagcxP2pListen(int dev, void *opaqueHandle, } static flagcxResult_t flagcxP2pReleasePd(int ibDevN); -static void flagcxP2pDrainCq(struct ibv_cq *cq); // Helper: set up PD (from eager init), CQs, QPs, and GID for a connection -static flagcxResult_t flagcxP2pSetupConn(int dev, +static flagcxResult_t flagcxP2pSetupConn(int dev, void *outerComm, struct flagcxIbNetCommDevBase *base, - struct flagcxP2pChannel *channels, - int *outIbDevN) { + struct flagcxIbQp *qp_list, + int *outIbDevN, int numQps) { struct flagcxIbMergedDev *mergedDev = flagcxIbMergedDevs + dev; int ibDevN = mergedDev->devs[0]; // v1: single physical NIC *outIbDevN = ibDevN; @@ -440,6 +313,18 @@ static flagcxResult_t flagcxP2pSetupConn(int dev, base->pd = ibDev->pd; pthread_mutex_unlock(&ibDev->lock); + // Step 0: pull the shared CQ from the per-ibDev WorkerPool. The pool is + // lazily created on first call (and lives for the process lifetime). + struct ibv_cq *sharedCq = + flagcxP2pPoolGetSharedCq(ibDevN, ibDev->context); + if (sharedCq == NULL) { + WARN("NET/IB_P2P : pool[%d] returned NULL shared CQ", ibDevN); + flagcxP2pReleasePd(ibDevN); + base->pd = NULL; + return flagcxInternalError; + } + base->cq = sharedCq; + int accessFlags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; @@ -455,32 +340,25 @@ static flagcxResult_t flagcxP2pSetupConn(int dev, res, setup_fail); base->gidInfo.linkLayer = ibDev->link; - // Create RC QPs with remote write, read, and atomic access. - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { - FLAGCXCHECKGOTO(flagcxWrapIbvCreateCq(&channels[i].cq, ibDev->context, - 2 * FLAGCX_P2P_MAX_REQUESTS, NULL, - NULL, 0), - res, setup_fail); - base->cq = channels[i].cq; + for (int i = 0; i < numQps; i++) { FLAGCXCHECKGOTO( - flagcxIbCreateQp(ibDev->portNum, base, accessFlags, &channels[i].qp), + flagcxIbCreateQp(ibDev->portNum, base, accessFlags, &qp_list[i]), res, setup_fail); - channels[i].qp.devIndex = 0; + qp_list[i].devIndex = 0; + flagcxP2pPoolRegisterQp(ibDevN, outerComm, qp_list[i].qp); } return flagcxSuccess; setup_fail: - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { - if (channels[i].qp.qp) { - flagcxWrapIbvDestroyQp(channels[i].qp.qp); - channels[i].qp.qp = NULL; - } - if (channels[i].cq) { - flagcxWrapIbvDestroyCq(channels[i].cq); - channels[i].cq = NULL; + for (int i = 0; i < numQps; i++) { + if (qp_list[i].qp) { + flagcxP2pPoolUnregisterQp(ibDevN, qp_list[i].qp); + flagcxWrapIbvDestroyQp(qp_list[i].qp); + qp_list[i].qp = NULL; } } + // Do not destroy sharedCq — owned by the pool. base->cq = NULL; flagcxP2pReleasePd(ibDevN); base->pd = NULL; @@ -527,44 +405,25 @@ flagcxP2pTransitionQp(struct flagcxIbQp *qp, return flagcxSuccess; } -static void flagcxP2pRegisterChannels(struct flagcxP2pChannel *channels, - struct flagcxP2pRequest *reqs, - std::atomic *reqDone, - std::atomic *cqError) { - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - cqPollerRegister(channels[i].cq, reqs, reqDone, cqError); -} - -static void flagcxP2pUnregisterChannels(struct flagcxP2pChannel *channels) { - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - if (channels[i].cq) - cqPollerUnregister(channels[i].cq); -} - -static flagcxResult_t -flagcxP2pDestroyChannels(struct flagcxP2pChannel *channels) { - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - flagcxP2pDrainCq(channels[i].cq); - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { - if (channels[i].qp.qp) { - FLAGCXCHECK(flagcxWrapIbvDestroyQp(channels[i].qp.qp)); - channels[i].qp.qp = NULL; - } - } - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) { - if (channels[i].cq) { - FLAGCXCHECK(flagcxWrapIbvDestroyCq(channels[i].cq)); - channels[i].cq = NULL; +static flagcxResult_t flagcxP2pDestroyQps(int ibDevN, + struct flagcxIbQp *qp_list, + int numQps) { + for (int i = 0; i < numQps; i++) { + if (qp_list[i].qp) { + flagcxP2pPoolUnregisterQp(ibDevN, qp_list[i].qp); + FLAGCXCHECK(flagcxWrapIbvDestroyQp(qp_list[i].qp)); + qp_list[i].qp = NULL; } } return flagcxSuccess; } -static inline struct flagcxP2pChannel * -flagcxP2pNextChannel(struct flagcxP2pChannel *channels, - std::atomic *nextChannel) { +static inline struct flagcxIbQp * +flagcxP2pNextQp(struct flagcxIbQp *qp_list, + std::atomic *nextChannel, int qpCount) { + uint32_t mod = (qpCount > 0) ? (uint32_t)qpCount : 1u; uint32_t idx = nextChannel->fetch_add(1, std::memory_order_relaxed); - return channels + (idx % FLAGCX_P2P_QPS_PER_CONN); + return qp_list + (idx % mod); } static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, @@ -579,9 +438,10 @@ static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, FLAGCXCHECK(flagcxCalloc(&comm, 1)); int ready = 0; auto connectStart = std::chrono::steady_clock::time_point(); - struct flagcxP2pConnMeta localMeta[FLAGCX_P2P_QPS_PER_CONN]; - struct flagcxP2pConnMeta remoteMeta[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxP2pConnMeta localMeta[kFlagcxP2pMaxQpsPerEngine]; + struct flagcxP2pConnMeta remoteMeta[kFlagcxP2pMaxQpsPerEngine]; int localReady = 1, remoteReady = 0; + uint32_t localNumQps = 0, remoteNumQps = 0, agreedNumQps = 0; // TCP connect (blocking with timeout) FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock, &handle->connectAddr, @@ -603,36 +463,51 @@ static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, } } - // Set up PD, CQs, QPs + // numQps negotiation must happen before setup so we only create the + // QPs we'll actually use; both peers agree on min(). + localNumQps = (uint32_t)flagcxP2pGlobalConfig().qpsPerConn; + if (localNumQps == 0 || localNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) + localNumQps = (uint32_t)kFlagcxP2pMaxQpsPerEngine; + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localNumQps, sizeof(localNumQps)), res, + connect_fail); FLAGCXCHECKGOTO( - flagcxP2pSetupConn(dev, &comm->base, comm->channels, &comm->ibDevN), res, + flagcxSocketRecv(&comm->sock, &remoteNumQps, sizeof(remoteNumQps)), res, connect_fail); + if (remoteNumQps == 0 || remoteNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) { + WARN("NET/IB_P2P : peer advertised invalid numQps=%u (max=%d)", + remoteNumQps, kFlagcxP2pMaxQpsPerEngine); + res = flagcxInternalError; + goto connect_fail; + } + agreedNumQps = std::min(localNumQps, remoteNumQps); + if (localNumQps != remoteNumQps) { + INFO(FLAGCX_NET, + "NET/IB_P2P : numQps mismatch (local=%u remote=%u) — using min=%u", + localNumQps, remoteNumQps, agreedNumQps); + } + comm->numQps = (int)agreedNumQps; + + FLAGCXCHECKGOTO(flagcxP2pSetupConn(dev, comm, &comm->base, comm->qp_list_, + &comm->ibDevN, comm->numQps), + res, connect_fail); - // Exchange connection metadata - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->channels[i].qp, + for (int i = 0; i < comm->numQps; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->qp_list_[i], comm->ibDevN); - FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, sizeof(localMeta)), + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, + comm->numQps * sizeof(localMeta[0])), res, connect_fail); - FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, sizeof(remoteMeta)), + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, + comm->numQps * sizeof(remoteMeta[0])), res, connect_fail); // Transition each matched QP to RTR then RTS. - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->channels[i].qp, &comm->base, + for (int i = 0; i < comm->numQps; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->qp_list_[i], &comm->base, &remoteMeta[i], comm->ibDevN), res, connect_fail); - // Register putSignal scratchpad MR - comm->putSignalScratchpad = 0; - FLAGCXCHECKGOTO( - flagcxWrapIbvRegMr(&comm->putSignalScratchpadMr, comm->base.pd, - &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC), - res, connect_fail); - // Exchange ready FLAGCXCHECKGOTO( flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, @@ -641,16 +516,12 @@ static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, connect_fail); - flagcxP2pRegisterChannels(comm->channels, comm->reqs, comm->reqDone, - &comm->cqError); *sendComm = comm; return flagcxSuccess; connect_fail: - if (comm->putSignalScratchpadMr) - flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr); - flagcxP2pDestroyChannels(comm->channels); + flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps); if (comm->base.pd) flagcxP2pReleasePd(comm->ibDevN); flagcxSocketClose(&comm->sock); @@ -669,9 +540,10 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { // TCP accept (blocking, no timeout) flagcxResult_t res; int ready; - struct flagcxP2pConnMeta localMeta[FLAGCX_P2P_QPS_PER_CONN]; - struct flagcxP2pConnMeta remoteMeta[FLAGCX_P2P_QPS_PER_CONN]; + struct flagcxP2pConnMeta localMeta[kFlagcxP2pMaxQpsPerEngine]; + struct flagcxP2pConnMeta remoteMeta[kFlagcxP2pMaxQpsPerEngine]; int localReady = 1, remoteReady = 0; + uint32_t localNumQps = 0, remoteNumQps = 0, agreedNumQps = 0; FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock), res, accept_fail); FLAGCXCHECKGOTO(flagcxSocketAccept(&comm->sock, &lComm->sock), res, accept_fail); @@ -688,36 +560,51 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { return res; } - // Set up PD, CQs, QPs - FLAGCXCHECKGOTO(flagcxP2pSetupConn(lComm->dev, &comm->base, comm->channels, - &comm->ibDevN), + // accept side mirrors connect: recv numQps first, then send. + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteNumQps, sizeof(remoteNumQps)), res, + accept_cleanup); + localNumQps = (uint32_t)flagcxP2pGlobalConfig().qpsPerConn; + if (localNumQps == 0 || localNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) + localNumQps = (uint32_t)kFlagcxP2pMaxQpsPerEngine; + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localNumQps, sizeof(localNumQps)), res, + accept_cleanup); + if (remoteNumQps == 0 || remoteNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) { + WARN("NET/IB_P2P : peer advertised invalid numQps=%u (max=%d)", + remoteNumQps, kFlagcxP2pMaxQpsPerEngine); + res = flagcxInternalError; + goto accept_cleanup; + } + agreedNumQps = std::min(localNumQps, remoteNumQps); + if (localNumQps != remoteNumQps) { + INFO(FLAGCX_NET, + "NET/IB_P2P : numQps mismatch (local=%u remote=%u) — using min=%u", + localNumQps, remoteNumQps, agreedNumQps); + } + comm->numQps = (int)agreedNumQps; + + FLAGCXCHECKGOTO(flagcxP2pSetupConn(lComm->dev, comm, &comm->base, + comm->qp_list_, &comm->ibDevN, + comm->numQps), res, accept_cleanup); - // Exchange connection metadata (accept receives first, then sends) - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->channels[i].qp, + for (int i = 0; i < comm->numQps; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->qp_list_[i], comm->ibDevN); - FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, sizeof(remoteMeta)), + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, + comm->numQps * sizeof(remoteMeta[0])), res, accept_cleanup); - FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, sizeof(localMeta)), + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, + comm->numQps * sizeof(localMeta[0])), res, accept_cleanup); // Transition each matched QP to RTR then RTS. - for (int i = 0; i < FLAGCX_P2P_QPS_PER_CONN; i++) - FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->channels[i].qp, &comm->base, + for (int i = 0; i < comm->numQps; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->qp_list_[i], &comm->base, &remoteMeta[i], comm->ibDevN), res, accept_cleanup); - // Register putSignal scratchpad MR (symmetric with connect) - comm->putSignalScratchpad = 0; - FLAGCXCHECKGOTO( - flagcxWrapIbvRegMr(&comm->putSignalScratchpadMr, comm->base.pd, - &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC), - res, accept_cleanup); - // Exchange ready FLAGCXCHECKGOTO( flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, @@ -726,16 +613,12 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, accept_cleanup); - flagcxP2pRegisterChannels(comm->channels, comm->reqs, comm->reqDone, - &comm->cqError); *recvComm = comm; return flagcxSuccess; accept_cleanup: - if (comm->putSignalScratchpadMr) - flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr); - flagcxP2pDestroyChannels(comm->channels); + flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps); if (comm->base.pd) flagcxP2pReleasePd(comm->ibDevN); flagcxSocketClose(&comm->sock); @@ -747,100 +630,66 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { /* One-sided transfers: iput / iget / iputSignal */ /* ------------------------------------------------------------------ */ +static flagcxResult_t flagcxP2pBuildSingleSliceReq( + struct flagcxP2pSendComm *comm, uint64_t localVa, uint64_t remoteVa, + size_t size, uint32_t lkey, uint32_t rkey, uint8_t opcode, + void **request) { + if ((uint32_t)size != size) { + WARN("NET/IB_P2P : single-op size %zu exceeds 32-bit limit", size); + return flagcxInternalError; + } + + auto *req = new struct flagcxP2pSliceReq; + req->slice.srcVa = localVa; + req->slice.dstVa = remoteVa; + req->slice.length = (uint32_t)size; + req->slice.lkey = lkey; + req->slice.rkey = rkey; + req->slice.opcode = opcode; + req->slice.peerNicPath = std::string(); + req->slice.task = &req->task; + req->slice.qpDepth = NULL; + req->task.sliceList.push_back(&req->slice); + req->task.sliceCount.fetch_add(1, std::memory_order_release); + + FlagcxSlice *slicePtr = &req->slice; + flagcxResult_t rc = + flagcxP2pPoolSubmit(comm->ibDevN, comm, &slicePtr, 1); + if (rc != flagcxSuccess) { + delete req; + return rc; + } + + *request = req; + return flagcxSuccess; +} + static flagcxResult_t flagcxP2pIput(void *sendComm, uint64_t srcOff, uint64_t dstOff, size_t size, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **request) { + (void)srcRank; + (void)dstRank; struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; - - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, - FLAGCX_P2P_REQ_IPUT, &req)); - struct flagcxP2pChannel *channel = - flagcxP2pNextChannel(comm->channels, &comm->nextChannel); - req->cq = channel->cq; - - struct ibv_sge sge; - memset(&sge, 0, sizeof(sge)); - sge.addr = src->baseVa + srcOff; - sge.length = (uint32_t)size; - if ((size_t)sge.length != size) { - WARN("NET/IB_P2P : iput size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); - return flagcxInternalError; - } - sge.lkey = src->lkey; - - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.send_flags = IBV_SEND_SIGNALED; - wr.wr_id = req - comm->reqs; - wr.wr.rdma.remote_addr = dst->baseVa + dstOff; - wr.wr.rdma.rkey = dst->rkey; - wr.sg_list = &sge; - wr.num_sge = 1; - - req->events = 1; - struct ibv_send_wr *bad_wr; - flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wr, &bad_wr); - if (res != flagcxSuccess) { - flagcxP2pFreeRequest(req); - return res; - } - - *request = req; - return flagcxSuccess; + return flagcxP2pBuildSingleSliceReq( + comm, src->baseVa + srcOff, dst->baseVa + dstOff, size, src->lkey, + dst->rkey, FLAGCX_SLICE_OP_WRITE, request); } static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, uint64_t dstOff, size_t size, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **request) { + (void)srcRank; + (void)dstRank; struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; - - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, - FLAGCX_P2P_REQ_IGET, &req)); - struct flagcxP2pChannel *channel = - flagcxP2pNextChannel(comm->channels, &comm->nextChannel); - req->cq = channel->cq; - - struct ibv_sge sge; - memset(&sge, 0, sizeof(sge)); - sge.addr = dst->baseVa + dstOff; - sge.length = (uint32_t)size; - if ((size_t)sge.length != size) { - WARN("NET/IB_P2P : iget size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); - return flagcxInternalError; - } - sge.lkey = dst->lkey; - - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.opcode = IBV_WR_RDMA_READ; - wr.send_flags = IBV_SEND_SIGNALED; - wr.wr_id = req - comm->reqs; - wr.wr.rdma.remote_addr = src->baseVa + srcOff; - wr.wr.rdma.rkey = src->rkey; - wr.sg_list = &sge; - wr.num_sge = 1; - - req->events = 1; - struct ibv_send_wr *bad_wr; - flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wr, &bad_wr); - if (res != flagcxSuccess) { - flagcxP2pFreeRequest(req); - return res; - } - - *request = req; - return flagcxSuccess; + return flagcxP2pBuildSingleSliceReq( + comm, dst->baseVa + dstOff, src->baseVa + srcOff, size, dst->lkey, + src->rkey, FLAGCX_SLICE_OP_READ, request); } static flagcxResult_t @@ -848,139 +697,142 @@ flagcxP2pIgetBatch(void *sendComm, int count, const uint64_t *srcOffs, const uint64_t *dstOffs, const size_t *sizes, int srcRank, int dstRank, void *const *srcHandles, void *const *dstHandles, void **request) { + (void)srcRank; + (void)dstRank; struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; - if (count <= 0 || count > FLAGCX_P2P_IGET_BATCH_MAX_WR || srcOffs == NULL || + const int maxWrPerPost = (int)flagcxP2pGlobalConfig().maxWrPerPost; + if (count <= 0 || count > maxWrPerPost || srcOffs == NULL || dstOffs == NULL || sizes == NULL || srcHandles == NULL || dstHandles == NULL || request == NULL) { - WARN("NET/IB_P2P : invalid igetBatch arguments, count %d", count); + WARN("NET/IB_P2P : invalid igetBatch arguments, count %d (max %d)", count, + maxWrPerPost); return flagcxInternalError; } - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, - FLAGCX_P2P_REQ_IGET, &req)); - struct flagcxP2pChannel *channel = - flagcxP2pNextChannel(comm->channels, &comm->nextChannel); - req->cq = channel->cq; - - struct ibv_send_wr wrs[FLAGCX_P2P_IGET_BATCH_MAX_WR]; - struct ibv_sge sges[FLAGCX_P2P_IGET_BATCH_MAX_WR]; - memset(wrs, 0, sizeof(wrs)); - memset(sges, 0, sizeof(sges)); - + auto *req = new struct flagcxP2pSliceReq; + req->task.sliceList.reserve(count); for (int i = 0; i < count; i++) { - if (srcHandles[i] == NULL || dstHandles[i] == NULL) { - WARN("NET/IB_P2P : igetBatch handle %d is NULL", i); - flagcxP2pFreeRequest(req); - return flagcxInternalError; - } - - struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles[i]; - struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles[i]; - - sges[i].addr = dst->baseVa + dstOffs[i]; - sges[i].length = (uint32_t)sizes[i]; - if ((size_t)sges[i].length != sizes[i]) { - WARN("NET/IB_P2P : igetBatch size %zu exceeds 32-bit limit", sizes[i]); - flagcxP2pFreeRequest(req); + if (srcHandles[i] == NULL || dstHandles[i] == NULL || + (uint32_t)sizes[i] != sizes[i]) { + WARN("NET/IB_P2P : igetBatch slice %d invalid", i); + for (auto *s : req->task.sliceList) + delete s; + delete req; return flagcxInternalError; } - sges[i].lkey = dst->lkey; - - wrs[i].opcode = IBV_WR_RDMA_READ; - wrs[i].send_flags = - i == count - 1 ? IBV_SEND_SIGNALED : 0; // final CQE tracks batch - wrs[i].wr_id = i == count - 1 ? req - comm->reqs : 0; - wrs[i].wr.rdma.remote_addr = src->baseVa + srcOffs[i]; - wrs[i].wr.rdma.rkey = src->rkey; - wrs[i].sg_list = &sges[i]; - wrs[i].num_sge = 1; - wrs[i].next = i + 1 < count ? &wrs[i + 1] : NULL; + auto *src = (struct flagcxP2pMrHandle *)srcHandles[i]; + auto *dst = (struct flagcxP2pMrHandle *)dstHandles[i]; + auto *s = new FlagcxSlice{dst->baseVa + dstOffs[i], + src->baseVa + srcOffs[i], + (uint32_t)sizes[i], + dst->lkey, + src->rkey, + FLAGCX_SLICE_OP_READ, + std::string(), + &req->task, + NULL}; + req->task.sliceList.push_back(s); + req->task.sliceCount.fetch_add(1, std::memory_order_release); } - req->events = 1; - struct ibv_send_wr *bad_wr; - flagcxResult_t res = flagcxWrapIbvPostSend(channel->qp.qp, &wrs[0], &bad_wr); - if (res != flagcxSuccess) { - flagcxP2pFreeRequest(req); - return res; + flagcxResult_t rc = flagcxP2pPoolSubmit(comm->ibDevN, comm, + req->task.sliceList.data(), count); + if (rc != flagcxSuccess) { + for (auto *s : req->task.sliceList) + delete s; + delete req; + return rc; } - *request = req; return flagcxSuccess; } -static flagcxResult_t -flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, - size_t size, int srcRank, int dstRank, void **srcHandles, - void **dstHandles, uint64_t signalOff, void **signalHandles, - uint64_t signalValue, void **request) { +static flagcxResult_t flagcxP2pIputSignal(void *, uint64_t, uint64_t, size_t, + int, int, void **, void **, uint64_t, + void **, uint64_t, void **) { + WARN("NET/IB_P2P : iputSignal not supported"); + return flagcxInternalError; +} + +/* ------------------------------------------------------------------ */ +/* Slice batch: pool worker passes the chosen QP. wr_id = ptr|1. */ +/* ------------------------------------------------------------------ */ + +static inline uint32_t flagcxSliceOpcodeToVerbs(uint8_t op) { + return op == FLAGCX_SLICE_OP_READ ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; +} + +extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, + struct ibv_qp *qp, int count, + FlagcxSlice **slices) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; - struct flagcxP2pMrHandle *signalInfo = - (struct flagcxP2pMrHandle *)signalHandles; - - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->reqDone, NULL, - FLAGCX_P2P_REQ_IPUT, &req)); - struct flagcxP2pChannel *channel = - flagcxP2pNextChannel(comm->channels, &comm->nextChannel); - req->cq = channel->cq; - - bool chainData = (size > 0 && srcHandles != NULL && dstHandles != NULL); - - struct ibv_sge sge[2]; - struct ibv_send_wr wr[2]; - memset(sge, 0, sizeof(sge)); - memset(wr, 0, sizeof(wr)); - - // wr[0]: RDMA WRITE for data (unsignaled, chained to wr[1]) - if (chainData) { - struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; - struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; - - sge[0].addr = src->baseVa + srcOff; - sge[0].length = (uint32_t)size; - if ((size_t)sge[0].length != size) { - WARN("NET/IB_P2P : iputSignal size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); + const char *opLabel = + (slices != NULL && count > 0 && slices[0] != NULL && + slices[0]->opcode == FLAGCX_SLICE_OP_READ) + ? "READ" + : "WRITE"; + const int maxWrPerPost = (int)flagcxP2pGlobalConfig().maxWrPerPost; + if (count <= 0 || count > maxWrPerPost || slices == NULL || + qp == NULL || comm == NULL) { + WARN("NET/IB_P2P : invalid sliceBatch arguments (op=%s, count=%d, qp=%p, " + "max=%d)", + opLabel, count, (void *)qp, maxWrPerPost); + return flagcxInternalError; + } + + // count can be up to flagcxP2pGlobalConfig().maxWrPerPost (default 256, + // bounded at 1024). Heap-allocate to keep the stack small. + std::vector wrs(count); + std::vector sges(count); + + for (int i = 0; i < count; i++) { + FlagcxSlice *s = slices[i]; + if (s == NULL) { + WARN("NET/IB_P2P : sliceBatch slice[%d] is NULL", i); + for (int k = 0; k < i; k++) + slices[k]->markFailed(); + for (int k = i; k < count; k++) + if (slices[k]) + slices[k]->markFailed(); return flagcxInternalError; } - sge[0].lkey = src->lkey; - - wr[0].opcode = IBV_WR_RDMA_WRITE; - wr[0].send_flags = 0; // unsignaled - wr[0].wr.rdma.remote_addr = dst->baseVa + dstOff; - wr[0].wr.rdma.rkey = dst->rkey; - wr[0].sg_list = &sge[0]; - wr[0].num_sge = 1; - wr[0].next = &wr[1]; // chain to atomic + + sges[i].addr = s->srcVa; + sges[i].length = s->length; + sges[i].lkey = s->lkey; + + wrs[i].opcode = flagcxSliceOpcodeToVerbs(s->opcode); + wrs[i].send_flags = IBV_SEND_SIGNALED; + wrs[i].wr_id = ((uintptr_t)s) | 1ull; + wrs[i].wr.rdma.remote_addr = s->dstVa; + wrs[i].wr.rdma.rkey = s->rkey; + wrs[i].sg_list = &sges[i]; + wrs[i].num_sge = 1; + wrs[i].next = (i + 1 < count) ? &wrs[i + 1] : NULL; } - // wr[1]: ATOMIC FETCH_AND_ADD for signal (signaled) - sge[1].addr = (uintptr_t)&comm->putSignalScratchpad; - sge[1].length = sizeof(comm->putSignalScratchpad); - sge[1].lkey = comm->putSignalScratchpadMr->lkey; - - wr[1].opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; - wr[1].send_flags = IBV_SEND_SIGNALED; - wr[1].wr_id = req - comm->reqs; - wr[1].wr.atomic.remote_addr = signalInfo->baseVa + signalOff; - wr[1].wr.atomic.rkey = signalInfo->rkey; - wr[1].wr.atomic.compare_add = signalValue; - wr[1].sg_list = &sge[1]; - wr[1].num_sge = 1; - wr[1].next = NULL; - - req->events = 1; - struct ibv_send_wr *bad_wr; - flagcxResult_t res = flagcxWrapIbvPostSend( - channel->qp.qp, chainData ? &wr[0] : &wr[1], &bad_wr); + struct ibv_send_wr *bad_wr = NULL; + flagcxResult_t res = flagcxWrapIbvPostSend(qp, &wrs[0], &bad_wr); if (res != flagcxSuccess) { - flagcxP2pFreeRequest(req); + int failedFrom = 0; + if (bad_wr != NULL) { + ptrdiff_t off = bad_wr - &wrs[0]; + if (off >= 0 && off < count) + failedFrom = (int)off; + } + // Slices in [failedFrom..count) never went on the wire — roll back + // their share of the pool's qpDepth pre-bump so the gate doesn't leak. + for (int k = failedFrom; k < count; k++) { + if (slices[k]->qpDepth != NULL) + __sync_fetch_and_sub(slices[k]->qpDepth, 1); + slices[k]->markFailed(); + } + WARN("NET/IB_P2P : sliceBatch ibv_post_send failed (op=%s, count=%d, " + "failedFrom=%d)", + opLabel, count, failedFrom); return res; } - *request = req; return flagcxSuccess; } @@ -988,25 +840,36 @@ flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, /* Test */ /* ------------------------------------------------------------------ */ +// Single-slice path uses the wrapper's embedded `slice`; batch path +// heap-allocates each — distinguish by address. +static inline void flagcxP2pFreeSliceReq(struct flagcxP2pSliceReq *req) { + if (!req) + return; + for (auto *s : req->task.sliceList) { + if (s != &req->slice) + delete s; + } + delete req; +} + static flagcxResult_t flagcxP2pTest(void *request, int *done, int *sizes) { *done = 0; - struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)request; - if (req == NULL || req->type == FLAGCX_P2P_REQ_UNUSED) { + if (sizes) + *sizes = 0; + if (request == NULL) { *done = 1; return flagcxSuccess; } - - uint32_t idx = (uint32_t)(req - req->reqs); - if (idx >= FLAGCX_P2P_MAX_REQUESTS) { - WARN("NET/IB_P2P : invalid request index %u in test()", idx); - return flagcxInternalError; - } - - if (req->reqDone[idx].load(std::memory_order_acquire)) { - req->reqDone[idx].store(0, std::memory_order_relaxed); + auto *req = static_cast(request); + if (req->task.isAllDone()) { *done = 1; - if (sizes) - *sizes = 0; + if (sizes) { + uint64_t total = 0; + for (auto *s : req->task.sliceList) + total += s->length; + *sizes = (int)std::min(total, (uint64_t)INT32_MAX); + } + flagcxP2pFreeSliceReq(req); } return flagcxSuccess; } @@ -1016,21 +879,16 @@ static flagcxResult_t flagcxP2pTestBatch(void **requests, int nRequests, int completed = 0; for (int i = 0; i < nRequests; i++) { doneFlags[i] = 0; - struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)requests[i]; - if (req == NULL || req->type == FLAGCX_P2P_REQ_UNUSED) { + auto *req = static_cast(requests[i]); + if (req == NULL) { doneFlags[i] = 1; completed++; continue; } - - uint32_t idx = (uint32_t)(req - req->reqs); - if (idx >= FLAGCX_P2P_MAX_REQUESTS) - continue; - - if (req->reqDone[idx].load(std::memory_order_acquire)) { - req->reqDone[idx].store(0, std::memory_order_relaxed); + if (req->task.isAllDone()) { doneFlags[i] = 1; completed++; + flagcxP2pFreeSliceReq(req); } } *doneCount = completed; @@ -1059,26 +917,10 @@ static flagcxResult_t flagcxP2pReleasePd(int ibDevN) { return flagcxSuccess; } -// Helper: drain CQ before destroying resources -static void flagcxP2pDrainCq(struct ibv_cq *cq) { - if (!cq) - return; - struct ibv_wc wcs[64]; - int nCqe = 0; - for (int i = 0; i < 16; i++) { - flagcxWrapIbvPollCq(cq, 64, wcs, &nCqe); - if (nCqe == 0) - break; - } -} - static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; if (comm) { - flagcxP2pUnregisterChannels(comm->channels); - FLAGCXCHECK(flagcxP2pDestroyChannels(comm->channels)); - if (comm->putSignalScratchpadMr) - FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); + FLAGCXCHECK(flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); @@ -1089,10 +931,7 @@ static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { static flagcxResult_t flagcxP2pCloseRecv(void *recvComm) { struct flagcxP2pRecvComm *comm = (struct flagcxP2pRecvComm *)recvComm; if (comm) { - flagcxP2pUnregisterChannels(comm->channels); - FLAGCXCHECK(flagcxP2pDestroyChannels(comm->channels)); - if (comm->putSignalScratchpadMr) - FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); + FLAGCXCHECK(flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 7cddf31f..aec5bac9 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -16,12 +16,15 @@ #include "flagcx_net.h" #include "flagcx_net_adaptor.h" #include "ib_common.h" +#include "ibvwrap.h" #include "p2p_topo.h" +#include "param.h" #include "socket.h" #include #include #include +#include #include #include #include @@ -41,6 +44,122 @@ extern struct flagcxNetAdaptor flagcxNetIbP2p; +struct FlagcxSlice; + +extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, + struct ibv_qp *qp, int count, + FlagcxSlice **slices); + +namespace { + +FLAGCX_PARAM(P2pQpsPerConn, "P2P_QPS_PER_CONN", 4); +FLAGCX_PARAM(P2pWorkersPerPool, "P2P_WORKERS_PER_POOL", 2); +FLAGCX_PARAM(P2pShardCount, "P2P_SHARD_COUNT", 8); +FLAGCX_PARAM(P2pCqDepth, "P2P_CQ_DEPTH", 4096); +FLAGCX_PARAM(P2pMaxWrPerPost, "P2P_MAX_WR_PER_POST", 64); +FLAGCX_PARAM(P2pMaxRequests, "P2P_MAX_REQUESTS", 256); +FLAGCX_PARAM(P2pBatchPollSize, "P2P_BATCH_POLL_SIZE", 32); +FLAGCX_PARAM(P2pReadBatchWindow, "P2P_READ_BATCH_WINDOW", 8); +FLAGCX_PARAM(P2pSliceSize, "P2P_SLICE_SIZE", 65536); +FLAGCX_PARAM(P2pFragmentLimit, "P2P_FRAGMENT_LIMIT", 4096); +FLAGCX_PARAM(P2pMaxSge, "P2P_MAX_SGE", 4); +FLAGCX_PARAM(P2pMaxInline, "P2P_MAX_INLINE", 64); +FLAGCX_PARAM(P2pIbPort, "P2P_IB_PORT", 1); +FLAGCX_PARAM(P2pGidIndex, "P2P_GID_INDEX", -1); +FLAGCX_PARAM(P2pMtu, "P2P_MTU", 4096); +FLAGCX_PARAM(P2pIbTc, "P2P_IB_TC", -1); +FLAGCX_PARAM(P2pRetryCnt, "P2P_RETRY_CNT", 7); +FLAGCX_PARAM(P2pNotifMaxPeers, "P2P_NOTIF_MAX_PEERS", 64); +FLAGCX_PARAM(P2pDestDevAffinity, "P2P_DEST_DEV_AFFINITY", 0); + +template +inline T clampParam(int64_t v, T lo, T hi, T deft, const char *name) { + if (v < (int64_t)lo || v > (int64_t)hi) { + INFO(FLAGCX_INIT, + "Ignore FLAGCX_%s=%lld (out of [%lld,%lld]); using default %lld", + name, (long long)v, (long long)lo, (long long)hi, (long long)deft); + return deft; + } + return (T)v; +} + +void loadGlobalConfig(FlagcxP2pGlobalConfig &c) { + c.qpsPerConn = clampParam(flagcxParamP2pQpsPerConn(), 1, kFlagcxP2pMaxQpsPerEngine, 4, "P2P_QPS_PER_CONN"); + c.workersPerPool = clampParam(flagcxParamP2pWorkersPerPool(), 1, 8, 2, "P2P_WORKERS_PER_POOL"); + c.shardCount = clampParam(flagcxParamP2pShardCount(), 1, 64, 8, "P2P_SHARD_COUNT"); + c.sharedCqDepth = clampParam(flagcxParamP2pCqDepth(), 1, 1u<<20, 4096, "P2P_CQ_DEPTH"); + c.maxWrPerPost = clampParam(flagcxParamP2pMaxWrPerPost(),1, 1024, 256, "P2P_MAX_WR_PER_POST"); + c.maxRequests = clampParam(flagcxParamP2pMaxRequests(), 1, 1u<<16, 256, "P2P_MAX_REQUESTS"); + c.batchPollSize = clampParam(flagcxParamP2pBatchPollSize(), 1, 256, 32, "P2P_BATCH_POLL_SIZE"); + c.readBatchWindow = clampParam(flagcxParamP2pReadBatchWindow(), 1, 256, 8, "P2P_READ_BATCH_WINDOW"); + c.sliceSize = clampParam(flagcxParamP2pSliceSize(), 1024, 1u<<26, 65536, "P2P_SLICE_SIZE"); + c.fragmentLimit = clampParam(flagcxParamP2pFragmentLimit(), 0, c.sliceSize, 4096, "P2P_FRAGMENT_LIMIT"); + c.maxSge = clampParam(flagcxParamP2pMaxSge(), 1, 32, 4, "P2P_MAX_SGE"); + c.maxInline = clampParam(flagcxParamP2pMaxInline(), 0, 1024, 64, "P2P_MAX_INLINE"); + c.ibPort = clampParam(flagcxParamP2pIbPort(), 1, 255, 1, "P2P_IB_PORT"); + c.gidIndex = clampParam(flagcxParamP2pGidIndex(), -1, 255, -1, "P2P_GID_INDEX"); + { + int64_t mv = flagcxParamP2pMtu(); + if (mv == 512 || mv == 1024 || mv == 2048 || mv == 4096) { + c.mtuLength = (int)mv; + } else { + WARN("Ignore FLAGCX_P2P_MTU=%lld (must be 512/1024/2048/4096); using 4096", + (long long)mv); + c.mtuLength = 4096; + } + } + c.ibTrafficClass = clampParam(flagcxParamP2pIbTc(), -1, 255, -1, "P2P_IB_TC"); + c.retryCnt = clampParam(flagcxParamP2pRetryCnt(), 0, 7, 7, "P2P_RETRY_CNT"); + c.notifMaxPeers = clampParam(flagcxParamP2pNotifMaxPeers(), 1, 1024, 64, "P2P_NOTIF_MAX_PEERS"); + c.enableDestDeviceAffinity = (flagcxParamP2pDestDevAffinity() != 0); +} + +void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c); + +FlagcxP2pGlobalConfig &mutableGlobalConfig() { + static FlagcxP2pGlobalConfig cfg; + static std::once_flag once; + // Important: dumpGlobalConfigImpl reads cfg directly (no recursion back + // through mutableGlobalConfig), so the lambda is safe to call on the + // same thread that holds the once_flag. + std::call_once(once, [] { + loadGlobalConfig(cfg); + dumpGlobalConfigImpl(cfg); + }); + return cfg; +} + +void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c) { + INFO(FLAGCX_INIT, "=== FlagCX P2P GlobalConfig ==="); + INFO(FLAGCX_INIT, + "qpsPerConn=%d workersPerPool=%d shardCount=%d", + c.qpsPerConn, c.workersPerPool, c.shardCount); + INFO(FLAGCX_INIT, + "sharedCqDepth=%zu maxWrPerPost=%zu maxRequests=%zu " + "batchPollSize=%zu readBatchWindow=%zu", + c.sharedCqDepth, c.maxWrPerPost, c.maxRequests, c.batchPollSize, + c.readBatchWindow); + INFO(FLAGCX_INIT, "sliceSize=%zu fragmentLimit=%zu", + c.sliceSize, c.fragmentLimit); + INFO(FLAGCX_INIT, + "ibPort=%u gidIndex=%d mtu=%d tc=%d retry=%d " + "maxSge=%zu maxInline=%zu", + (unsigned)c.ibPort, c.gidIndex, c.mtuLength, c.ibTrafficClass, + c.retryCnt, c.maxSge, c.maxInline); + INFO(FLAGCX_INIT, "notifMaxPeers=%d destDevAffinity=%d", + c.notifMaxPeers, (int)c.enableDestDeviceAffinity); +} + +} // namespace + +const FlagcxP2pGlobalConfig &flagcxP2pGlobalConfig() { + return mutableGlobalConfig(); +} + +void flagcxP2pDumpGlobalConfig() { + dumpGlobalConfigImpl(flagcxP2pGlobalConfig()); +} + struct FlagcxP2pMrHandleView { uintptr_t baseVa; uint32_t lkey; @@ -56,17 +175,10 @@ struct FlagcxP2pListenHandleView { static_assert(sizeof(FlagcxP2pListenHandleView) <= FLAGCX_NET_HANDLE_MAXSIZE, "listen handle must fit in FLAGCX_NET_HANDLE_MAXSIZE"); -static constexpr int kP2pQpsPerConn = 4; - -struct FlagcxP2pChannelView { - struct ibv_cq *cq; - struct flagcxIbQp qp; -}; - struct FlagcxP2pCommView { int ibDevN; struct flagcxIbNetCommDevBase base; - struct FlagcxP2pChannelView channels[kP2pQpsPerConn]; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; }; @@ -130,7 +242,6 @@ struct FlagcxP2pEngine { #if defined(__linux__) int notifEpollFd; #endif - std::thread notifThread; std::atomic stopNotif; std::unordered_map notifPeers; std::mutex notifPeerMutex; @@ -190,474 +301,660 @@ static std::unordered_map gXferMap; static std::mutex gXferMutex; static uint64_t gNextXferId = 1; -/* ------------------------------------------------------------------ */ -/* Async Transfer Worker Infrastructure */ -/* ------------------------------------------------------------------ */ - -static constexpr int kWindowSize = 8; -static constexpr int kIgetBatchSize = 64; - -enum FlagcxSlicePolicyKind { - FLAGCX_POLICY_NIXL = 0, - FLAGCX_POLICY_FLAGCX = 1, -}; - -struct FlagcxTransferTask; - -struct FlagcxSlice { - // WRITE: local source VA; READ: local destination VA. - uint64_t srcVa = 0; - // WRITE: remote destination VA; READ: remote source VA. - uint64_t dstVa = 0; - uint32_t length = 0; - uint32_t lkey = 0; - uint32_t rkey = 0; - uint8_t opcode = 0; - FlagcxTransferTask *task = nullptr; - - void markSuccess(); - void markFailed(); -}; - struct FlagcxTransferTask { - FlagcxP2pConn *conn = nullptr; std::atomic sliceCount{0}; std::atomic doneSliceCount{0}; - std::atomic result{0}; std::vector sliceList; bool isAllDone() const { - const uint64_t total = sliceCount.load(std::memory_order_acquire); - const uint64_t done = doneSliceCount.load(std::memory_order_acquire); + auto total = sliceCount.load(std::memory_order_acquire); + auto done = doneSliceCount.load(std::memory_order_acquire); return total > 0 && done >= total; } }; -void FlagcxSlice::markSuccess() { - if (task) - task->doneSliceCount.fetch_add(1, std::memory_order_release); -} +enum FlagcxSliceOp : uint8_t { + FLAGCX_SLICE_OP_WRITE = 0, + FLAGCX_SLICE_OP_READ = 1, +}; + +struct FlagcxSlice { + // WRITE: local source VA; READ: local destination VA. + uint64_t srcVa = 0; + // WRITE: remote destination VA; READ: remote source VA. + uint64_t dstVa; + uint32_t length; + uint32_t lkey; + uint32_t rkey; + uint8_t opcode; + std::string peerNicPath; + FlagcxTransferTask *task; + volatile int *qpDepth; -void FlagcxSlice::markFailed() { - if (task) { - task->result.store(-1, std::memory_order_release); - task->doneSliceCount.fetch_add(1, std::memory_order_release); + inline void markSuccess() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); } -} -struct NixlSlicePolicy { + inline void markFailed() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } +}; + +struct FlagcxNixlSlicePolicy { static constexpr bool kFurtherCut = false; static constexpr size_t kBlockSize = SIZE_MAX; static constexpr size_t kFragmentSize = 0; }; -struct FlagcxSlicePolicy { +struct FlagcxConnectorSlicePolicy { static constexpr bool kFurtherCut = true; static constexpr size_t kBlockSize = 64 * 1024; static constexpr size_t kFragmentSize = 4 * 1024; }; template -static int buildSlices(FlagcxTransferTask *task, uint64_t srcVa, uint64_t dstVa, - size_t totalLen, uint32_t lkey, uint32_t rkey, - uint8_t opcode) { - if (task == nullptr || totalLen == 0) - return -1; - - if constexpr (!Policy::kFurtherCut) { - if (totalLen > UINT32_MAX) - return -1; - FlagcxSlice *slice = new FlagcxSlice; - slice->srcVa = srcVa; - slice->dstVa = dstVa; - slice->length = static_cast(totalLen); - slice->lkey = lkey; - slice->rkey = rkey; - slice->opcode = opcode; - slice->task = task; - task->sliceList.push_back(slice); +inline void flagcxBuildSlices(FlagcxTransferTask *task, uint64_t srcVa, + uint64_t dstVa, size_t totalLen, uint32_t lkey, + uint32_t rkey, uint8_t opcode, + const std::string &peerNicPath) { + if (!Policy::kFurtherCut) { + auto *s = new FlagcxSlice{srcVa, dstVa, (uint32_t)totalLen, + lkey, rkey, opcode, + peerNicPath, task, nullptr}; + task->sliceList.push_back(s); task->sliceCount.fetch_add(1, std::memory_order_release); - return 0; - } else { - size_t off = 0; - while (off < totalLen) { - const size_t remaining = totalLen - off; - const bool mergeTail = - remaining <= Policy::kBlockSize + Policy::kFragmentSize; - const size_t len = mergeTail ? remaining : Policy::kBlockSize; - if (len > UINT32_MAX) - return -1; + return; + } - FlagcxSlice *slice = new FlagcxSlice; - slice->srcVa = srcVa + off; - slice->dstVa = dstVa + off; - slice->length = static_cast(len); - slice->lkey = lkey; - slice->rkey = rkey; - slice->opcode = opcode; - slice->task = task; - task->sliceList.push_back(slice); - task->sliceCount.fetch_add(1, std::memory_order_release); - - off += len; - if (mergeTail) - break; - } - return 0; + size_t off = 0; + while (off < totalLen) { + bool merge = + (totalLen - off) <= Policy::kBlockSize + Policy::kFragmentSize; + size_t len = merge ? (totalLen - off) : Policy::kBlockSize; + auto *s = new FlagcxSlice{srcVa + off, dstVa + off, (uint32_t)len, + lkey, rkey, opcode, + peerNicPath, task, nullptr}; + task->sliceList.push_back(s); + task->sliceCount.fetch_add(1, std::memory_order_release); + off += len; + if (merge) + break; } } -static void cleanupTransferTask(FlagcxTransferTask *task) { - if (task == nullptr) - return; - for (FlagcxSlice *slice : task->sliceList) - delete slice; - delete task; -} +namespace { -struct AsyncWorker { - std::thread thread; - pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - pthread_cond_t cv = PTHREAD_COND_INITIALIZER; - std::deque> queue; - std::atomic stop{false}; +struct PoolSubmitItem { + void *sendComm; // adaptor sendComm view + FlagcxSlice *slice; // owned by caller (engine ReadVector/WriteVector) }; -static AsyncWorker gAsyncWorker; +static void notifPollThreadFunc(FlagcxP2pEngine *engine); -static FlagcxP2pCommView *getCommView(void *comm) { - return reinterpret_cast(comm); -} +struct PoolQpEntry { + struct ibv_qp *qp; + void *sendComm; // owning conn (flagcxP2pSendComm/RecvComm) + volatile int wrDepth; -struct AsyncReadBatchEntry { - void *request = nullptr; - int firstSlice = 0; - int count = 0; + PoolQpEntry(struct ibv_qp *q, void *sc) : qp(q), sendComm(sc), wrDepth(0) {} + PoolQpEntry(const PoolQpEntry &) = delete; + PoolQpEntry &operator=(const PoolQpEntry &) = delete; }; -static inline bool isReadSlice(const FlagcxSlice *slice) { - return slice != nullptr && slice->opcode == IBV_WR_RDMA_READ; +class FlagcxWorkerPool { +public: + FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx); + ~FlagcxWorkerPool(); + FlagcxWorkerPool(const FlagcxWorkerPool &) = delete; + FlagcxWorkerPool &operator=(const FlagcxWorkerPool &) = delete; + + struct ibv_cq *getSharedCq() const { return shared_cq_; } + void registerQp(void *sendComm, struct ibv_qp *qp); + void unregisterQp(struct ibv_qp *qp); + + flagcxResult_t submitPostSend(void *sendComm, FlagcxSlice **slices, + int count); + + void startNotif(FlagcxP2pEngine *engine); + void stopNotif(); + +private: + void transferWorkerLoop(int tid); + void performPostSend(int tid); + void performPollCq(); + void notifWorkerLoop(); + + static uint64_t nowNs() { + using clk = std::chrono::steady_clock; + return std::chrono::duration_cast( + clk::now().time_since_epoch()) + .count(); + } + + int ibDevN_; + struct ibv_cq *shared_cq_ = nullptr; + + int numWorkers_; + int numShards_; + size_t maxWrPerPost_; + size_t batchPollSize_; + int maxWrDepth_ = 0; + + std::mutex qp_mu_; + std::vector> qpEntries_; + std::unordered_map qpNumToIdx_; + std::vector> workerQpIdx_; + std::vector workerQpCursor_; + std::atomic qpRegisterCounter_{0}; + + std::vector> slice_queues_; + std::unique_ptr slice_locks_; + std::atomic shardRoundRobin_{0}; + + std::atomic submitted_{0}; + std::atomic processed_{0}; + std::atomic suspended_flag_{0}; + std::condition_variable cv_; + std::mutex cv_mu_; + + std::atomic running_{true}; + std::vector transferThreads_; + + FlagcxP2pEngine *engine_ = nullptr; + std::thread notifThread_; + std::atomic notifSpawned_{false}; +}; + +FlagcxWorkerPool::FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx) + : ibDevN_(ibDevN) { + const auto &C = flagcxP2pGlobalConfig(); + numWorkers_ = C.workersPerPool; + numShards_ = C.shardCount; + maxWrPerPost_ = C.maxWrPerPost; + batchPollSize_ = C.batchPollSize; + + if (numWorkers_ > 0 && numShards_ % numWorkers_ != 0) { + int rounded = + ((numShards_ + numWorkers_ - 1) / numWorkers_) * numWorkers_; + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] rounded shardCount %d → %d for even " + "worker assignment (W=%d)", + ibDevN_, numShards_, rounded, numWorkers_); + numShards_ = rounded; + } + + if (numWorkers_ > 0 && C.qpsPerConn % numWorkers_ != 0) { + WARN("NET/IB_P2P : pool[%d] qpsPerConn=%d not divisible by " + "workersPerPool=%d — some workers may starve for conns where " + "they own no QP", + ibDevN_, C.qpsPerConn, numWorkers_); + } + + flagcxResult_t res = flagcxWrapIbvCreateCq( + &shared_cq_, ctx, (int)C.sharedCqDepth, NULL, NULL, 0); + if (res != flagcxSuccess) { + WARN("NET/IB_P2P : pool[%d] failed to create shared CQ", ibDevN_); + shared_cq_ = nullptr; + return; + } + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] shared CQ created (depth=%zu, workers=%d, " + "shards=%d, qpsPerConn=%d)", + ibDevN_, C.sharedCqDepth, numWorkers_, numShards_, C.qpsPerConn); + + slice_queues_.resize(numShards_); + slice_locks_.reset(new std::mutex[numShards_]); + + workerQpIdx_.resize(numWorkers_); + workerQpCursor_.assign(numWorkers_, 0); + + transferThreads_.reserve(numWorkers_); + for (int t = 0; t < numWorkers_; t++) { + transferThreads_.emplace_back([this, t] { this->transferWorkerLoop(t); }); + } +} + +FlagcxWorkerPool::~FlagcxWorkerPool() { + running_.store(false, std::memory_order_release); + cv_.notify_all(); + for (auto &t : transferThreads_) { + if (t.joinable()) + t.join(); + } + // notifThread_ is joined explicitly via stopNotif() in EngineDestroy; + // by the time ~pool runs at process exit it should already be joined. + if (notifThread_.joinable()) { + notifThread_.join(); + } + // CQ destruction skipped: pool is process-lived; OS reclaims. } -static inline uint64_t sliceLocalVa(const FlagcxSlice *slice) { - return slice->srcVa; +void FlagcxWorkerPool::startNotif(FlagcxP2pEngine *engine) { + if (engine == nullptr) + return; + // Compare-exchange ensures only the first attach spawns the thread. + bool expected = false; + if (!notifSpawned_.compare_exchange_strong(expected, true)) { + // Already attached — keep existing engine pointer (or update if NULL). + if (engine_ == nullptr) + engine_ = engine; + return; + } + engine_ = engine; + notifThread_ = std::thread([this] { this->notifWorkerLoop(); }); + INFO(FLAGCX_INIT, "NET/IB_P2P : pool[%d] notifWorker spawned", ibDevN_); } -static inline uint64_t sliceRemoteVa(const FlagcxSlice *slice) { - return slice->dstVa; +void FlagcxWorkerPool::stopNotif() { + // Caller must have set engine_->stopNotif before calling — that breaks + // the epoll loop in notifPollThreadFunc. + if (notifThread_.joinable()) { + notifThread_.join(); + } + engine_ = nullptr; + notifSpawned_.store(false, std::memory_order_release); } -static void markSlicesFailed(std::shared_ptr task, - int firstSlice) { - if (!task) +void FlagcxWorkerPool::notifWorkerLoop() { + if (engine_ == nullptr) return; - for (size_t i = firstSlice; i < task->sliceList.size(); i++) - task->sliceList[i]->markFailed(); -} - -static bool asyncReadBatched(std::shared_ptr task, - struct flagcxNetAdaptor *adaptor) { - AsyncReadBatchEntry inflight[kWindowSize]; - const int numSlices = static_cast(task->sliceList.size()); - int issuedSlices = 0; - int completedSlices = 0; - int issuedBatches = 0; - int completedBatches = 0; - - while (completedSlices < numSlices) { - while (issuedSlices < numSlices && - issuedBatches - completedBatches < kWindowSize) { - const int count = std::min(kIgetBatchSize, numSlices - issuedSlices); - uint64_t srcOffs[kIgetBatchSize] = {}; - uint64_t dstOffs[kIgetBatchSize] = {}; - size_t sizes[kIgetBatchSize] = {}; - FlagcxP2pMrHandleView remoteMrs[kIgetBatchSize] = {}; - FlagcxP2pMrHandleView localMrs[kIgetBatchSize] = {}; - void *srcHandles[kIgetBatchSize] = {}; - void *dstHandles[kIgetBatchSize] = {}; - - for (int batchSlice = 0; batchSlice < count; batchSlice++) { - const int sliceIdx = issuedSlices + batchSlice; - FlagcxSlice *slice = task->sliceList[sliceIdx]; - if (!isReadSlice(slice)) - return false; - - remoteMrs[batchSlice].baseVa = sliceRemoteVa(slice); - remoteMrs[batchSlice].rkey = slice->rkey; - localMrs[batchSlice].baseVa = sliceLocalVa(slice); - localMrs[batchSlice].lkey = slice->lkey; - srcHandles[batchSlice] = &remoteMrs[batchSlice]; - dstHandles[batchSlice] = &localMrs[batchSlice]; - sizes[batchSlice] = slice->length; - } + // Reuse the original engine-side body — same behavior, just owned by + // the pool's thread. + notifPollThreadFunc(engine_); +} - void *request = nullptr; - flagcxResult_t rc = - adaptor->igetBatch(task->conn->sendComm, count, srcOffs, dstOffs, - sizes, 0, 0, srcHandles, dstHandles, &request); - if (rc != flagcxSuccess) - return false; - - AsyncReadBatchEntry &entry = inflight[issuedBatches % kWindowSize]; - entry.request = request; - entry.firstSlice = issuedSlices; - entry.count = count; - issuedSlices += count; - issuedBatches++; - } +void FlagcxWorkerPool::registerQp(void *sendComm, struct ibv_qp *qp) { + if (!qp || numWorkers_ <= 0) + return; - int newlyCompleted = 0; - if (adaptor->testBatch != nullptr) { - void *batchRequests[kWindowSize]; - int batchIndices[kWindowSize]; - int batchCount = 0; - for (int batch = completedBatches; batch < issuedBatches; batch++) { - AsyncReadBatchEntry &entry = inflight[batch % kWindowSize]; - if (entry.request != nullptr) { - batchRequests[batchCount] = entry.request; - batchIndices[batchCount] = batch; - batchCount++; - } - } + std::lock_guard lk(qp_mu_); - if (batchCount > 0) { - int doneFlags[kWindowSize]; - int doneCount = 0; - flagcxResult_t rc = adaptor->testBatch(batchRequests, batchCount, - doneFlags, &doneCount); - if (rc != flagcxSuccess) - return false; - - for (int i = 0; i < batchCount; i++) { - if (doneFlags[i]) { - AsyncReadBatchEntry &entry = - inflight[batchIndices[i] % kWindowSize]; - for (int s = 0; s < entry.count; s++) - task->sliceList[entry.firstSlice + s]->markSuccess(); - entry.request = nullptr; - newlyCompleted++; - } - } + if (maxWrDepth_ == 0) { + struct ibv_qp_attr attr; + struct ibv_qp_init_attr initAttr; + memset(&attr, 0, sizeof(attr)); + memset(&initAttr, 0, sizeof(initAttr)); + if (flagcxWrapIbvQueryQp(qp, &attr, IBV_QP_CAP, &initAttr) == + flagcxSuccess) { + int cap = (int)initAttr.cap.max_send_wr; + if (cap > 0) { + maxWrDepth_ = cap; + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] resolved max_wr_depth=%d from first QP", + ibDevN_, cap); } } else { - for (int batch = completedBatches; batch < issuedBatches; batch++) { - AsyncReadBatchEntry &entry = inflight[batch % kWindowSize]; - if (entry.request == nullptr) - continue; - - int done = 0; - int sizes = 0; - flagcxResult_t rc = adaptor->test(entry.request, &done, &sizes); - if (rc != flagcxSuccess) - return false; - if (done) { - for (int s = 0; s < entry.count; s++) - task->sliceList[entry.firstSlice + s]->markSuccess(); - entry.request = nullptr; - newlyCompleted++; - } - } + WARN("NET/IB_P2P : pool[%d] ibv_query_qp failed; max_wr_depth " + "stays unresolved (slice posts will fall back to no gate)", + ibDevN_); } + } - while (completedBatches < issuedBatches) { - AsyncReadBatchEntry &entry = inflight[completedBatches % kWindowSize]; - if (entry.request != nullptr) - break; - completedSlices += entry.count; - completedBatches++; + int idx = (int)qpEntries_.size(); + qpEntries_.emplace_back(new PoolQpEntry(qp, sendComm)); + qpNumToIdx_[qp->qp_num] = idx; + + int slot = + qpRegisterCounter_.fetch_add(1, std::memory_order_relaxed) % + numWorkers_; + workerQpIdx_[slot].push_back(idx); +} + +void FlagcxWorkerPool::unregisterQp(struct ibv_qp *qp) { + if (!qp) + return; + std::lock_guard lk(qp_mu_); + auto it = qpNumToIdx_.find(qp->qp_num); + if (it == qpNumToIdx_.end()) + return; + int idx = it->second; + qpNumToIdx_.erase(it); + for (auto &shard : workerQpIdx_) { + auto vit = std::find(shard.begin(), shard.end(), idx); + if (vit != shard.end()) { + shard.erase(vit); + break; } + } + // Slot kept alive (NULL'd) so any in-flight slice's qpDepth pointer stays valid. + qpEntries_[idx]->qp = nullptr; + qpEntries_[idx]->sendComm = nullptr; +} - if (newlyCompleted == 0 && issuedSlices >= numSlices) - std::this_thread::sleep_for(std::chrono::microseconds(1)); +flagcxResult_t FlagcxWorkerPool::submitPostSend(void *sendComm, + FlagcxSlice **slices, + int count) { + if (count <= 0 || slices == nullptr) + return flagcxSuccess; + + int shard = (int)(shardRoundRobin_.fetch_add(1, std::memory_order_relaxed) % + numShards_); + int enqueued = 0; + { + std::lock_guard lk(slice_locks_[shard]); + auto &q = slice_queues_[shard]; + q.reserve(q.size() + count); + for (int i = 0; i < count; i++) { + if (slices[i] == nullptr) + continue; + q.push_back({sendComm, slices[i]}); + enqueued++; + } } + if (enqueued == 0) + return flagcxSuccess; + submitted_.fetch_add(enqueued, std::memory_order_release); - return true; + if (suspended_flag_.load(std::memory_order_acquire) > 0) { + std::lock_guard lk(cv_mu_); + cv_.notify_all(); + } + return flagcxSuccess; } -static void asyncWorkerFunc() { - while (true) { - std::shared_ptr task; - pthread_mutex_lock(&gAsyncWorker.mutex); - while (gAsyncWorker.queue.empty() && !gAsyncWorker.stop.load()) { - pthread_cond_wait(&gAsyncWorker.cv, &gAsyncWorker.mutex); - } - if (gAsyncWorker.stop.load() && gAsyncWorker.queue.empty()) { - pthread_mutex_unlock(&gAsyncWorker.mutex); - return; - } - task = gAsyncWorker.queue.front(); - gAsyncWorker.queue.pop_front(); - pthread_mutex_unlock(&gAsyncWorker.mutex); - - FlagcxP2pConn *conn = task->conn; - struct flagcxNetAdaptor *adaptor = conn->engine->adaptor; - const int numSlices = static_cast(task->sliceList.size()); - - if (numSlices == 0) { - task->result.store(-1, std::memory_order_release); - task->sliceCount.store(1, std::memory_order_release); - task->doneSliceCount.store(1, std::memory_order_release); +void FlagcxWorkerPool::transferWorkerLoop(int tid) { + const static uint64_t kWaitPeriodInNano = 100ull * 1000 * 1000; // 100ms + uint64_t last_wait_ts = nowNs(); + + while (running_.load(std::memory_order_relaxed)) { + auto processed_slice_count = + processed_.load(std::memory_order_relaxed); + auto submitted_slice_count = + submitted_.load(std::memory_order_relaxed); + + if (processed_slice_count == submitted_slice_count) { + uint64_t curr_wait_ts = nowNs(); + if (curr_wait_ts - last_wait_ts > kWaitPeriodInNano) { + std::unique_lock lock(cv_mu_); + suspended_flag_.fetch_add(1); + if (processed_.load(std::memory_order_relaxed) == + submitted_.load(std::memory_order_relaxed)) { + cv_.wait_for(lock, std::chrono::seconds(1)); + } + suspended_flag_.fetch_sub(1); + last_wait_ts = curr_wait_ts; + } continue; } - if (isReadSlice(task->sliceList[0]) && adaptor->igetBatch != nullptr) { - bool ok = asyncReadBatched(task, adaptor); - if (!ok) - markSlicesFailed(task, 0); - task->result.store(ok ? 0 : -1, std::memory_order_release); - continue; - } + performPostSend(tid); + performPollCq(); + } +} - std::vector inflightReqs(kWindowSize, nullptr); - int issued = 0, completed = 0; - bool error = false; - - while (completed < numSlices && !error) { - // Post up to kWindowSize ahead of completed - while (issued < numSlices && (issued - completed) < kWindowSize) { - FlagcxSlice *slice = task->sliceList[issued]; - FlagcxP2pMrHandleView localMr; - FlagcxP2pMrHandleView remoteMr; - memset(&localMr, 0, sizeof(localMr)); - memset(&remoteMr, 0, sizeof(remoteMr)); - - localMr.baseVa = sliceLocalVa(slice); - localMr.lkey = slice->lkey; - remoteMr.baseVa = sliceRemoteVa(slice); - remoteMr.rkey = slice->rkey; - - void *request = NULL; - flagcxResult_t rc; - - if (isReadSlice(slice)) { - const uint64_t srcOff = 0; - const uint64_t dstOff = 0; - rc = - adaptor->iget(conn->sendComm, srcOff, dstOff, slice->length, 0, 0, - (void **)&remoteMr, (void **)&localMr, &request); - } else { - const uint64_t srcOff = 0; - const uint64_t dstOff = 0; - rc = - adaptor->iput(conn->sendComm, srcOff, dstOff, slice->length, 0, 0, - (void **)&localMr, (void **)&remoteMr, &request); - } +void FlagcxWorkerPool::performPostSend(int tid) { + if (numWorkers_ <= 0) + return; - if (rc != flagcxSuccess) { - error = true; - break; - } + std::vector myQpEntries; + int curMaxDepth; + { + std::lock_guard lk(qp_mu_); + myQpEntries.reserve(workerQpIdx_[tid].size()); + for (int idx : workerQpIdx_[tid]) + myQpEntries.push_back(qpEntries_[idx].get()); + curMaxDepth = maxWrDepth_; + } + + for (int s = tid; s < numShards_; s += numWorkers_) { + std::vector local; + { + std::lock_guard lk(slice_locks_[s]); + if (slice_queues_[s].empty()) + continue; + local.swap(slice_queues_[s]); + } + if (local.empty()) + continue; - inflightReqs[issued % kWindowSize] = request; - issued++; + size_t i = 0; + while (i < local.size()) { + size_t j = i + 1; + while (j < local.size() && local[j].sendComm == local[i].sendComm && + local[j].slice->opcode == local[i].slice->opcode) { + j++; + } + void *sc = local[i].sendComm; + std::vector myQpOnComm; + myQpOnComm.reserve(myQpEntries.size()); + for (PoolQpEntry *e : myQpEntries) { + if (e && e->qp != nullptr && e->sendComm == sc) + myQpOnComm.push_back(e); + } + if (myQpOnComm.empty()) { + WARN("NET/IB_P2P : pool[%d] worker %d owns no QP for conn %p; " + "failing %zu slices", + ibDevN_, tid, sc, j - i); + for (size_t k = i; k < j; k++) + local[k].slice->markFailed(); + processed_.fetch_add(j - i, std::memory_order_release); + i = j; + continue; } - // Batch-poll completions for in-flight requests - int newlyCompleted = 0; - - if (adaptor->testBatch != nullptr) { - // Collect non-null in-flight requests for batch testing - void *batchRequests[kWindowSize]; - int batchIndices[kWindowSize]; - int batchCount = 0; - - for (int i = completed; i < issued; i++) { - int slot = i % kWindowSize; - if (inflightReqs[slot] != nullptr) { - batchRequests[batchCount] = inflightReqs[slot]; - batchIndices[batchCount] = i; - batchCount++; + size_t &cursor = workerQpCursor_[tid]; + const size_t ringSz = myQpOnComm.size(); + + while (i < j) { + const size_t take = std::min(maxWrPerPost_, j - i); + PoolQpEntry *chosen = nullptr; + for (size_t k = 0; k < ringSz; k++) { + PoolQpEntry *e = myQpOnComm[(cursor + k) % ringSz]; + int cur = e->wrDepth; + if (curMaxDepth == 0 || cur + (int)take <= curMaxDepth) { + chosen = e; + cursor = (cursor + k + 1) % ringSz; + break; } } - if (batchCount > 0) { - int doneFlags[kWindowSize]; - int doneCount = 0; - flagcxResult_t res = adaptor->testBatch(batchRequests, batchCount, - doneFlags, &doneCount); - if (res != flagcxSuccess) { - error = true; - } else { - for (int b = 0; b < batchCount; b++) { - if (doneFlags[b]) { - int i = batchIndices[b]; - int slot = i % kWindowSize; - task->sliceList[i]->markSuccess(); - inflightReqs[slot] = nullptr; - newlyCompleted++; - } - } + if (chosen == nullptr) { + { + std::lock_guard lk(slice_locks_[s]); + for (size_t k = i; k < j; k++) + slice_queues_[s].push_back(local[k]); } + i = j; + break; } - } else { - // Fallback: per-request polling - for (int i = completed; i < issued; i++) { - int slot = i % kWindowSize; - if (inflightReqs[slot] == nullptr) { - continue; - } - int done = 0, sizes = 0; - flagcxResult_t res = adaptor->test(inflightReqs[slot], &done, &sizes); - if (res != flagcxSuccess) { - task->sliceList[i]->markFailed(); - inflightReqs[slot] = nullptr; - newlyCompleted++; - continue; - } - if (done) { - task->sliceList[i]->markSuccess(); - inflightReqs[slot] = nullptr; - newlyCompleted++; - } + + volatile int *depthPtr = &chosen->wrDepth; + __sync_fetch_and_add(depthPtr, (int)take); + + std::vector chunk; + chunk.reserve(take); + for (size_t k = 0; k < take; k++) { + FlagcxSlice *sl = local[i + k].slice; + sl->qpDepth = depthPtr; + chunk.push_back(sl); } - } - // Advance completed pointer over contiguous completions - while (completed < issued && - inflightReqs[completed % kWindowSize] == nullptr) { - completed++; - } + struct ibv_qp *qp = chosen->qp; + flagcxResult_t rc = flagcxP2pSliceBatch(sc, qp, (int)take, + chunk.data()); - // Yield briefly if no progress was made - if (newlyCompleted == 0 && issued >= numSlices) { - std::this_thread::sleep_for(std::chrono::microseconds(1)); + if (rc != flagcxSuccess) { + processed_.fetch_add(take, std::memory_order_release); + } + i += take; } } + } +} + +void FlagcxWorkerPool::performPollCq() { + if (shared_cq_ == nullptr) + return; - if (error) { - markSlicesFailed(task, completed); - task->result.store(-1, std::memory_order_release); + constexpr int kMaxPollBatch = 256; + struct ibv_wc wcs[kMaxPollBatch]; + int batch = (int)std::min(batchPollSize_, kMaxPollBatch); + int n = 0; + if (flagcxWrapIbvPollCq(shared_cq_, batch, wcs, &n) != flagcxSuccess) { + WARN("NET/IB_P2P : ibv_poll_cq failed on shared CQ %p", shared_cq_); + return; + } + if (n == 0) + return; + + uint64_t sliceProgressed = 0; + std::unordered_map qpDepthSet; + for (int i = 0; i < n; i++) { + uintptr_t raw = (uintptr_t)wcs[i].wr_id; + if (raw == 0 || (raw & 1ull) == 0) + continue; + + FlagcxSlice *slice = + reinterpret_cast(raw & ~(uintptr_t)1ull); + if (slice->qpDepth != NULL) + qpDepthSet[slice->qpDepth]++; + if (wcs[i].status != IBV_WC_SUCCESS) { + WARN("NET/IB_P2P : pool poll error status %d for slice %p", + wcs[i].status, slice); + slice->markFailed(); + } else { + slice->markSuccess(); } + sliceProgressed++; + } + for (auto &entry : qpDepthSet) + __sync_fetch_and_sub(entry.first, entry.second); + if (sliceProgressed > 0) + processed_.fetch_add(sliceProgressed, std::memory_order_release); +} + +// ---- Per-ibDev singleton plumbing ----------------------------------- + +static std::unique_ptr gPools[MAX_IB_DEVS]; +static std::mutex gPoolMu; + +static FlagcxWorkerPool *getOrCreatePool(int ibDevN, struct ibv_context *ctx) { + if (ibDevN < 0 || ibDevN >= MAX_IB_DEVS || ctx == NULL) + return NULL; + std::lock_guard lk(gPoolMu); + if (!gPools[ibDevN]) + gPools[ibDevN].reset(new FlagcxWorkerPool(ibDevN, ctx)); + return gPools[ibDevN].get(); +} + +static FlagcxWorkerPool *lookupPool(int ibDevN) { + if (ibDevN < 0 || ibDevN >= MAX_IB_DEVS) + return nullptr; + std::lock_guard lk(gPoolMu); + return gPools[ibDevN].get(); +} + +} // namespace + +// ---- Hooks consumed by ibrc_p2p_adaptor.cc (forward-declared there). ---- +struct ibv_cq *flagcxP2pPoolGetSharedCq(int ibDevN, struct ibv_context *ctx) { + FlagcxWorkerPool *pool = getOrCreatePool(ibDevN, ctx); + return pool ? pool->getSharedCq() : NULL; +} + +void flagcxP2pPoolRegisterQp(int ibDevN, void *sendComm, struct ibv_qp *qp) { + if (qp == nullptr) + return; + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool) + pool->registerQp(sendComm, qp); +} + +void flagcxP2pPoolUnregisterQp(int ibDevN, struct ibv_qp *qp) { + if (qp == nullptr) + return; + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool) + pool->unregisterQp(qp); +} + +flagcxResult_t flagcxP2pPoolSubmit(int ibDevN, void *sendComm, + FlagcxSlice **slices, int count) { + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool == nullptr) { + WARN("NET/IB_P2P : flagcxP2pPoolSubmit on uninitialized pool[%d]", ibDevN); + return flagcxInternalError; } + return pool->submitPostSend(sendComm, slices, count); } -static std::mutex gAsyncWorkerLifecycleMutex; +void flagcxP2pPoolStartNotif(int ibDevN, struct ibv_context *ctx, + FlagcxP2pEngine *engine) { + FlagcxWorkerPool *pool = getOrCreatePool(ibDevN, ctx); + if (pool == nullptr) { + WARN("NET/IB_P2P : pool[%d] cannot be created for notif", ibDevN); + return; + } + pool->startNotif(engine); +} -static void ensureAsyncWorkerStarted() { - std::lock_guard lock(gAsyncWorkerLifecycleMutex); - if (gAsyncWorker.thread.joinable() && !gAsyncWorker.stop.load()) - return; // already running - // If previously stopped, join the old thread before restarting - if (gAsyncWorker.thread.joinable()) { - gAsyncWorker.thread.join(); +void flagcxP2pPoolStopNotif() { + // Stop notif on whichever pool currently owns it (only one does — the + // first that StartNotif touched). + std::lock_guard lk(gPoolMu); + for (int i = 0; i < MAX_IB_DEVS; i++) { + if (gPools[i]) + gPools[i]->stopNotif(); } - gAsyncWorker.stop.store(false); - gAsyncWorker.thread = std::thread(asyncWorkerFunc); } -static void stopAsyncWorker() { - std::lock_guard lock(gAsyncWorkerLifecycleMutex); - gAsyncWorker.stop.store(true); - pthread_cond_broadcast(&gAsyncWorker.cv); - if (gAsyncWorker.thread.joinable()) { - gAsyncWorker.thread.join(); +struct PoolTransferTask { + FlagcxTransferTask fx; + FlagcxP2pConn *conn; + std::atomic postOk{true}; +}; + +static FlagcxP2pCommView *getCommView(void *comm) { + return reinterpret_cast(comm); +} + +static bool buildAndSubmitToPool(PoolTransferTask *task, + const std::vector &dataVec, + const std::vector &sizeVec, + const std::vector &descs, + const std::vector + &localEntries, + int numIovs, void *sendComm, + int connIbDevN, uint8_t opcode) { + for (int i = 0; i < numIovs; i++) { + if (localEntries[i].ibDevN != connIbDevN) { + WARN("NET/IB_P2P : iov[%d] ibDevN mismatch (%d vs conn %d)", i, + localEntries[i].ibDevN, connIbDevN); + for (auto *s : task->fx.sliceList) + s->markFailed(); + return false; + } + auto *localMr = + reinterpret_cast(localEntries[i].mhandle); + uint64_t localVa = (uintptr_t)dataVec[i]; + uint64_t remoteVa = descs[i].addr; + flagcxBuildSlices( + &task->fx, localVa, remoteVa, sizeVec[i], localMr->lkey, + descs[i].rkey, opcode, std::string()); + } + + if (task->fx.sliceList.empty()) { + return false; } + + flagcxResult_t rc = flagcxP2pPoolSubmit(connIbDevN, sendComm, + task->fx.sliceList.data(), + (int)task->fx.sliceList.size()); + if (rc != flagcxSuccess) { + task->postOk.store(false, std::memory_order_release); + for (auto *s : task->fx.sliceList) + s->markFailed(); + return false; + } + return true; } -// Map from transfer ID to async task (for XferStatus polling) -static std::unordered_map> - gAsyncXferMap; -static std::mutex gAsyncXferMutex; +static std::unordered_map> + gPoolXferMap; +static std::mutex gPoolXferMutex; + static bool findMemReg(uintptr_t addr, FlagcxP2pMemRegEntry *out) { for (std::unordered_map::const_iterator it = @@ -1362,8 +1659,8 @@ FlagcxP2pEngine *flagcxP2pEngineCreate() { #endif } - if (engine->notifListenActive) { - engine->notifThread = std::thread(notifPollThreadFunc, engine); + if (engine->notifListenActive && flagcxNIbDevs > 0) { + flagcxP2pPoolStartNotif(0, flagcxIbDevs[0].context, engine); } return engine; } @@ -1372,16 +1669,12 @@ void flagcxP2pEngineDestroy(FlagcxP2pEngine *engine) { if (engine == NULL) return; - stopAsyncWorker(); - engine->stopNotif = true; if (engine->notifListenActive) { flagcxSocketClose(&engine->notifListenSock); engine->notifListenActive = false; } - if (engine->notifThread.joinable()) { - engine->notifThread.join(); - } + flagcxP2pPoolStopNotif(); { std::lock_guard lock(engine->notifPeerMutex); @@ -1817,12 +2110,7 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, return -1; } - std::shared_ptr task(new FlagcxTransferTask, - cleanupTransferTask); - task->conn = conn; - task->sliceList.reserve(numIovs); - const int connIbDevN = getCommView(conn->sendComm)->ibDevN; - + std::vector localEntries(numIovs); { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { @@ -1844,37 +2132,32 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, return -1; } - if (entry->ibDevN != connIbDevN) - return -1; - - FlagcxP2pMrHandleView *localMr = - reinterpret_cast(entry->mhandle); - if (buildSlices( - task.get(), - static_cast(reinterpret_cast(dstVec[i])), - descs[i].addr, sizeVec[i], localMr->lkey, descs[i].rkey, - IBV_WR_RDMA_READ) != 0) { - return -1; - } + localEntries[i] = *entry; } } - ensureAsyncWorkerStarted(); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + auto task = std::make_shared(); + task->conn = conn; - const uint64_t xferId = [&] { - std::lock_guard lock(gAsyncXferMutex); - uint64_t id = gNextXferId++; - gAsyncXferMap[id] = task; - return id; - }(); + if (!buildAndSubmitToPool(task.get(), dstVec, sizeVec, descs, localEntries, + numIovs, conn->sendComm, connIbDevN, + FLAGCX_SLICE_OP_READ)) { + // sentinel so isAllDone() converges (needs total>0) + auto *sentinel = new FlagcxSlice{0, 0, 0, 0, 0, FLAGCX_SLICE_OP_READ, + std::string(), &task->fx, nullptr}; + task->fx.sliceList.push_back(sentinel); + task->fx.sliceCount.fetch_add(1, std::memory_order_release); + sentinel->markFailed(); + task->postOk.store(false, std::memory_order_release); + } + uint64_t xferId; { - pthread_mutex_lock(&gAsyncWorker.mutex); - gAsyncWorker.queue.push_back(task); - pthread_mutex_unlock(&gAsyncWorker.mutex); + std::lock_guard lock(gPoolXferMutex); + xferId = gNextXferId++; + gPoolXferMap[xferId] = task; } - pthread_cond_signal(&gAsyncWorker.cv); - *transferId = xferId; return 0; } @@ -1962,12 +2245,7 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, if (mrIds.size() < static_cast(numIovs)) return -1; - std::shared_ptr task(new FlagcxTransferTask, - cleanupTransferTask); - task->conn = conn; - task->sliceList.reserve(numIovs); - const int connIbDevN = getCommView(conn->sendComm)->ibDevN; - + std::vector localEntries(numIovs); { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { @@ -1979,37 +2257,31 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, sizeVec[i])) return -1; - if (entry->ibDevN != connIbDevN) - return -1; - - FlagcxP2pMrHandleView *localMr = - reinterpret_cast(entry->mhandle); - if (buildSlices( - task.get(), - static_cast(reinterpret_cast(dstVec[i])), - descs[i].addr, sizeVec[i], localMr->lkey, descs[i].rkey, - IBV_WR_RDMA_WRITE) != 0) { - return -1; - } + localEntries[i] = *entry; } } - ensureAsyncWorkerStarted(); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + auto task = std::make_shared(); + task->conn = conn; - const uint64_t xferId = [&] { - std::lock_guard lock(gAsyncXferMutex); - uint64_t id = gNextXferId++; - gAsyncXferMap[id] = task; - return id; - }(); + if (!buildAndSubmitToPool(task.get(), dstVec, sizeVec, descs, localEntries, + numIovs, conn->sendComm, connIbDevN, + FLAGCX_SLICE_OP_WRITE)) { + auto *sentinel = new FlagcxSlice{0, 0, 0, 0, 0, FLAGCX_SLICE_OP_WRITE, + std::string(), &task->fx, nullptr}; + task->fx.sliceList.push_back(sentinel); + task->fx.sliceCount.fetch_add(1, std::memory_order_release); + sentinel->markFailed(); + task->postOk.store(false, std::memory_order_release); + } + uint64_t xferId; { - pthread_mutex_lock(&gAsyncWorker.mutex); - gAsyncWorker.queue.push_back(task); - pthread_mutex_unlock(&gAsyncWorker.mutex); + std::lock_guard lock(gPoolXferMutex); + xferId = gNextXferId++; + gPoolXferMap[xferId] = task; } - pthread_cond_signal(&gAsyncWorker.cv); - *transferId = xferId; return 0; } @@ -2051,13 +2323,16 @@ bool flagcxP2pEngineXferStatus(FlagcxP2pConn *conn, uint64_t transferId) { if (conn == NULL) return true; - // Check async transfer map first (for vectored transfers) { - std::lock_guard lock(gAsyncXferMutex); - auto it = gAsyncXferMap.find(transferId); - if (it != gAsyncXferMap.end()) { - if (it->second->isAllDone()) { - gAsyncXferMap.erase(it); + std::lock_guard lock(gPoolXferMutex); + auto it = gPoolXferMap.find(transferId); + if (it != gPoolXferMap.end()) { + auto &task = it->second; + if (task->fx.isAllDone()) { + for (auto *s : task->fx.sliceList) + delete s; + task->fx.sliceList.clear(); + gPoolXferMap.erase(it); return true; } return false; diff --git a/flagcx/include/flagcx_p2p.h b/flagcx/include/flagcx_p2p.h index 42226526..f636a3dc 100644 --- a/flagcx/include/flagcx_p2p.h +++ b/flagcx/include/flagcx_p2p.h @@ -26,6 +26,8 @@ #define FLAGCX_P2P_DESC_SIZE 64 #define FLAGCX_P2P_IPC_INFO_SIZE 128 +constexpr int kFlagcxP2pMaxQpsPerEngine = 8; + /* ------------------------------------------------------------------ */ /* Opaque handle types */ /* ------------------------------------------------------------------ */ @@ -424,4 +426,57 @@ int flagcxP2pEngineGetIpcInfo(FlagcxP2pEngine *engine, uintptr_t addr, int flagcxP2pEngineUpdateIpcInfo(char *ipcBuf, uintptr_t addr, uintptr_t baseAddr, size_t size); +/* ================================================================== */ +/* Global runtime configuration */ +/* ================================================================== */ + +struct FlagcxP2pGlobalConfig { + /* Worker pool / QP topology */ + int qpsPerConn = 4; /* FLAGCX_P2P_QPS_PER_CONN */ + int workersPerPool = 2; /* FLAGCX_P2P_WORKERS_PER_POOL */ + int shardCount = 8; /* FLAGCX_P2P_SHARD_COUNT */ + + /* CQ / WR / completion-queue depth */ + size_t sharedCqDepth = 4096; /* FLAGCX_P2P_CQ_DEPTH */ + size_t maxWrPerPost = 256; /* FLAGCX_P2P_MAX_WR_PER_POST */ + size_t maxRequests = 256; /* FLAGCX_P2P_MAX_REQUESTS */ + size_t batchPollSize = 32; /* FLAGCX_P2P_BATCH_POLL_SIZE */ + size_t readBatchWindow = 8; /* FLAGCX_P2P_READ_BATCH_WINDOW */ + + /* Slice cut policy */ + size_t sliceSize = 64 * 1024; /* FLAGCX_P2P_SLICE_SIZE */ + size_t fragmentLimit = 4 * 1024; /* FLAGCX_P2P_FRAGMENT_LIMIT */ + + /* IB QP attributes — verbs-clean (plain int) so this header does + not pull . */ + size_t maxSge = 4; /* FLAGCX_P2P_MAX_SGE */ + size_t maxInline = 64; /* FLAGCX_P2P_MAX_INLINE */ + uint8_t ibPort = 1; /* FLAGCX_P2P_IB_PORT */ + int gidIndex = -1; /* FLAGCX_P2P_GID_INDEX (-1=auto) */ + int mtuLength = 4096; /* FLAGCX_P2P_MTU */ + int ibTrafficClass = -1; /* FLAGCX_P2P_IB_TC (-1=off) */ + int retryCnt = 7; /* FLAGCX_P2P_RETRY_CNT */ + + /* Notification */ + int notifMaxPeers = 64; /* FLAGCX_P2P_NOTIF_MAX_PEERS */ + + /* Misc */ + bool enableDestDeviceAffinity = false; /* FLAGCX_P2P_DEST_DEV_AFFINITY */ +}; + +/* Returns the lazy-loaded singleton (mooncake::globalConfig() shape). + First call materializes the struct and parses env vars exactly once. */ +const FlagcxP2pGlobalConfig &flagcxP2pGlobalConfig(); + +/* Logs the resolved config once. Implicitly invoked at first + flagcxP2pGlobalConfig() call. */ +void flagcxP2pDumpGlobalConfig(); + +/* Clamp size-limited fields against ibv_query_device() results — call + once from the adaptor's init path after IB attributes are known. The + four uint32 inputs are the obvious ibv_device_attr counterparts; we + take plain ints to keep verbs out of this header. */ +void flagcxP2pClampToDeviceLimits(uint32_t maxQpWr, uint32_t maxSge, + uint32_t maxCqe, uint32_t maxQp); + #endif /* FLAGCX_P2P_H_ */ From 4fdea9e5f4b1290d21f66e78172cb90c5d7f904e Mon Sep 17 00:00:00 2001 From: leoda1 Date: Fri, 29 May 2026 11:41:45 +0800 Subject: [PATCH 09/11] move strcut FlagcxSlice from cc to h --- flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 41 ------------------- flagcx/core/flagcx_p2p.cc | 43 -------------------- flagcx/include/flagcx_p2p.h | 54 ++++++++++++++++++++++---- 3 files changed, 47 insertions(+), 91 deletions(-) diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index c8c3deab..3c915db0 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -27,8 +27,6 @@ #include #include -struct FlagcxSlice; - extern struct ibv_cq *flagcxP2pPoolGetSharedCq(int ibDevN, struct ibv_context *ctx); extern void flagcxP2pPoolRegisterQp(int ibDevN, void *sendComm, @@ -41,45 +39,6 @@ extern flagcxResult_t flagcxP2pPoolSubmit(int ibDevN, void *sendComm, /* Internal structs */ /* ------------------------------------------------------------------ */ -struct FlagcxTransferTask { - std::atomic sliceCount{0}; - std::atomic doneSliceCount{0}; - std::vector sliceList; - - bool isAllDone() const { - auto total = sliceCount.load(std::memory_order_acquire); - auto done = doneSliceCount.load(std::memory_order_acquire); - return total > 0 && done >= total; - } -}; - -enum FlagcxSliceOp : uint8_t { - FLAGCX_SLICE_OP_WRITE = 0, - FLAGCX_SLICE_OP_READ = 1, -}; - -struct FlagcxSlice { - uint64_t srcVa; - uint64_t dstVa; - uint32_t length; - uint32_t lkey; - uint32_t rkey; - uint8_t opcode; - std::string peerNicPath; - FlagcxTransferTask *task; - volatile int *qpDepth; - - inline void markSuccess() { - if (task) - task->doneSliceCount.fetch_add(1, std::memory_order_release); - } - - inline void markFailed() { - if (task) - task->doneSliceCount.fetch_add(1, std::memory_order_release); - } -}; - // Per-device context — created at init, holds eagerly allocated PD. // Passed as the `comm` parameter to regMr/deregMr when no connection exists. // ibDevN MUST be the first field so regMr can cast any comm pointer to extract diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index aec5bac9..1d9be096 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -44,8 +44,6 @@ extern struct flagcxNetAdaptor flagcxNetIbP2p; -struct FlagcxSlice; - extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, struct ibv_qp *qp, int count, FlagcxSlice **slices); @@ -301,47 +299,6 @@ static std::unordered_map gXferMap; static std::mutex gXferMutex; static uint64_t gNextXferId = 1; -struct FlagcxTransferTask { - std::atomic sliceCount{0}; - std::atomic doneSliceCount{0}; - std::vector sliceList; - - bool isAllDone() const { - auto total = sliceCount.load(std::memory_order_acquire); - auto done = doneSliceCount.load(std::memory_order_acquire); - return total > 0 && done >= total; - } -}; - -enum FlagcxSliceOp : uint8_t { - FLAGCX_SLICE_OP_WRITE = 0, - FLAGCX_SLICE_OP_READ = 1, -}; - -struct FlagcxSlice { - // WRITE: local source VA; READ: local destination VA. - uint64_t srcVa = 0; - // WRITE: remote destination VA; READ: remote source VA. - uint64_t dstVa; - uint32_t length; - uint32_t lkey; - uint32_t rkey; - uint8_t opcode; - std::string peerNicPath; - FlagcxTransferTask *task; - volatile int *qpDepth; - - inline void markSuccess() { - if (task) - task->doneSliceCount.fetch_add(1, std::memory_order_release); - } - - inline void markFailed() { - if (task) - task->doneSliceCount.fetch_add(1, std::memory_order_release); - } -}; - struct FlagcxNixlSlicePolicy { static constexpr bool kFurtherCut = false; static constexpr size_t kBlockSize = SIZE_MAX; diff --git a/flagcx/include/flagcx_p2p.h b/flagcx/include/flagcx_p2p.h index f636a3dc..b3181675 100644 --- a/flagcx/include/flagcx_p2p.h +++ b/flagcx/include/flagcx_p2p.h @@ -13,9 +13,11 @@ #ifndef FLAGCX_P2P_H_ #define FLAGCX_P2P_H_ +#include #include #include #include +#include #include /* ------------------------------------------------------------------ */ @@ -89,6 +91,51 @@ inline void flagcxP2pDeserializeRdmaDesc(const char *buf, std::memcpy(desc->padding, buf + 32, sizeof(desc->padding)); } +/* ------------------------------------------------------------------ */ +/* Slice / transfer-task types (shared by engine and adaptor) */ +/* ------------------------------------------------------------------ */ + +struct FlagcxTransferTask { + std::atomic sliceCount{0}; + std::atomic doneSliceCount{0}; + std::vector sliceList; + + bool isAllDone() const { + auto total = sliceCount.load(std::memory_order_acquire); + auto done = doneSliceCount.load(std::memory_order_acquire); + return total > 0 && done >= total; + } +}; + +enum FlagcxSliceOp : uint8_t { + FLAGCX_SLICE_OP_WRITE = 0, + FLAGCX_SLICE_OP_READ = 1, +}; + +struct FlagcxSlice { + // WRITE: local source VA; READ: local destination VA. + uint64_t srcVa = 0; + // WRITE: remote destination VA; READ: remote source VA. + uint64_t dstVa; + uint32_t length; + uint32_t lkey; + uint32_t rkey; + uint8_t opcode; + std::string peerNicPath; + FlagcxTransferTask *task; + volatile int *qpDepth; + + inline void markSuccess() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } + + inline void markFailed() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } +}; + /* ------------------------------------------------------------------ */ /* Notification message */ /* ------------------------------------------------------------------ */ @@ -472,11 +519,4 @@ const FlagcxP2pGlobalConfig &flagcxP2pGlobalConfig(); flagcxP2pGlobalConfig() call. */ void flagcxP2pDumpGlobalConfig(); -/* Clamp size-limited fields against ibv_query_device() results — call - once from the adaptor's init path after IB attributes are known. The - four uint32 inputs are the obvious ibv_device_attr counterparts; we - take plain ints to keep verbs out of this header. */ -void flagcxP2pClampToDeviceLimits(uint32_t maxQpWr, uint32_t maxSge, - uint32_t maxCqe, uint32_t maxQp); - #endif /* FLAGCX_P2P_H_ */ From b0018deac889e39bcf5b3a08e1e2593a33c00cc1 Mon Sep 17 00:00:00 2001 From: leoda1 Date: Fri, 29 May 2026 14:23:45 +0800 Subject: [PATCH 10/11] fix for compile --- flagcx/adaptor/net/ibrc_p2p_adaptor.cc | 2 +- flagcx/core/flagcx_p2p.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 3c915db0..348e3c69 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -717,7 +717,7 @@ static flagcxResult_t flagcxP2pIputSignal(void *, uint64_t, uint64_t, size_t, /* Slice batch: pool worker passes the chosen QP. wr_id = ptr|1. */ /* ------------------------------------------------------------------ */ -static inline uint32_t flagcxSliceOpcodeToVerbs(uint8_t op) { +static inline enum ibv_wr_opcode flagcxSliceOpcodeToVerbs(uint8_t op) { return op == FLAGCX_SLICE_OP_READ ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; } diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 1d9be096..8abbb7d0 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -341,6 +341,8 @@ inline void flagcxBuildSlices(FlagcxTransferTask *task, uint64_t srcVa, } } +static void notifPollThreadFunc(FlagcxP2pEngine *engine); + namespace { struct PoolSubmitItem { @@ -348,8 +350,6 @@ struct PoolSubmitItem { FlagcxSlice *slice; // owned by caller (engine ReadVector/WriteVector) }; -static void notifPollThreadFunc(FlagcxP2pEngine *engine); - struct PoolQpEntry { struct ibv_qp *qp; void *sendComm; // owning conn (flagcxP2pSendComm/RecvComm) From c27df24bdaa13a86ede4ac85b38d847abb622ab9 Mon Sep 17 00:00:00 2001 From: leoda1 Date: Fri, 29 May 2026 16:37:41 +0800 Subject: [PATCH 11/11] fix when QP complete create, flow into one worker --- flagcx/core/flagcx_p2p.cc | 220 ++++++++++++++++++------------------ flagcx/include/flagcx_p2p.h | 5 +- 2 files changed, 111 insertions(+), 114 deletions(-) diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 8abbb7d0..7f82961c 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -51,13 +51,12 @@ extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, namespace { FLAGCX_PARAM(P2pQpsPerConn, "P2P_QPS_PER_CONN", 4); -FLAGCX_PARAM(P2pWorkersPerPool, "P2P_WORKERS_PER_POOL", 2); +FLAGCX_PARAM(P2pWorkersPerPool, "P2P_WORKERS_PER_POOL", 4); FLAGCX_PARAM(P2pShardCount, "P2P_SHARD_COUNT", 8); FLAGCX_PARAM(P2pCqDepth, "P2P_CQ_DEPTH", 4096); FLAGCX_PARAM(P2pMaxWrPerPost, "P2P_MAX_WR_PER_POST", 64); FLAGCX_PARAM(P2pMaxRequests, "P2P_MAX_REQUESTS", 256); -FLAGCX_PARAM(P2pBatchPollSize, "P2P_BATCH_POLL_SIZE", 32); -FLAGCX_PARAM(P2pReadBatchWindow, "P2P_READ_BATCH_WINDOW", 8); +FLAGCX_PARAM(P2pBatchPollSize, "P2P_BATCH_POLL_SIZE", 64); FLAGCX_PARAM(P2pSliceSize, "P2P_SLICE_SIZE", 65536); FLAGCX_PARAM(P2pFragmentLimit, "P2P_FRAGMENT_LIMIT", 4096); FLAGCX_PARAM(P2pMaxSge, "P2P_MAX_SGE", 4); @@ -83,13 +82,12 @@ inline T clampParam(int64_t v, T lo, T hi, T deft, const char *name) { void loadGlobalConfig(FlagcxP2pGlobalConfig &c) { c.qpsPerConn = clampParam(flagcxParamP2pQpsPerConn(), 1, kFlagcxP2pMaxQpsPerEngine, 4, "P2P_QPS_PER_CONN"); - c.workersPerPool = clampParam(flagcxParamP2pWorkersPerPool(), 1, 8, 2, "P2P_WORKERS_PER_POOL"); + c.workersPerPool = clampParam(flagcxParamP2pWorkersPerPool(), 1, 8, 4, "P2P_WORKERS_PER_POOL"); c.shardCount = clampParam(flagcxParamP2pShardCount(), 1, 64, 8, "P2P_SHARD_COUNT"); c.sharedCqDepth = clampParam(flagcxParamP2pCqDepth(), 1, 1u<<20, 4096, "P2P_CQ_DEPTH"); c.maxWrPerPost = clampParam(flagcxParamP2pMaxWrPerPost(),1, 1024, 256, "P2P_MAX_WR_PER_POST"); c.maxRequests = clampParam(flagcxParamP2pMaxRequests(), 1, 1u<<16, 256, "P2P_MAX_REQUESTS"); - c.batchPollSize = clampParam(flagcxParamP2pBatchPollSize(), 1, 256, 32, "P2P_BATCH_POLL_SIZE"); - c.readBatchWindow = clampParam(flagcxParamP2pReadBatchWindow(), 1, 256, 8, "P2P_READ_BATCH_WINDOW"); + c.batchPollSize = clampParam(flagcxParamP2pBatchPollSize(), 1, 256, 64, "P2P_BATCH_POLL_SIZE"); c.sliceSize = clampParam(flagcxParamP2pSliceSize(), 1024, 1u<<26, 65536, "P2P_SLICE_SIZE"); c.fragmentLimit = clampParam(flagcxParamP2pFragmentLimit(), 0, c.sliceSize, 4096, "P2P_FRAGMENT_LIMIT"); c.maxSge = clampParam(flagcxParamP2pMaxSge(), 1, 32, 4, "P2P_MAX_SGE"); @@ -117,9 +115,6 @@ void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c); FlagcxP2pGlobalConfig &mutableGlobalConfig() { static FlagcxP2pGlobalConfig cfg; static std::once_flag once; - // Important: dumpGlobalConfigImpl reads cfg directly (no recursion back - // through mutableGlobalConfig), so the lambda is safe to call on the - // same thread that holds the once_flag. std::call_once(once, [] { loadGlobalConfig(cfg); dumpGlobalConfigImpl(cfg); @@ -133,10 +128,8 @@ void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c) { "qpsPerConn=%d workersPerPool=%d shardCount=%d", c.qpsPerConn, c.workersPerPool, c.shardCount); INFO(FLAGCX_INIT, - "sharedCqDepth=%zu maxWrPerPost=%zu maxRequests=%zu " - "batchPollSize=%zu readBatchWindow=%zu", - c.sharedCqDepth, c.maxWrPerPost, c.maxRequests, c.batchPollSize, - c.readBatchWindow); + "sharedCqDepth=%zu maxWrPerPost=%zu maxRequests=%zu batchPollSize=%zu", + c.sharedCqDepth, c.maxWrPerPost, c.maxRequests, c.batchPollSize); INFO(FLAGCX_INIT, "sliceSize=%zu fragmentLimit=%zu", c.sliceSize, c.fragmentLimit); INFO(FLAGCX_INIT, @@ -345,11 +338,6 @@ static void notifPollThreadFunc(FlagcxP2pEngine *engine); namespace { -struct PoolSubmitItem { - void *sendComm; // adaptor sendComm view - FlagcxSlice *slice; // owned by caller (engine ReadVector/WriteVector) -}; - struct PoolQpEntry { struct ibv_qp *qp; void *sendComm; // owning conn (flagcxP2pSendComm/RecvComm) @@ -404,11 +392,14 @@ class FlagcxWorkerPool { std::unordered_map qpNumToIdx_; std::vector> workerQpIdx_; std::vector workerQpCursor_; - std::atomic qpRegisterCounter_{0}; - - std::vector> slice_queues_; + std::unordered_map connQpRegCount_; + std::vector>> + slice_queues_; std::unique_ptr slice_locks_; + std::unique_ptr[]> slice_queue_count_; std::atomic shardRoundRobin_{0}; + std::vector>> + collective_slice_queue_; std::atomic submitted_{0}; std::atomic processed_{0}; @@ -443,9 +434,9 @@ FlagcxWorkerPool::FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx) } if (numWorkers_ > 0 && C.qpsPerConn % numWorkers_ != 0) { - WARN("NET/IB_P2P : pool[%d] qpsPerConn=%d not divisible by " - "workersPerPool=%d — some workers may starve for conns where " - "they own no QP", + WARN("NET/IB_P2P : pool[%d] qpsPerEngine=%d not divisible by " + "workersPerPool=%d — QPs spread per connection but unevenly; some " + "workers will own more QPs than others for each conn", ibDevN_, C.qpsPerConn, numWorkers_); } @@ -463,9 +454,13 @@ FlagcxWorkerPool::FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx) slice_queues_.resize(numShards_); slice_locks_.reset(new std::mutex[numShards_]); + slice_queue_count_.reset(new std::atomic[numShards_]); + for (int s = 0; s < numShards_; s++) + slice_queue_count_[s].store(0, std::memory_order_relaxed); workerQpIdx_.resize(numWorkers_); workerQpCursor_.assign(numWorkers_, 0); + collective_slice_queue_.resize(numWorkers_); transferThreads_.reserve(numWorkers_); for (int t = 0; t < numWorkers_; t++) { @@ -553,9 +548,8 @@ void FlagcxWorkerPool::registerQp(void *sendComm, struct ibv_qp *qp) { qpEntries_.emplace_back(new PoolQpEntry(qp, sendComm)); qpNumToIdx_[qp->qp_num] = idx; - int slot = - qpRegisterCounter_.fetch_add(1, std::memory_order_relaxed) % - numWorkers_; + int connIdx = connQpRegCount_[sendComm]++; + int slot = connIdx % numWorkers_; workerQpIdx_[slot].push_back(idx); } @@ -568,6 +562,7 @@ void FlagcxWorkerPool::unregisterQp(struct ibv_qp *qp) { return; int idx = it->second; qpNumToIdx_.erase(it); + void *sc = qpEntries_[idx]->sendComm; for (auto &shard : workerQpIdx_) { auto vit = std::find(shard.begin(), shard.end(), idx); if (vit != shard.end()) { @@ -575,6 +570,9 @@ void FlagcxWorkerPool::unregisterQp(struct ibv_qp *qp) { break; } } + auto cit = connQpRegCount_.find(sc); + if (cit != connQpRegCount_.end() && --cit->second <= 0) + connQpRegCount_.erase(cit); // Slot kept alive (NULL'd) so any in-flight slice's qpDepth pointer stays valid. qpEntries_[idx]->qp = nullptr; qpEntries_[idx]->sendComm = nullptr; @@ -586,22 +584,28 @@ flagcxResult_t FlagcxWorkerPool::submitPostSend(void *sendComm, if (count <= 0 || slices == nullptr) return flagcxSuccess; - int shard = (int)(shardRoundRobin_.fetch_add(1, std::memory_order_relaxed) % - numShards_); + std::vector> perShard(numShards_); int enqueued = 0; - { - std::lock_guard lk(slice_locks_[shard]); - auto &q = slice_queues_[shard]; - q.reserve(q.size() + count); - for (int i = 0; i < count; i++) { - if (slices[i] == nullptr) - continue; - q.push_back({sendComm, slices[i]}); - enqueued++; - } + for (int i = 0; i < count; i++) { + if (slices[i] == nullptr) + continue; + int shard = (int)(shardRoundRobin_.fetch_add(1, std::memory_order_relaxed) % + numShards_); + perShard[shard].push_back(slices[i]); + enqueued++; } if (enqueued == 0) return flagcxSuccess; + + for (int s = 0; s < numShards_; s++) { + if (perShard[s].empty()) + continue; + std::lock_guard lk(slice_locks_[s]); + auto &vec = slice_queues_[s][sendComm]; + vec.insert(vec.end(), perShard[s].begin(), perShard[s].end()); + slice_queue_count_[s].fetch_add((int)perShard[s].size(), + std::memory_order_relaxed); + } submitted_.fetch_add(enqueued, std::memory_order_release); if (suspended_flag_.load(std::memory_order_acquire) > 0) { @@ -645,6 +649,21 @@ void FlagcxWorkerPool::performPostSend(int tid) { if (numWorkers_ <= 0) return; + auto &local = collective_slice_queue_[tid]; + for (int s = tid; s < numShards_; s += numWorkers_) { + if (slice_queue_count_[s].load(std::memory_order_relaxed) == 0) + continue; + std::lock_guard lk(slice_locks_[s]); + for (auto &entry : slice_queues_[s]) { + if (entry.second.empty()) + continue; + auto &dst = local[entry.first]; + dst.insert(dst.end(), entry.second.begin(), entry.second.end()); + entry.second.clear(); + } + slice_queue_count_[s].store(0, std::memory_order_relaxed); + } + std::vector myQpEntries; int curMaxDepth; { @@ -655,89 +674,68 @@ void FlagcxWorkerPool::performPostSend(int tid) { curMaxDepth = maxWrDepth_; } - for (int s = tid; s < numShards_; s += numWorkers_) { - std::vector local; - { - std::lock_guard lk(slice_locks_[s]); - if (slice_queues_[s].empty()) - continue; - local.swap(slice_queues_[s]); + size_t &cursor = workerQpCursor_[tid]; + for (auto &entry : local) { + void *sc = entry.first; + auto &pending = entry.second; + if (pending.empty()) + continue; + + std::vector myQpOnComm; + myQpOnComm.reserve(myQpEntries.size()); + for (PoolQpEntry *e : myQpEntries) { + if (e && e->qp != nullptr && e->sendComm == sc) + myQpOnComm.push_back(e); } - if (local.empty()) + if (myQpOnComm.empty()) { + WARN("NET/IB_P2P : pool[%d] worker %d owns no QP for Engine %p; " + "failing %zu slices", + ibDevN_, tid, sc, pending.size()); + for (auto *sl : pending) + sl->markFailed(); + processed_.fetch_add(pending.size(), std::memory_order_release); + pending.clear(); continue; + } + const size_t ringSz = myQpOnComm.size(); size_t i = 0; - while (i < local.size()) { - size_t j = i + 1; - while (j < local.size() && local[j].sendComm == local[i].sendComm && - local[j].slice->opcode == local[i].slice->opcode) { - j++; - } - void *sc = local[i].sendComm; - std::vector myQpOnComm; - myQpOnComm.reserve(myQpEntries.size()); - for (PoolQpEntry *e : myQpEntries) { - if (e && e->qp != nullptr && e->sendComm == sc) - myQpOnComm.push_back(e); - } - if (myQpOnComm.empty()) { - WARN("NET/IB_P2P : pool[%d] worker %d owns no QP for conn %p; " - "failing %zu slices", - ibDevN_, tid, sc, j - i); - for (size_t k = i; k < j; k++) - local[k].slice->markFailed(); - processed_.fetch_add(j - i, std::memory_order_release); - i = j; - continue; - } - - size_t &cursor = workerQpCursor_[tid]; - const size_t ringSz = myQpOnComm.size(); - - while (i < j) { - const size_t take = std::min(maxWrPerPost_, j - i); - PoolQpEntry *chosen = nullptr; - for (size_t k = 0; k < ringSz; k++) { - PoolQpEntry *e = myQpOnComm[(cursor + k) % ringSz]; - int cur = e->wrDepth; - if (curMaxDepth == 0 || cur + (int)take <= curMaxDepth) { - chosen = e; - cursor = (cursor + k + 1) % ringSz; - break; - } - } - - if (chosen == nullptr) { - { - std::lock_guard lk(slice_locks_[s]); - for (size_t k = i; k < j; k++) - slice_queues_[s].push_back(local[k]); - } - i = j; + while (i < pending.size()) { + const size_t take = std::min(maxWrPerPost_, pending.size() - i); + PoolQpEntry *chosen = nullptr; + for (size_t k = 0; k < ringSz; k++) { + PoolQpEntry *e = myQpOnComm[(cursor + k) % ringSz]; + int cur = e->wrDepth; + if (curMaxDepth == 0 || cur + (int)take <= curMaxDepth) { + chosen = e; + cursor = (cursor + k + 1) % ringSz; break; } + } - volatile int *depthPtr = &chosen->wrDepth; - __sync_fetch_and_add(depthPtr, (int)take); - - std::vector chunk; - chunk.reserve(take); - for (size_t k = 0; k < take; k++) { - FlagcxSlice *sl = local[i + k].slice; - sl->qpDepth = depthPtr; - chunk.push_back(sl); - } + if (chosen == nullptr) + break; // all of this worker's QPs for the engine are full; retry later - struct ibv_qp *qp = chosen->qp; - flagcxResult_t rc = flagcxP2pSliceBatch(sc, qp, (int)take, - chunk.data()); + volatile int *depthPtr = &chosen->wrDepth; + __sync_fetch_and_add(depthPtr, (int)take); - if (rc != flagcxSuccess) { - processed_.fetch_add(take, std::memory_order_release); - } - i += take; + std::vector chunk; + chunk.reserve(take); + for (size_t k = 0; k < take; k++) { + FlagcxSlice *sl = pending[i + k]; + sl->qpDepth = depthPtr; + chunk.push_back(sl); } + + flagcxResult_t rc = + flagcxP2pSliceBatch(sc, chosen->qp, (int)take, chunk.data()); + if (rc != flagcxSuccess) + processed_.fetch_add(take, std::memory_order_release); + i += take; } + // Drop the posted prefix; anything left stays for the next iteration. + if (i > 0) + pending.erase(pending.begin(), pending.begin() + i); } } diff --git a/flagcx/include/flagcx_p2p.h b/flagcx/include/flagcx_p2p.h index b3181675..3a5be2b8 100644 --- a/flagcx/include/flagcx_p2p.h +++ b/flagcx/include/flagcx_p2p.h @@ -480,15 +480,14 @@ int flagcxP2pEngineUpdateIpcInfo(char *ipcBuf, uintptr_t addr, struct FlagcxP2pGlobalConfig { /* Worker pool / QP topology */ int qpsPerConn = 4; /* FLAGCX_P2P_QPS_PER_CONN */ - int workersPerPool = 2; /* FLAGCX_P2P_WORKERS_PER_POOL */ + int workersPerPool = 4; /* FLAGCX_P2P_WORKERS_PER_POOL */ int shardCount = 8; /* FLAGCX_P2P_SHARD_COUNT */ /* CQ / WR / completion-queue depth */ size_t sharedCqDepth = 4096; /* FLAGCX_P2P_CQ_DEPTH */ size_t maxWrPerPost = 256; /* FLAGCX_P2P_MAX_WR_PER_POST */ size_t maxRequests = 256; /* FLAGCX_P2P_MAX_REQUESTS */ - size_t batchPollSize = 32; /* FLAGCX_P2P_BATCH_POLL_SIZE */ - size_t readBatchWindow = 8; /* FLAGCX_P2P_READ_BATCH_WINDOW */ + size_t batchPollSize = 64; /* FLAGCX_P2P_BATCH_POLL_SIZE */ /* Slice cut policy */ size_t sliceSize = 64 * 1024; /* FLAGCX_P2P_SLICE_SIZE */