diff --git a/flagcx/adaptor/include/flagcx_net_adaptor.h b/flagcx/adaptor/include/flagcx_net_adaptor.h index 169e5571..248e5c32 100644 --- a/flagcx/adaptor/include/flagcx_net_adaptor.h +++ b/flagcx/adaptor/include/flagcx_net_adaptor.h @@ -25,6 +25,7 @@ typedef enum { // regMr, regMrDmaBuf, deregMr, isend, irecv, iflush, test, // iput, iget, iputSignal, getDevFromName // v2 — adds iputBatch (optional one-sided batch WRITE) +// latest — adds optional batch helpers for one-sided transfers struct flagcxNetAdaptor_v1 { // Basic functions @@ -135,6 +136,17 @@ struct flagcxNetAdaptor_latest { const size_t *sizes, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **requests, int *posted); + // Optional batch completion test — polls CQ once for multiple requests. + // If NULL, caller falls back to per-request test(). + flagcxResult_t (*testBatch)(void **requests, int nRequests, int *doneFlags, + int *doneCount); + // Optional one-side batch READ. Success returns one logical request for the + // full batch. + flagcxResult_t (*igetBatch)(void *sendComm, int count, + const uint64_t *srcOffs, const uint64_t *dstOffs, + const size_t *sizes, int srcRank, int dstRank, + void *const *srcHandles, void *const *dstHandles, + void **request); }; #define flagcxNetAdaptor flagcxNetAdaptor_latest diff --git a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc index 4860a20e..348e3c69 100644 --- a/flagcx/adaptor/net/ibrc_p2p_adaptor.cc +++ b/flagcx/adaptor/net/ibrc_p2p_adaptor.cc @@ -4,22 +4,36 @@ * IBRC P2P Net Adaptor — implements flagcxNetAdaptor for one-sided RDMA * (P2P) use cases. Shares IB device discovery and utility code with the * existing IBRC adaptor but uses P2P-native handle formats, eager PD - * allocation, and simplified (single-QP, no-FIFO) connection setup. + * allocation, and simplified (no-FIFO) connection setup. ************************************************************************/ #include "flagcx_common.h" #include "flagcx_net_adaptor.h" +#include "flagcx_p2p.h" #include "ib_common.h" #include "ibvwrap.h" #include "socket.h" #include #include +#include #include +#include #include +#include +#include #include #include #include +#include + +extern struct ibv_cq *flagcxP2pPoolGetSharedCq(int ibDevN, + struct ibv_context *ctx); +extern void flagcxP2pPoolRegisterQp(int ibDevN, void *sendComm, + struct ibv_qp *qp); +extern void flagcxP2pPoolUnregisterQp(int ibDevN, struct ibv_qp *qp); +extern flagcxResult_t flagcxP2pPoolSubmit(int ibDevN, void *sendComm, + FlagcxSlice **slices, int count); /* ------------------------------------------------------------------ */ /* Internal structs */ @@ -67,39 +81,28 @@ struct flagcxP2pConnMeta { enum ibv_mtu mtu; }; -// P2P request — simplified from flagcxIbRequest -#define FLAGCX_P2P_MAX_REQUESTS 128 -#define FLAGCX_P2P_REQ_UNUSED 0 -#define FLAGCX_P2P_REQ_IPUT 1 -#define FLAGCX_P2P_REQ_IGET 2 - -struct flagcxP2pRequest { - int type; - int events; // outstanding CQEs expected - struct ibv_cq *cq; // CQ to poll for this request - struct flagcxP2pRequest *reqs; // back-pointer to owning reqs[] array +struct flagcxP2pSliceReq { + FlagcxTransferTask task; + FlagcxSlice slice; }; -// P2P send comm — one QP, one CQ, blocking connect +// Field order through `sock` mirrors core's FlagcxP2pCommView — do not reorder. struct flagcxP2pSendComm { - int ibDevN; // MUST be first field + int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; - struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; - uint64_t putSignalScratchpad; - struct ibv_mr *putSignalScratchpadMr; + std::atomic nextChannel{0}; + int numQps{0}; // resolved from flagcxP2pGlobalConfig().qpsPerConn at connect/accept }; -// P2P recv comm — symmetric with send comm so both sides can initiate transfers struct flagcxP2pRecvComm { - int ibDevN; // MUST be first field + int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; - struct flagcxP2pRequest reqs[FLAGCX_P2P_MAX_REQUESTS]; - uint64_t putSignalScratchpad; - struct ibv_mr *putSignalScratchpadMr; + std::atomic nextChannel{0}; + int numQps{0}; // resolved from flagcxP2pGlobalConfig().qpsPerConn at connect/accept }; /* ------------------------------------------------------------------ */ @@ -110,32 +113,6 @@ static struct flagcxP2pDevCtx flagcxP2pDevCtxs[MAX_IB_DEVS]; static int flagcxP2pInitialized = 0; static pthread_mutex_t flagcxP2pInitLock = PTHREAD_MUTEX_INITIALIZER; -/* ------------------------------------------------------------------ */ -/* Request helpers */ -/* ------------------------------------------------------------------ */ - -static flagcxResult_t flagcxP2pGetRequest(struct flagcxP2pRequest *reqs, - struct ibv_cq *cq, int type, - struct flagcxP2pRequest **req) { - for (int i = 0; i < FLAGCX_P2P_MAX_REQUESTS; i++) { - if (reqs[i].type == FLAGCX_P2P_REQ_UNUSED) { - reqs[i].type = type; - reqs[i].events = 0; - reqs[i].cq = cq; - reqs[i].reqs = reqs; - *req = &reqs[i]; - return flagcxSuccess; - } - } - WARN("NET/IB_P2P : unable to allocate request"); - *req = NULL; - return flagcxInternalError; -} - -static inline void flagcxP2pFreeRequest(struct flagcxP2pRequest *req) { - req->type = FLAGCX_P2P_REQ_UNUSED; -} - /* ------------------------------------------------------------------ */ /* Init / Devices / Properties */ /* ------------------------------------------------------------------ */ @@ -275,11 +252,13 @@ static flagcxResult_t flagcxP2pListen(int dev, void *opaqueHandle, return flagcxSuccess; } -// Helper: set up PD (from eager init), CQ, QP, and GID for a connection -static flagcxResult_t flagcxP2pSetupConn(int dev, +static flagcxResult_t flagcxP2pReleasePd(int ibDevN); + +// Helper: set up PD (from eager init), CQs, QPs, and GID for a connection +static flagcxResult_t flagcxP2pSetupConn(int dev, void *outerComm, struct flagcxIbNetCommDevBase *base, - struct flagcxIbQp *qp, - int *outIbDevN) { + struct flagcxIbQp *qp_list, + int *outIbDevN, int numQps) { struct flagcxIbMergedDev *mergedDev = flagcxIbMergedDevs + dev; int ibDevN = mergedDev->devs[0]; // v1: single physical NIC *outIbDevN = ibDevN; @@ -293,26 +272,56 @@ static flagcxResult_t flagcxP2pSetupConn(int dev, base->pd = ibDev->pd; pthread_mutex_unlock(&ibDev->lock); - // Create CQ for this connection - FLAGCXCHECK(flagcxWrapIbvCreateCq( - &base->cq, ibDev->context, 2 * FLAGCX_P2P_MAX_REQUESTS, NULL, NULL, 0)); + // Step 0: pull the shared CQ from the per-ibDev WorkerPool. The pool is + // lazily created on first call (and lives for the process lifetime). + struct ibv_cq *sharedCq = + flagcxP2pPoolGetSharedCq(ibDevN, ibDev->context); + if (sharedCq == NULL) { + WARN("NET/IB_P2P : pool[%d] returned NULL shared CQ", ibDevN); + flagcxP2pReleasePd(ibDevN); + base->pd = NULL; + return flagcxInternalError; + } + base->cq = sharedCq; + + int accessFlags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_REMOTE_ATOMIC; // Get GID info - FLAGCXCHECK(flagcxIbGetGidIndex(ibDev->context, ibDev->portNum, - ibDev->portAttr.gid_tbl_len, - &base->gidInfo.localGidIndex)); - FLAGCXCHECK(flagcxWrapIbvQueryGid(ibDev->context, ibDev->portNum, - base->gidInfo.localGidIndex, - &base->gidInfo.localGid)); + flagcxResult_t res; + FLAGCXCHECKGOTO(flagcxIbGetGidIndex(ibDev->context, ibDev->portNum, + ibDev->portAttr.gid_tbl_len, + &base->gidInfo.localGidIndex), + res, setup_fail); + FLAGCXCHECKGOTO(flagcxWrapIbvQueryGid(ibDev->context, ibDev->portNum, + base->gidInfo.localGidIndex, + &base->gidInfo.localGid), + res, setup_fail); base->gidInfo.linkLayer = ibDev->link; - // Create RC QP with remote write, read, and atomic access - int accessFlags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | - IBV_ACCESS_REMOTE_ATOMIC; - FLAGCXCHECK(flagcxIbCreateQp(ibDev->portNum, base, accessFlags, qp)); - qp->devIndex = 0; + for (int i = 0; i < numQps; i++) { + FLAGCXCHECKGOTO( + flagcxIbCreateQp(ibDev->portNum, base, accessFlags, &qp_list[i]), + res, setup_fail); + qp_list[i].devIndex = 0; + flagcxP2pPoolRegisterQp(ibDevN, outerComm, qp_list[i].qp); + } return flagcxSuccess; + +setup_fail: + for (int i = 0; i < numQps; i++) { + if (qp_list[i].qp) { + flagcxP2pPoolUnregisterQp(ibDevN, qp_list[i].qp); + flagcxWrapIbvDestroyQp(qp_list[i].qp); + qp_list[i].qp = NULL; + } + } + // Do not destroy sharedCq — owned by the pool. + base->cq = NULL; + flagcxP2pReleasePd(ibDevN); + base->pd = NULL; + return res; } // Helper: build local connection metadata @@ -355,63 +364,128 @@ flagcxP2pTransitionQp(struct flagcxIbQp *qp, return flagcxSuccess; } +static flagcxResult_t flagcxP2pDestroyQps(int ibDevN, + struct flagcxIbQp *qp_list, + int numQps) { + for (int i = 0; i < numQps; i++) { + if (qp_list[i].qp) { + flagcxP2pPoolUnregisterQp(ibDevN, qp_list[i].qp); + FLAGCXCHECK(flagcxWrapIbvDestroyQp(qp_list[i].qp)); + qp_list[i].qp = NULL; + } + } + return flagcxSuccess; +} + +static inline struct flagcxIbQp * +flagcxP2pNextQp(struct flagcxIbQp *qp_list, + std::atomic *nextChannel, int qpCount) { + uint32_t mod = (qpCount > 0) ? (uint32_t)qpCount : 1u; + uint32_t idx = nextChannel->fetch_add(1, std::memory_order_relaxed); + return qp_list + (idx % mod); +} + static flagcxResult_t flagcxP2pConnect(int dev, void *opaqueHandle, void **sendComm) { struct flagcxP2pListenHandle *handle = (struct flagcxP2pListenHandle *)opaqueHandle; + flagcxResult_t res; *sendComm = NULL; // Allocate send comm struct flagcxP2pSendComm *comm; FLAGCXCHECK(flagcxCalloc(&comm, 1)); + int ready = 0; + auto connectStart = std::chrono::steady_clock::time_point(); + struct flagcxP2pConnMeta localMeta[kFlagcxP2pMaxQpsPerEngine]; + struct flagcxP2pConnMeta remoteMeta[kFlagcxP2pMaxQpsPerEngine]; + int localReady = 1, remoteReady = 0; + uint32_t localNumQps = 0, remoteNumQps = 0, agreedNumQps = 0; // TCP connect (blocking with timeout) - FLAGCXCHECK(flagcxSocketInit(&comm->sock, &handle->connectAddr, handle->magic, - flagcxSocketTypeNetIb, NULL, 1)); - FLAGCXCHECK(flagcxSocketConnect(&comm->sock)); - int ready = 0; - auto connectStart = std::chrono::steady_clock::now(); + FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock, &handle->connectAddr, + handle->magic, flagcxSocketTypeNetIb, NULL, + 1), + res, connect_fail); + FLAGCXCHECKGOTO(flagcxSocketConnect(&comm->sock), res, connect_fail); + connectStart = std::chrono::steady_clock::now(); while (!ready) { - FLAGCXCHECK(flagcxSocketReady(&comm->sock, &ready)); + FLAGCXCHECKGOTO(flagcxSocketReady(&comm->sock, &ready), res, connect_fail); if (!ready) { if (std::chrono::steady_clock::now() - connectStart > std::chrono::seconds(30)) { WARN("NET/IB_P2P : connect socket ready timed out after 30s"); - free(comm); - return flagcxSystemError; + res = flagcxSystemError; + goto connect_fail; } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } - // Set up PD, CQ, QP - FLAGCXCHECK(flagcxP2pSetupConn(dev, &comm->base, &comm->qp, &comm->ibDevN)); - - // Exchange connection metadata - struct flagcxP2pConnMeta localMeta, remoteMeta; - flagcxP2pBuildConnMeta(&localMeta, &comm->base, &comm->qp, comm->ibDevN); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localMeta, sizeof(localMeta))); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteMeta, sizeof(remoteMeta))); - - // Transition QP to RTR then RTS - FLAGCXCHECK( - flagcxP2pTransitionQp(&comm->qp, &comm->base, &remoteMeta, comm->ibDevN)); - - // Register putSignal scratchpad MR - comm->putSignalScratchpad = 0; - FLAGCXCHECK(flagcxWrapIbvRegMr( - &comm->putSignalScratchpadMr, comm->base.pd, &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC)); + // numQps negotiation must happen before setup so we only create the + // QPs we'll actually use; both peers agree on min(). + localNumQps = (uint32_t)flagcxP2pGlobalConfig().qpsPerConn; + if (localNumQps == 0 || localNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) + localNumQps = (uint32_t)kFlagcxP2pMaxQpsPerEngine; + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localNumQps, sizeof(localNumQps)), res, + connect_fail); + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteNumQps, sizeof(remoteNumQps)), res, + connect_fail); + if (remoteNumQps == 0 || remoteNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) { + WARN("NET/IB_P2P : peer advertised invalid numQps=%u (max=%d)", + remoteNumQps, kFlagcxP2pMaxQpsPerEngine); + res = flagcxInternalError; + goto connect_fail; + } + agreedNumQps = std::min(localNumQps, remoteNumQps); + if (localNumQps != remoteNumQps) { + INFO(FLAGCX_NET, + "NET/IB_P2P : numQps mismatch (local=%u remote=%u) — using min=%u", + localNumQps, remoteNumQps, agreedNumQps); + } + comm->numQps = (int)agreedNumQps; + + FLAGCXCHECKGOTO(flagcxP2pSetupConn(dev, comm, &comm->base, comm->qp_list_, + &comm->ibDevN, comm->numQps), + res, connect_fail); + + for (int i = 0; i < comm->numQps; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->qp_list_[i], + comm->ibDevN); + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, + comm->numQps * sizeof(localMeta[0])), + res, connect_fail); + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, + comm->numQps * sizeof(remoteMeta[0])), + res, connect_fail); + + // Transition each matched QP to RTR then RTS. + for (int i = 0; i < comm->numQps; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->qp_list_[i], &comm->base, + &remoteMeta[i], comm->ibDevN), + res, connect_fail); // Exchange ready - int localReady = 1, remoteReady = 0; - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, + connect_fail); + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, + connect_fail); + *sendComm = comm; return flagcxSuccess; + +connect_fail: + flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps); + if (comm->base.pd) + flagcxP2pReleasePd(comm->ibDevN); + flagcxSocketClose(&comm->sock); + free(comm); + return res; } static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { @@ -425,6 +499,10 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { // TCP accept (blocking, no timeout) flagcxResult_t res; int ready; + struct flagcxP2pConnMeta localMeta[kFlagcxP2pMaxQpsPerEngine]; + struct flagcxP2pConnMeta remoteMeta[kFlagcxP2pMaxQpsPerEngine]; + int localReady = 1, remoteReady = 0; + uint32_t localNumQps = 0, remoteNumQps = 0, agreedNumQps = 0; FLAGCXCHECKGOTO(flagcxSocketInit(&comm->sock), res, accept_fail); FLAGCXCHECKGOTO(flagcxSocketAccept(&comm->sock, &lComm->sock), res, accept_fail); @@ -441,187 +519,279 @@ static flagcxResult_t flagcxP2pAccept(void *listenComm, void **recvComm) { return res; } - // Set up PD, CQ, QP - FLAGCXCHECK( - flagcxP2pSetupConn(lComm->dev, &comm->base, &comm->qp, &comm->ibDevN)); - - // Exchange connection metadata (accept receives first, then sends) - struct flagcxP2pConnMeta localMeta, remoteMeta; - flagcxP2pBuildConnMeta(&localMeta, &comm->base, &comm->qp, comm->ibDevN); - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteMeta, sizeof(remoteMeta))); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localMeta, sizeof(localMeta))); - - // Transition QP to RTR then RTS - FLAGCXCHECK( - flagcxP2pTransitionQp(&comm->qp, &comm->base, &remoteMeta, comm->ibDevN)); - - // Register putSignal scratchpad MR (symmetric with connect) - comm->putSignalScratchpad = 0; - FLAGCXCHECK(flagcxWrapIbvRegMr( - &comm->putSignalScratchpadMr, comm->base.pd, &comm->putSignalScratchpad, - sizeof(comm->putSignalScratchpad), - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC)); + // accept side mirrors connect: recv numQps first, then send. + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteNumQps, sizeof(remoteNumQps)), res, + accept_cleanup); + localNumQps = (uint32_t)flagcxP2pGlobalConfig().qpsPerConn; + if (localNumQps == 0 || localNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) + localNumQps = (uint32_t)kFlagcxP2pMaxQpsPerEngine; + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localNumQps, sizeof(localNumQps)), res, + accept_cleanup); + if (remoteNumQps == 0 || remoteNumQps > (uint32_t)kFlagcxP2pMaxQpsPerEngine) { + WARN("NET/IB_P2P : peer advertised invalid numQps=%u (max=%d)", + remoteNumQps, kFlagcxP2pMaxQpsPerEngine); + res = flagcxInternalError; + goto accept_cleanup; + } + agreedNumQps = std::min(localNumQps, remoteNumQps); + if (localNumQps != remoteNumQps) { + INFO(FLAGCX_NET, + "NET/IB_P2P : numQps mismatch (local=%u remote=%u) — using min=%u", + localNumQps, remoteNumQps, agreedNumQps); + } + comm->numQps = (int)agreedNumQps; + + FLAGCXCHECKGOTO(flagcxP2pSetupConn(lComm->dev, comm, &comm->base, + comm->qp_list_, &comm->ibDevN, + comm->numQps), + res, accept_cleanup); + + for (int i = 0; i < comm->numQps; i++) + flagcxP2pBuildConnMeta(&localMeta[i], &comm->base, &comm->qp_list_[i], + comm->ibDevN); + FLAGCXCHECKGOTO(flagcxSocketRecv(&comm->sock, remoteMeta, + comm->numQps * sizeof(remoteMeta[0])), + res, accept_cleanup); + FLAGCXCHECKGOTO(flagcxSocketSend(&comm->sock, localMeta, + comm->numQps * sizeof(localMeta[0])), + res, accept_cleanup); + + // Transition each matched QP to RTR then RTS. + for (int i = 0; i < comm->numQps; i++) + FLAGCXCHECKGOTO(flagcxP2pTransitionQp(&comm->qp_list_[i], &comm->base, + &remoteMeta[i], comm->ibDevN), + res, accept_cleanup); // Exchange ready - int localReady = 1, remoteReady = 0; - FLAGCXCHECK(flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady))); - FLAGCXCHECK(flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady))); + FLAGCXCHECKGOTO( + flagcxSocketRecv(&comm->sock, &remoteReady, sizeof(remoteReady)), res, + accept_cleanup); + FLAGCXCHECKGOTO( + flagcxSocketSend(&comm->sock, &localReady, sizeof(localReady)), res, + accept_cleanup); + *recvComm = comm; return flagcxSuccess; + +accept_cleanup: + flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps); + if (comm->base.pd) + flagcxP2pReleasePd(comm->ibDevN); + flagcxSocketClose(&comm->sock); + free(comm); + return res; } /* ------------------------------------------------------------------ */ /* One-sided transfers: iput / iget / iputSignal */ /* ------------------------------------------------------------------ */ +static flagcxResult_t flagcxP2pBuildSingleSliceReq( + struct flagcxP2pSendComm *comm, uint64_t localVa, uint64_t remoteVa, + size_t size, uint32_t lkey, uint32_t rkey, uint8_t opcode, + void **request) { + if ((uint32_t)size != size) { + WARN("NET/IB_P2P : single-op size %zu exceeds 32-bit limit", size); + return flagcxInternalError; + } + + auto *req = new struct flagcxP2pSliceReq; + req->slice.srcVa = localVa; + req->slice.dstVa = remoteVa; + req->slice.length = (uint32_t)size; + req->slice.lkey = lkey; + req->slice.rkey = rkey; + req->slice.opcode = opcode; + req->slice.peerNicPath = std::string(); + req->slice.task = &req->task; + req->slice.qpDepth = NULL; + req->task.sliceList.push_back(&req->slice); + req->task.sliceCount.fetch_add(1, std::memory_order_release); + + FlagcxSlice *slicePtr = &req->slice; + flagcxResult_t rc = + flagcxP2pPoolSubmit(comm->ibDevN, comm, &slicePtr, 1); + if (rc != flagcxSuccess) { + delete req; + return rc; + } + + *request = req; + return flagcxSuccess; +} + static flagcxResult_t flagcxP2pIput(void *sendComm, uint64_t srcOff, uint64_t dstOff, size_t size, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **request) { + (void)srcRank; + (void)dstRank; struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; - - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, - FLAGCX_P2P_REQ_IPUT, &req)); - - struct ibv_sge sge; - memset(&sge, 0, sizeof(sge)); - sge.addr = src->baseVa + srcOff; - sge.length = (uint32_t)size; - if ((size_t)sge.length != size) { - WARN("NET/IB_P2P : iput size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); - return flagcxInternalError; - } - sge.lkey = src->lkey; - - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.opcode = IBV_WR_RDMA_WRITE; - wr.send_flags = IBV_SEND_SIGNALED; - wr.wr_id = req - comm->reqs; - wr.wr.rdma.remote_addr = dst->baseVa + dstOff; - wr.wr.rdma.rkey = dst->rkey; - wr.sg_list = &sge; - wr.num_sge = 1; - - struct ibv_send_wr *bad_wr; - FLAGCXCHECK(flagcxWrapIbvPostSend(comm->qp.qp, &wr, &bad_wr)); - req->events = 1; - - *request = req; - return flagcxSuccess; + return flagcxP2pBuildSingleSliceReq( + comm, src->baseVa + srcOff, dst->baseVa + dstOff, size, src->lkey, + dst->rkey, FLAGCX_SLICE_OP_WRITE, request); } static flagcxResult_t flagcxP2pIget(void *sendComm, uint64_t srcOff, uint64_t dstOff, size_t size, int srcRank, int dstRank, void **srcHandles, void **dstHandles, void **request) { + (void)srcRank; + (void)dstRank; struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; + return flagcxP2pBuildSingleSliceReq( + comm, dst->baseVa + dstOff, src->baseVa + srcOff, size, dst->lkey, + src->rkey, FLAGCX_SLICE_OP_READ, request); +} - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, - FLAGCX_P2P_REQ_IGET, &req)); - - struct ibv_sge sge; - memset(&sge, 0, sizeof(sge)); - sge.addr = dst->baseVa + dstOff; - sge.length = (uint32_t)size; - if ((size_t)sge.length != size) { - WARN("NET/IB_P2P : iget size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); +static flagcxResult_t +flagcxP2pIgetBatch(void *sendComm, int count, const uint64_t *srcOffs, + const uint64_t *dstOffs, const size_t *sizes, int srcRank, + int dstRank, void *const *srcHandles, + void *const *dstHandles, void **request) { + (void)srcRank; + (void)dstRank; + struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; + const int maxWrPerPost = (int)flagcxP2pGlobalConfig().maxWrPerPost; + if (count <= 0 || count > maxWrPerPost || srcOffs == NULL || + dstOffs == NULL || sizes == NULL || srcHandles == NULL || + dstHandles == NULL || request == NULL) { + WARN("NET/IB_P2P : invalid igetBatch arguments, count %d (max %d)", count, + maxWrPerPost); return flagcxInternalError; } - sge.lkey = dst->lkey; - - struct ibv_send_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.opcode = IBV_WR_RDMA_READ; - wr.send_flags = IBV_SEND_SIGNALED; - wr.wr_id = req - comm->reqs; - wr.wr.rdma.remote_addr = src->baseVa + srcOff; - wr.wr.rdma.rkey = src->rkey; - wr.sg_list = &sge; - wr.num_sge = 1; - - struct ibv_send_wr *bad_wr; - FLAGCXCHECK(flagcxWrapIbvPostSend(comm->qp.qp, &wr, &bad_wr)); - req->events = 1; + auto *req = new struct flagcxP2pSliceReq; + req->task.sliceList.reserve(count); + for (int i = 0; i < count; i++) { + if (srcHandles[i] == NULL || dstHandles[i] == NULL || + (uint32_t)sizes[i] != sizes[i]) { + WARN("NET/IB_P2P : igetBatch slice %d invalid", i); + for (auto *s : req->task.sliceList) + delete s; + delete req; + return flagcxInternalError; + } + auto *src = (struct flagcxP2pMrHandle *)srcHandles[i]; + auto *dst = (struct flagcxP2pMrHandle *)dstHandles[i]; + auto *s = new FlagcxSlice{dst->baseVa + dstOffs[i], + src->baseVa + srcOffs[i], + (uint32_t)sizes[i], + dst->lkey, + src->rkey, + FLAGCX_SLICE_OP_READ, + std::string(), + &req->task, + NULL}; + req->task.sliceList.push_back(s); + req->task.sliceCount.fetch_add(1, std::memory_order_release); + } + + flagcxResult_t rc = flagcxP2pPoolSubmit(comm->ibDevN, comm, + req->task.sliceList.data(), count); + if (rc != flagcxSuccess) { + for (auto *s : req->task.sliceList) + delete s; + delete req; + return rc; + } *request = req; return flagcxSuccess; } -static flagcxResult_t -flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, - size_t size, int srcRank, int dstRank, void **srcHandles, - void **dstHandles, uint64_t signalOff, void **signalHandles, - uint64_t signalValue, void **request) { +static flagcxResult_t flagcxP2pIputSignal(void *, uint64_t, uint64_t, size_t, + int, int, void **, void **, uint64_t, + void **, uint64_t, void **) { + WARN("NET/IB_P2P : iputSignal not supported"); + return flagcxInternalError; +} + +/* ------------------------------------------------------------------ */ +/* Slice batch: pool worker passes the chosen QP. wr_id = ptr|1. */ +/* ------------------------------------------------------------------ */ + +static inline enum ibv_wr_opcode flagcxSliceOpcodeToVerbs(uint8_t op) { + return op == FLAGCX_SLICE_OP_READ ? IBV_WR_RDMA_READ : IBV_WR_RDMA_WRITE; +} + +extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, + struct ibv_qp *qp, int count, + FlagcxSlice **slices) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; - struct flagcxP2pMrHandle *signalInfo = - (struct flagcxP2pMrHandle *)signalHandles; - - struct flagcxP2pRequest *req; - FLAGCXCHECK(flagcxP2pGetRequest(comm->reqs, comm->base.cq, - FLAGCX_P2P_REQ_IPUT, &req)); - - bool chainData = (size > 0 && srcHandles != NULL && dstHandles != NULL); - - struct ibv_sge sge[2]; - struct ibv_send_wr wr[2]; - memset(sge, 0, sizeof(sge)); - memset(wr, 0, sizeof(wr)); - - // wr[0]: RDMA WRITE for data (unsignaled, chained to wr[1]) - if (chainData) { - struct flagcxP2pMrHandle *src = (struct flagcxP2pMrHandle *)srcHandles; - struct flagcxP2pMrHandle *dst = (struct flagcxP2pMrHandle *)dstHandles; - - sge[0].addr = src->baseVa + srcOff; - sge[0].length = (uint32_t)size; - if ((size_t)sge[0].length != size) { - WARN("NET/IB_P2P : iputSignal size %zu exceeds 32-bit limit", size); - flagcxP2pFreeRequest(req); + const char *opLabel = + (slices != NULL && count > 0 && slices[0] != NULL && + slices[0]->opcode == FLAGCX_SLICE_OP_READ) + ? "READ" + : "WRITE"; + const int maxWrPerPost = (int)flagcxP2pGlobalConfig().maxWrPerPost; + if (count <= 0 || count > maxWrPerPost || slices == NULL || + qp == NULL || comm == NULL) { + WARN("NET/IB_P2P : invalid sliceBatch arguments (op=%s, count=%d, qp=%p, " + "max=%d)", + opLabel, count, (void *)qp, maxWrPerPost); + return flagcxInternalError; + } + + // count can be up to flagcxP2pGlobalConfig().maxWrPerPost (default 256, + // bounded at 1024). Heap-allocate to keep the stack small. + std::vector wrs(count); + std::vector sges(count); + + for (int i = 0; i < count; i++) { + FlagcxSlice *s = slices[i]; + if (s == NULL) { + WARN("NET/IB_P2P : sliceBatch slice[%d] is NULL", i); + for (int k = 0; k < i; k++) + slices[k]->markFailed(); + for (int k = i; k < count; k++) + if (slices[k]) + slices[k]->markFailed(); return flagcxInternalError; } - sge[0].lkey = src->lkey; - - wr[0].opcode = IBV_WR_RDMA_WRITE; - wr[0].send_flags = 0; // unsignaled - wr[0].wr.rdma.remote_addr = dst->baseVa + dstOff; - wr[0].wr.rdma.rkey = dst->rkey; - wr[0].sg_list = &sge[0]; - wr[0].num_sge = 1; - wr[0].next = &wr[1]; // chain to atomic + + sges[i].addr = s->srcVa; + sges[i].length = s->length; + sges[i].lkey = s->lkey; + + wrs[i].opcode = flagcxSliceOpcodeToVerbs(s->opcode); + wrs[i].send_flags = IBV_SEND_SIGNALED; + wrs[i].wr_id = ((uintptr_t)s) | 1ull; + wrs[i].wr.rdma.remote_addr = s->dstVa; + wrs[i].wr.rdma.rkey = s->rkey; + wrs[i].sg_list = &sges[i]; + wrs[i].num_sge = 1; + wrs[i].next = (i + 1 < count) ? &wrs[i + 1] : NULL; } - // wr[1]: ATOMIC FETCH_AND_ADD for signal (signaled) - sge[1].addr = (uintptr_t)&comm->putSignalScratchpad; - sge[1].length = sizeof(comm->putSignalScratchpad); - sge[1].lkey = comm->putSignalScratchpadMr->lkey; - - wr[1].opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; - wr[1].send_flags = IBV_SEND_SIGNALED; - wr[1].wr_id = req - comm->reqs; - wr[1].wr.atomic.remote_addr = signalInfo->baseVa + signalOff; - wr[1].wr.atomic.rkey = signalInfo->rkey; - wr[1].wr.atomic.compare_add = signalValue; - wr[1].sg_list = &sge[1]; - wr[1].num_sge = 1; - wr[1].next = NULL; - - struct ibv_send_wr *bad_wr; - FLAGCXCHECK( - flagcxWrapIbvPostSend(comm->qp.qp, chainData ? &wr[0] : &wr[1], &bad_wr)); - req->events = 1; + struct ibv_send_wr *bad_wr = NULL; + flagcxResult_t res = flagcxWrapIbvPostSend(qp, &wrs[0], &bad_wr); + if (res != flagcxSuccess) { + int failedFrom = 0; + if (bad_wr != NULL) { + ptrdiff_t off = bad_wr - &wrs[0]; + if (off >= 0 && off < count) + failedFrom = (int)off; + } + // Slices in [failedFrom..count) never went on the wire — roll back + // their share of the pool's qpDepth pre-bump so the gate doesn't leak. + for (int k = failedFrom; k < count; k++) { + if (slices[k]->qpDepth != NULL) + __sync_fetch_and_sub(slices[k]->qpDepth, 1); + slices[k]->markFailed(); + } + WARN("NET/IB_P2P : sliceBatch ibv_post_send failed (op=%s, count=%d, " + "failedFrom=%d)", + opLabel, count, failedFrom); + return res; + } - *request = req; return flagcxSuccess; } @@ -629,46 +799,58 @@ flagcxP2pIputSignal(void *sendComm, uint64_t srcOff, uint64_t dstOff, /* Test */ /* ------------------------------------------------------------------ */ +// Single-slice path uses the wrapper's embedded `slice`; batch path +// heap-allocates each — distinguish by address. +static inline void flagcxP2pFreeSliceReq(struct flagcxP2pSliceReq *req) { + if (!req) + return; + for (auto *s : req->task.sliceList) { + if (s != &req->slice) + delete s; + } + delete req; +} + static flagcxResult_t flagcxP2pTest(void *request, int *done, int *sizes) { *done = 0; - struct flagcxP2pRequest *req = (struct flagcxP2pRequest *)request; - if (req == NULL || req->type == FLAGCX_P2P_REQ_UNUSED) { + if (sizes) + *sizes = 0; + if (request == NULL) { *done = 1; return flagcxSuccess; } - - int nCqe = 0; - struct ibv_wc wc; - FLAGCXCHECK(flagcxWrapIbvPollCq(req->cq, 1, &wc, &nCqe)); - - if (nCqe == 0) - return flagcxSuccess; - - if (wc.status != IBV_WC_SUCCESS) { - WARN("NET/IB_P2P : CQ error: status=%d opcode=%d wr_id=%lu", wc.status, - wc.opcode, wc.wr_id); - return flagcxRemoteError; - } - - // Map CQE back to the correct request via wr_id - uint32_t reqIdx = wc.wr_id; - if (reqIdx >= FLAGCX_P2P_MAX_REQUESTS) { - WARN("NET/IB_P2P : invalid wr_id %u in CQE", reqIdx); - return flagcxInternalError; - } - struct flagcxP2pRequest *completedReq = &req->reqs[reqIdx]; - - completedReq->events--; - if (completedReq->events == 0) { - completedReq->type = FLAGCX_P2P_REQ_UNUSED; + auto *req = static_cast(request); + if (req->task.isAllDone()) { + *done = 1; + if (sizes) { + uint64_t total = 0; + for (auto *s : req->task.sliceList) + total += s->length; + *sizes = (int)std::min(total, (uint64_t)INT32_MAX); + } + flagcxP2pFreeSliceReq(req); } + return flagcxSuccess; +} - // Check if the originally requested op is done - if (req->type == FLAGCX_P2P_REQ_UNUSED) { - *done = 1; - if (sizes) - *sizes = 0; +static flagcxResult_t flagcxP2pTestBatch(void **requests, int nRequests, + int *doneFlags, int *doneCount) { + int completed = 0; + for (int i = 0; i < nRequests; i++) { + doneFlags[i] = 0; + auto *req = static_cast(requests[i]); + if (req == NULL) { + doneFlags[i] = 1; + completed++; + continue; + } + if (req->task.isAllDone()) { + doneFlags[i] = 1; + completed++; + flagcxP2pFreeSliceReq(req); + } } + *doneCount = completed; return flagcxSuccess; } @@ -694,29 +876,10 @@ static flagcxResult_t flagcxP2pReleasePd(int ibDevN) { return flagcxSuccess; } -// Helper: drain CQ before destroying resources -static void flagcxP2pDrainCq(struct ibv_cq *cq) { - if (!cq) - return; - struct ibv_wc wcs[64]; - int nCqe = 0; - for (int i = 0; i < 16; i++) { - flagcxWrapIbvPollCq(cq, 64, wcs, &nCqe); - if (nCqe == 0) - break; - } -} - static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { struct flagcxP2pSendComm *comm = (struct flagcxP2pSendComm *)sendComm; if (comm) { - flagcxP2pDrainCq(comm->base.cq); - if (comm->qp.qp) - FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); - if (comm->putSignalScratchpadMr) - FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); - if (comm->base.cq) - FLAGCXCHECK(flagcxWrapIbvDestroyCq(comm->base.cq)); + FLAGCXCHECK(flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); @@ -727,13 +890,7 @@ static flagcxResult_t flagcxP2pCloseSend(void *sendComm) { static flagcxResult_t flagcxP2pCloseRecv(void *recvComm) { struct flagcxP2pRecvComm *comm = (struct flagcxP2pRecvComm *)recvComm; if (comm) { - flagcxP2pDrainCq(comm->base.cq); - if (comm->qp.qp) - FLAGCXCHECK(flagcxWrapIbvDestroyQp(comm->qp.qp)); - if (comm->putSignalScratchpadMr) - FLAGCXCHECK(flagcxWrapIbvDeregMr(comm->putSignalScratchpadMr)); - if (comm->base.cq) - FLAGCXCHECK(flagcxWrapIbvDestroyCq(comm->base.cq)); + FLAGCXCHECK(flagcxP2pDestroyQps(comm->ibDevN, comm->qp_list_, comm->numQps)); FLAGCXCHECK(flagcxP2pReleasePd(comm->ibDevN)); FLAGCXCHECK(flagcxSocketClose(&comm->sock)); free(comm); @@ -793,20 +950,40 @@ static flagcxResult_t flagcxP2pGetDevFromName(char *name, int *dev) { struct flagcxNetAdaptor flagcxNetIbP2p = { // Basic functions - "IB_P2P", flagcxP2pInit, flagcxP2pDevices, flagcxP2pGetProperties, + "IB_P2P", + flagcxP2pInit, + flagcxP2pDevices, + flagcxP2pGetProperties, // Setup functions - flagcxP2pListen, flagcxP2pConnect, flagcxP2pAccept, flagcxP2pCloseSend, - flagcxP2pCloseRecv, flagcxP2pCloseListen, + flagcxP2pListen, + flagcxP2pConnect, + flagcxP2pAccept, + flagcxP2pCloseSend, + flagcxP2pCloseRecv, + flagcxP2pCloseListen, // Memory region functions - flagcxP2pRegMr, flagcxP2pRegMrDmaBuf, flagcxP2pDeregMr, + flagcxP2pRegMr, + flagcxP2pRegMrDmaBuf, + flagcxP2pDeregMr, // Two-sided functions (stubs) - flagcxP2pIsend, flagcxP2pIrecv, flagcxP2pIflush, flagcxP2pTest, + flagcxP2pIsend, + flagcxP2pIrecv, + flagcxP2pIflush, + flagcxP2pTest, // One-sided functions - flagcxP2pIput, flagcxP2pIget, flagcxP2pIputSignal, + flagcxP2pIput, + flagcxP2pIget, + flagcxP2pIputSignal, // Device name lookup - flagcxP2pGetDevFromName}; + flagcxP2pGetDevFromName, + + // Optional batch operations + nullptr, // iputBatch + flagcxP2pTestBatch, // testBatch + flagcxP2pIgetBatch, // igetBatch +}; diff --git a/flagcx/core/flagcx_p2p.cc b/flagcx/core/flagcx_p2p.cc index 21deb1bc..7f82961c 100644 --- a/flagcx/core/flagcx_p2p.cc +++ b/flagcx/core/flagcx_p2p.cc @@ -12,15 +12,19 @@ #include "flagcx_p2p.h" #include "adaptor.h" +#include "debug.h" #include "flagcx_net.h" #include "flagcx_net_adaptor.h" #include "ib_common.h" +#include "ibvwrap.h" #include "p2p_topo.h" +#include "param.h" #include "socket.h" #include #include #include +#include #include #include #include @@ -40,6 +44,113 @@ extern struct flagcxNetAdaptor flagcxNetIbP2p; +extern "C" flagcxResult_t flagcxP2pSliceBatch(void *sendComm, + struct ibv_qp *qp, int count, + FlagcxSlice **slices); + +namespace { + +FLAGCX_PARAM(P2pQpsPerConn, "P2P_QPS_PER_CONN", 4); +FLAGCX_PARAM(P2pWorkersPerPool, "P2P_WORKERS_PER_POOL", 4); +FLAGCX_PARAM(P2pShardCount, "P2P_SHARD_COUNT", 8); +FLAGCX_PARAM(P2pCqDepth, "P2P_CQ_DEPTH", 4096); +FLAGCX_PARAM(P2pMaxWrPerPost, "P2P_MAX_WR_PER_POST", 64); +FLAGCX_PARAM(P2pMaxRequests, "P2P_MAX_REQUESTS", 256); +FLAGCX_PARAM(P2pBatchPollSize, "P2P_BATCH_POLL_SIZE", 64); +FLAGCX_PARAM(P2pSliceSize, "P2P_SLICE_SIZE", 65536); +FLAGCX_PARAM(P2pFragmentLimit, "P2P_FRAGMENT_LIMIT", 4096); +FLAGCX_PARAM(P2pMaxSge, "P2P_MAX_SGE", 4); +FLAGCX_PARAM(P2pMaxInline, "P2P_MAX_INLINE", 64); +FLAGCX_PARAM(P2pIbPort, "P2P_IB_PORT", 1); +FLAGCX_PARAM(P2pGidIndex, "P2P_GID_INDEX", -1); +FLAGCX_PARAM(P2pMtu, "P2P_MTU", 4096); +FLAGCX_PARAM(P2pIbTc, "P2P_IB_TC", -1); +FLAGCX_PARAM(P2pRetryCnt, "P2P_RETRY_CNT", 7); +FLAGCX_PARAM(P2pNotifMaxPeers, "P2P_NOTIF_MAX_PEERS", 64); +FLAGCX_PARAM(P2pDestDevAffinity, "P2P_DEST_DEV_AFFINITY", 0); + +template +inline T clampParam(int64_t v, T lo, T hi, T deft, const char *name) { + if (v < (int64_t)lo || v > (int64_t)hi) { + INFO(FLAGCX_INIT, + "Ignore FLAGCX_%s=%lld (out of [%lld,%lld]); using default %lld", + name, (long long)v, (long long)lo, (long long)hi, (long long)deft); + return deft; + } + return (T)v; +} + +void loadGlobalConfig(FlagcxP2pGlobalConfig &c) { + c.qpsPerConn = clampParam(flagcxParamP2pQpsPerConn(), 1, kFlagcxP2pMaxQpsPerEngine, 4, "P2P_QPS_PER_CONN"); + c.workersPerPool = clampParam(flagcxParamP2pWorkersPerPool(), 1, 8, 4, "P2P_WORKERS_PER_POOL"); + c.shardCount = clampParam(flagcxParamP2pShardCount(), 1, 64, 8, "P2P_SHARD_COUNT"); + c.sharedCqDepth = clampParam(flagcxParamP2pCqDepth(), 1, 1u<<20, 4096, "P2P_CQ_DEPTH"); + c.maxWrPerPost = clampParam(flagcxParamP2pMaxWrPerPost(),1, 1024, 256, "P2P_MAX_WR_PER_POST"); + c.maxRequests = clampParam(flagcxParamP2pMaxRequests(), 1, 1u<<16, 256, "P2P_MAX_REQUESTS"); + c.batchPollSize = clampParam(flagcxParamP2pBatchPollSize(), 1, 256, 64, "P2P_BATCH_POLL_SIZE"); + c.sliceSize = clampParam(flagcxParamP2pSliceSize(), 1024, 1u<<26, 65536, "P2P_SLICE_SIZE"); + c.fragmentLimit = clampParam(flagcxParamP2pFragmentLimit(), 0, c.sliceSize, 4096, "P2P_FRAGMENT_LIMIT"); + c.maxSge = clampParam(flagcxParamP2pMaxSge(), 1, 32, 4, "P2P_MAX_SGE"); + c.maxInline = clampParam(flagcxParamP2pMaxInline(), 0, 1024, 64, "P2P_MAX_INLINE"); + c.ibPort = clampParam(flagcxParamP2pIbPort(), 1, 255, 1, "P2P_IB_PORT"); + c.gidIndex = clampParam(flagcxParamP2pGidIndex(), -1, 255, -1, "P2P_GID_INDEX"); + { + int64_t mv = flagcxParamP2pMtu(); + if (mv == 512 || mv == 1024 || mv == 2048 || mv == 4096) { + c.mtuLength = (int)mv; + } else { + WARN("Ignore FLAGCX_P2P_MTU=%lld (must be 512/1024/2048/4096); using 4096", + (long long)mv); + c.mtuLength = 4096; + } + } + c.ibTrafficClass = clampParam(flagcxParamP2pIbTc(), -1, 255, -1, "P2P_IB_TC"); + c.retryCnt = clampParam(flagcxParamP2pRetryCnt(), 0, 7, 7, "P2P_RETRY_CNT"); + c.notifMaxPeers = clampParam(flagcxParamP2pNotifMaxPeers(), 1, 1024, 64, "P2P_NOTIF_MAX_PEERS"); + c.enableDestDeviceAffinity = (flagcxParamP2pDestDevAffinity() != 0); +} + +void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c); + +FlagcxP2pGlobalConfig &mutableGlobalConfig() { + static FlagcxP2pGlobalConfig cfg; + static std::once_flag once; + std::call_once(once, [] { + loadGlobalConfig(cfg); + dumpGlobalConfigImpl(cfg); + }); + return cfg; +} + +void dumpGlobalConfigImpl(const FlagcxP2pGlobalConfig &c) { + INFO(FLAGCX_INIT, "=== FlagCX P2P GlobalConfig ==="); + INFO(FLAGCX_INIT, + "qpsPerConn=%d workersPerPool=%d shardCount=%d", + c.qpsPerConn, c.workersPerPool, c.shardCount); + INFO(FLAGCX_INIT, + "sharedCqDepth=%zu maxWrPerPost=%zu maxRequests=%zu batchPollSize=%zu", + c.sharedCqDepth, c.maxWrPerPost, c.maxRequests, c.batchPollSize); + INFO(FLAGCX_INIT, "sliceSize=%zu fragmentLimit=%zu", + c.sliceSize, c.fragmentLimit); + INFO(FLAGCX_INIT, + "ibPort=%u gidIndex=%d mtu=%d tc=%d retry=%d " + "maxSge=%zu maxInline=%zu", + (unsigned)c.ibPort, c.gidIndex, c.mtuLength, c.ibTrafficClass, + c.retryCnt, c.maxSge, c.maxInline); + INFO(FLAGCX_INIT, "notifMaxPeers=%d destDevAffinity=%d", + c.notifMaxPeers, (int)c.enableDestDeviceAffinity); +} + +} // namespace + +const FlagcxP2pGlobalConfig &flagcxP2pGlobalConfig() { + return mutableGlobalConfig(); +} + +void flagcxP2pDumpGlobalConfig() { + dumpGlobalConfigImpl(flagcxP2pGlobalConfig()); +} + struct FlagcxP2pMrHandleView { uintptr_t baseVa; uint32_t lkey; @@ -58,7 +169,7 @@ static_assert(sizeof(FlagcxP2pListenHandleView) <= FLAGCX_NET_HANDLE_MAXSIZE, struct FlagcxP2pCommView { int ibDevN; struct flagcxIbNetCommDevBase base; - struct flagcxIbQp qp; + struct flagcxIbQp qp_list_[kFlagcxP2pMaxQpsPerEngine]; struct flagcxSocket sock; }; @@ -122,7 +233,6 @@ struct FlagcxP2pEngine { #if defined(__linux__) int notifEpollFd; #endif - std::thread notifThread; std::atomic stopNotif; std::unordered_map notifPeers; std::mutex notifPeerMutex; @@ -174,6 +284,7 @@ static std::vector gNotifyList; static std::mutex gNotifyMutex; static std::unordered_map gMemRegInfo; +static std::unordered_map gMrToBaseAddr; static std::mutex gMemMutex; static uint64_t gNextMrId = 1; @@ -181,182 +292,624 @@ static std::unordered_map gXferMap; static std::mutex gXferMutex; static uint64_t gNextXferId = 1; -/* ------------------------------------------------------------------ */ -/* Async Transfer Worker Infrastructure */ -/* ------------------------------------------------------------------ */ +struct FlagcxNixlSlicePolicy { + static constexpr bool kFurtherCut = false; + static constexpr size_t kBlockSize = SIZE_MAX; + static constexpr size_t kFragmentSize = 0; +}; + +struct FlagcxConnectorSlicePolicy { + static constexpr bool kFurtherCut = true; + static constexpr size_t kBlockSize = 64 * 1024; + static constexpr size_t kFragmentSize = 4 * 1024; +}; + +template +inline void flagcxBuildSlices(FlagcxTransferTask *task, uint64_t srcVa, + uint64_t dstVa, size_t totalLen, uint32_t lkey, + uint32_t rkey, uint8_t opcode, + const std::string &peerNicPath) { + if (!Policy::kFurtherCut) { + auto *s = new FlagcxSlice{srcVa, dstVa, (uint32_t)totalLen, + lkey, rkey, opcode, + peerNicPath, task, nullptr}; + task->sliceList.push_back(s); + task->sliceCount.fetch_add(1, std::memory_order_release); + return; + } -static constexpr int kWindowSize = 64; -static constexpr int kBatchPollCqe = 32; + size_t off = 0; + while (off < totalLen) { + bool merge = + (totalLen - off) <= Policy::kBlockSize + Policy::kFragmentSize; + size_t len = merge ? (totalLen - off) : Policy::kBlockSize; + auto *s = new FlagcxSlice{srcVa + off, dstVa + off, (uint32_t)len, + lkey, rkey, opcode, + peerNicPath, task, nullptr}; + task->sliceList.push_back(s); + task->sliceCount.fetch_add(1, std::memory_order_release); + off += len; + if (merge) + break; + } +} -enum AsyncXferOp { ASYNC_XFER_READ, ASYNC_XFER_WRITE }; +static void notifPollThreadFunc(FlagcxP2pEngine *engine); -struct AsyncTransferTask { - FlagcxP2pConn *conn; - AsyncXferOp op; - int numIovs; - std::vector dataVec; - std::vector sizeVec; - std::vector descs; - std::vector localEntries; - std::atomic done{false}; - std::atomic result{0}; +namespace { + +struct PoolQpEntry { + struct ibv_qp *qp; + void *sendComm; // owning conn (flagcxP2pSendComm/RecvComm) + volatile int wrDepth; + + PoolQpEntry(struct ibv_qp *q, void *sc) : qp(q), sendComm(sc), wrDepth(0) {} + PoolQpEntry(const PoolQpEntry &) = delete; + PoolQpEntry &operator=(const PoolQpEntry &) = delete; }; -struct AsyncWorker { - std::thread thread; - pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - pthread_cond_t cv = PTHREAD_COND_INITIALIZER; - std::deque> queue; - std::atomic stop{false}; +class FlagcxWorkerPool { +public: + FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx); + ~FlagcxWorkerPool(); + FlagcxWorkerPool(const FlagcxWorkerPool &) = delete; + FlagcxWorkerPool &operator=(const FlagcxWorkerPool &) = delete; + + struct ibv_cq *getSharedCq() const { return shared_cq_; } + void registerQp(void *sendComm, struct ibv_qp *qp); + void unregisterQp(struct ibv_qp *qp); + + flagcxResult_t submitPostSend(void *sendComm, FlagcxSlice **slices, + int count); + + void startNotif(FlagcxP2pEngine *engine); + void stopNotif(); + +private: + void transferWorkerLoop(int tid); + void performPostSend(int tid); + void performPollCq(); + void notifWorkerLoop(); + + static uint64_t nowNs() { + using clk = std::chrono::steady_clock; + return std::chrono::duration_cast( + clk::now().time_since_epoch()) + .count(); + } + + int ibDevN_; + struct ibv_cq *shared_cq_ = nullptr; + + int numWorkers_; + int numShards_; + size_t maxWrPerPost_; + size_t batchPollSize_; + int maxWrDepth_ = 0; + + std::mutex qp_mu_; + std::vector> qpEntries_; + std::unordered_map qpNumToIdx_; + std::vector> workerQpIdx_; + std::vector workerQpCursor_; + std::unordered_map connQpRegCount_; + std::vector>> + slice_queues_; + std::unique_ptr slice_locks_; + std::unique_ptr[]> slice_queue_count_; + std::atomic shardRoundRobin_{0}; + std::vector>> + collective_slice_queue_; + + std::atomic submitted_{0}; + std::atomic processed_{0}; + std::atomic suspended_flag_{0}; + std::condition_variable cv_; + std::mutex cv_mu_; + + std::atomic running_{true}; + std::vector transferThreads_; + + FlagcxP2pEngine *engine_ = nullptr; + std::thread notifThread_; + std::atomic notifSpawned_{false}; }; -static AsyncWorker gAsyncWorker; +FlagcxWorkerPool::FlagcxWorkerPool(int ibDevN, struct ibv_context *ctx) + : ibDevN_(ibDevN) { + const auto &C = flagcxP2pGlobalConfig(); + numWorkers_ = C.workersPerPool; + numShards_ = C.shardCount; + maxWrPerPost_ = C.maxWrPerPost; + batchPollSize_ = C.batchPollSize; + + if (numWorkers_ > 0 && numShards_ % numWorkers_ != 0) { + int rounded = + ((numShards_ + numWorkers_ - 1) / numWorkers_) * numWorkers_; + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] rounded shardCount %d → %d for even " + "worker assignment (W=%d)", + ibDevN_, numShards_, rounded, numWorkers_); + numShards_ = rounded; + } + + if (numWorkers_ > 0 && C.qpsPerConn % numWorkers_ != 0) { + WARN("NET/IB_P2P : pool[%d] qpsPerEngine=%d not divisible by " + "workersPerPool=%d — QPs spread per connection but unevenly; some " + "workers will own more QPs than others for each conn", + ibDevN_, C.qpsPerConn, numWorkers_); + } + + flagcxResult_t res = flagcxWrapIbvCreateCq( + &shared_cq_, ctx, (int)C.sharedCqDepth, NULL, NULL, 0); + if (res != flagcxSuccess) { + WARN("NET/IB_P2P : pool[%d] failed to create shared CQ", ibDevN_); + shared_cq_ = nullptr; + return; + } + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] shared CQ created (depth=%zu, workers=%d, " + "shards=%d, qpsPerConn=%d)", + ibDevN_, C.sharedCqDepth, numWorkers_, numShards_, C.qpsPerConn); -static FlagcxP2pCommView *getCommView(void *comm) { - return reinterpret_cast(comm); + slice_queues_.resize(numShards_); + slice_locks_.reset(new std::mutex[numShards_]); + slice_queue_count_.reset(new std::atomic[numShards_]); + for (int s = 0; s < numShards_; s++) + slice_queue_count_[s].store(0, std::memory_order_relaxed); + + workerQpIdx_.resize(numWorkers_); + workerQpCursor_.assign(numWorkers_, 0); + collective_slice_queue_.resize(numWorkers_); + + transferThreads_.reserve(numWorkers_); + for (int t = 0; t < numWorkers_; t++) { + transferThreads_.emplace_back([this, t] { this->transferWorkerLoop(t); }); + } } -static void asyncWorkerFunc() { - while (true) { - std::shared_ptr task; - pthread_mutex_lock(&gAsyncWorker.mutex); - while (gAsyncWorker.queue.empty() && !gAsyncWorker.stop.load()) { - pthread_cond_wait(&gAsyncWorker.cv, &gAsyncWorker.mutex); +FlagcxWorkerPool::~FlagcxWorkerPool() { + running_.store(false, std::memory_order_release); + cv_.notify_all(); + for (auto &t : transferThreads_) { + if (t.joinable()) + t.join(); + } + // notifThread_ is joined explicitly via stopNotif() in EngineDestroy; + // by the time ~pool runs at process exit it should already be joined. + if (notifThread_.joinable()) { + notifThread_.join(); + } + // CQ destruction skipped: pool is process-lived; OS reclaims. +} + +void FlagcxWorkerPool::startNotif(FlagcxP2pEngine *engine) { + if (engine == nullptr) + return; + // Compare-exchange ensures only the first attach spawns the thread. + bool expected = false; + if (!notifSpawned_.compare_exchange_strong(expected, true)) { + // Already attached — keep existing engine pointer (or update if NULL). + if (engine_ == nullptr) + engine_ = engine; + return; + } + engine_ = engine; + notifThread_ = std::thread([this] { this->notifWorkerLoop(); }); + INFO(FLAGCX_INIT, "NET/IB_P2P : pool[%d] notifWorker spawned", ibDevN_); +} + +void FlagcxWorkerPool::stopNotif() { + // Caller must have set engine_->stopNotif before calling — that breaks + // the epoll loop in notifPollThreadFunc. + if (notifThread_.joinable()) { + notifThread_.join(); + } + engine_ = nullptr; + notifSpawned_.store(false, std::memory_order_release); +} + +void FlagcxWorkerPool::notifWorkerLoop() { + if (engine_ == nullptr) + return; + // Reuse the original engine-side body — same behavior, just owned by + // the pool's thread. + notifPollThreadFunc(engine_); +} + +void FlagcxWorkerPool::registerQp(void *sendComm, struct ibv_qp *qp) { + if (!qp || numWorkers_ <= 0) + return; + + std::lock_guard lk(qp_mu_); + + if (maxWrDepth_ == 0) { + struct ibv_qp_attr attr; + struct ibv_qp_init_attr initAttr; + memset(&attr, 0, sizeof(attr)); + memset(&initAttr, 0, sizeof(initAttr)); + if (flagcxWrapIbvQueryQp(qp, &attr, IBV_QP_CAP, &initAttr) == + flagcxSuccess) { + int cap = (int)initAttr.cap.max_send_wr; + if (cap > 0) { + maxWrDepth_ = cap; + INFO(FLAGCX_INIT, + "NET/IB_P2P : pool[%d] resolved max_wr_depth=%d from first QP", + ibDevN_, cap); + } + } else { + WARN("NET/IB_P2P : pool[%d] ibv_query_qp failed; max_wr_depth " + "stays unresolved (slice posts will fall back to no gate)", + ibDevN_); } - if (gAsyncWorker.stop.load() && gAsyncWorker.queue.empty()) { - pthread_mutex_unlock(&gAsyncWorker.mutex); - return; + } + + int idx = (int)qpEntries_.size(); + qpEntries_.emplace_back(new PoolQpEntry(qp, sendComm)); + qpNumToIdx_[qp->qp_num] = idx; + + int connIdx = connQpRegCount_[sendComm]++; + int slot = connIdx % numWorkers_; + workerQpIdx_[slot].push_back(idx); +} + +void FlagcxWorkerPool::unregisterQp(struct ibv_qp *qp) { + if (!qp) + return; + std::lock_guard lk(qp_mu_); + auto it = qpNumToIdx_.find(qp->qp_num); + if (it == qpNumToIdx_.end()) + return; + int idx = it->second; + qpNumToIdx_.erase(it); + void *sc = qpEntries_[idx]->sendComm; + for (auto &shard : workerQpIdx_) { + auto vit = std::find(shard.begin(), shard.end(), idx); + if (vit != shard.end()) { + shard.erase(vit); + break; } - task = gAsyncWorker.queue.front(); - gAsyncWorker.queue.pop_front(); - pthread_mutex_unlock(&gAsyncWorker.mutex); - - FlagcxP2pConn *conn = task->conn; - struct flagcxNetAdaptor *adaptor = conn->engine->adaptor; - const int numIovs = task->numIovs; - const int connIbDevN = getCommView(conn->sendComm)->ibDevN; - - std::vector inflightReqs(kWindowSize, nullptr); - int issued = 0, completed = 0; - bool error = false; - - while (completed < numIovs && !error) { - // Post up to kWindowSize ahead of completed - while (issued < numIovs && (issued - completed) < kWindowSize) { - if (task->localEntries[issued].ibDevN != connIbDevN) { - error = true; - break; - } + } + auto cit = connQpRegCount_.find(sc); + if (cit != connQpRegCount_.end() && --cit->second <= 0) + connQpRegCount_.erase(cit); + // Slot kept alive (NULL'd) so any in-flight slice's qpDepth pointer stays valid. + qpEntries_[idx]->qp = nullptr; + qpEntries_[idx]->sendComm = nullptr; +} - FlagcxP2pMrHandleView *localMr = - reinterpret_cast( - task->localEntries[issued].mhandle); - - FlagcxP2pMrHandleView remoteMr; - memset(&remoteMr, 0, sizeof(remoteMr)); - remoteMr.baseVa = task->descs[issued].addr; - remoteMr.rkey = task->descs[issued].rkey; - - void *request = NULL; - flagcxResult_t rc; - - if (task->op == ASYNC_XFER_READ) { - const uint64_t srcOff = 0; - const uint64_t dstOff = - (uintptr_t)task->dataVec[issued] - localMr->baseVa; - rc = adaptor->iget(conn->sendComm, srcOff, dstOff, - task->sizeVec[issued], 0, 0, (void **)&remoteMr, - (void **)task->localEntries[issued].mhandle, - &request); - } else { - const uint64_t srcOff = - (uintptr_t)task->dataVec[issued] - localMr->baseVa; - const uint64_t dstOff = 0; - rc = adaptor->iput(conn->sendComm, srcOff, dstOff, - task->sizeVec[issued], 0, 0, - (void **)task->localEntries[issued].mhandle, - (void **)&remoteMr, &request); - } +flagcxResult_t FlagcxWorkerPool::submitPostSend(void *sendComm, + FlagcxSlice **slices, + int count) { + if (count <= 0 || slices == nullptr) + return flagcxSuccess; - if (rc != flagcxSuccess) { - error = true; - break; - } + std::vector> perShard(numShards_); + int enqueued = 0; + for (int i = 0; i < count; i++) { + if (slices[i] == nullptr) + continue; + int shard = (int)(shardRoundRobin_.fetch_add(1, std::memory_order_relaxed) % + numShards_); + perShard[shard].push_back(slices[i]); + enqueued++; + } + if (enqueued == 0) + return flagcxSuccess; - inflightReqs[issued % kWindowSize] = request; - issued++; - } + for (int s = 0; s < numShards_; s++) { + if (perShard[s].empty()) + continue; + std::lock_guard lk(slice_locks_[s]); + auto &vec = slice_queues_[s][sendComm]; + vec.insert(vec.end(), perShard[s].begin(), perShard[s].end()); + slice_queue_count_[s].fetch_add((int)perShard[s].size(), + std::memory_order_relaxed); + } + submitted_.fetch_add(enqueued, std::memory_order_release); - // Batch-poll completions for in-flight requests - int newlyCompleted = 0; - for (int i = completed; i < issued; i++) { - int slot = i % kWindowSize; - if (inflightReqs[slot] == nullptr) { - if (i == completed) - completed++; - continue; - } - int done = 0, sizes = 0; - flagcxResult_t res = adaptor->test(inflightReqs[slot], &done, &sizes); - if (res != flagcxSuccess) { - inflightReqs[slot] = nullptr; - if (i == completed) - completed++; - newlyCompleted++; - continue; - } - if (done) { - inflightReqs[slot] = nullptr; - if (i == completed) - completed++; - newlyCompleted++; + if (suspended_flag_.load(std::memory_order_acquire) > 0) { + std::lock_guard lk(cv_mu_); + cv_.notify_all(); + } + return flagcxSuccess; +} + +void FlagcxWorkerPool::transferWorkerLoop(int tid) { + const static uint64_t kWaitPeriodInNano = 100ull * 1000 * 1000; // 100ms + uint64_t last_wait_ts = nowNs(); + + while (running_.load(std::memory_order_relaxed)) { + auto processed_slice_count = + processed_.load(std::memory_order_relaxed); + auto submitted_slice_count = + submitted_.load(std::memory_order_relaxed); + + if (processed_slice_count == submitted_slice_count) { + uint64_t curr_wait_ts = nowNs(); + if (curr_wait_ts - last_wait_ts > kWaitPeriodInNano) { + std::unique_lock lock(cv_mu_); + suspended_flag_.fetch_add(1); + if (processed_.load(std::memory_order_relaxed) == + submitted_.load(std::memory_order_relaxed)) { + cv_.wait_for(lock, std::chrono::seconds(1)); } + suspended_flag_.fetch_sub(1); + last_wait_ts = curr_wait_ts; } + continue; + } + + performPostSend(tid); + performPollCq(); + } +} + +void FlagcxWorkerPool::performPostSend(int tid) { + if (numWorkers_ <= 0) + return; - // Advance completed pointer over contiguous completions - while (completed < issued && - inflightReqs[completed % kWindowSize] == nullptr) { - completed++; + auto &local = collective_slice_queue_[tid]; + for (int s = tid; s < numShards_; s += numWorkers_) { + if (slice_queue_count_[s].load(std::memory_order_relaxed) == 0) + continue; + std::lock_guard lk(slice_locks_[s]); + for (auto &entry : slice_queues_[s]) { + if (entry.second.empty()) + continue; + auto &dst = local[entry.first]; + dst.insert(dst.end(), entry.second.begin(), entry.second.end()); + entry.second.clear(); + } + slice_queue_count_[s].store(0, std::memory_order_relaxed); + } + + std::vector myQpEntries; + int curMaxDepth; + { + std::lock_guard lk(qp_mu_); + myQpEntries.reserve(workerQpIdx_[tid].size()); + for (int idx : workerQpIdx_[tid]) + myQpEntries.push_back(qpEntries_[idx].get()); + curMaxDepth = maxWrDepth_; + } + + size_t &cursor = workerQpCursor_[tid]; + for (auto &entry : local) { + void *sc = entry.first; + auto &pending = entry.second; + if (pending.empty()) + continue; + + std::vector myQpOnComm; + myQpOnComm.reserve(myQpEntries.size()); + for (PoolQpEntry *e : myQpEntries) { + if (e && e->qp != nullptr && e->sendComm == sc) + myQpOnComm.push_back(e); + } + if (myQpOnComm.empty()) { + WARN("NET/IB_P2P : pool[%d] worker %d owns no QP for Engine %p; " + "failing %zu slices", + ibDevN_, tid, sc, pending.size()); + for (auto *sl : pending) + sl->markFailed(); + processed_.fetch_add(pending.size(), std::memory_order_release); + pending.clear(); + continue; + } + + const size_t ringSz = myQpOnComm.size(); + size_t i = 0; + while (i < pending.size()) { + const size_t take = std::min(maxWrPerPost_, pending.size() - i); + PoolQpEntry *chosen = nullptr; + for (size_t k = 0; k < ringSz; k++) { + PoolQpEntry *e = myQpOnComm[(cursor + k) % ringSz]; + int cur = e->wrDepth; + if (curMaxDepth == 0 || cur + (int)take <= curMaxDepth) { + chosen = e; + cursor = (cursor + k + 1) % ringSz; + break; + } } - // Yield briefly if no progress was made - if (newlyCompleted == 0 && issued >= numIovs) { - std::this_thread::sleep_for(std::chrono::microseconds(1)); + if (chosen == nullptr) + break; // all of this worker's QPs for the engine are full; retry later + + volatile int *depthPtr = &chosen->wrDepth; + __sync_fetch_and_add(depthPtr, (int)take); + + std::vector chunk; + chunk.reserve(take); + for (size_t k = 0; k < take; k++) { + FlagcxSlice *sl = pending[i + k]; + sl->qpDepth = depthPtr; + chunk.push_back(sl); } + + flagcxResult_t rc = + flagcxP2pSliceBatch(sc, chosen->qp, (int)take, chunk.data()); + if (rc != flagcxSuccess) + processed_.fetch_add(take, std::memory_order_release); + i += take; } + // Drop the posted prefix; anything left stays for the next iteration. + if (i > 0) + pending.erase(pending.begin(), pending.begin() + i); + } +} + +void FlagcxWorkerPool::performPollCq() { + if (shared_cq_ == nullptr) + return; + + constexpr int kMaxPollBatch = 256; + struct ibv_wc wcs[kMaxPollBatch]; + int batch = (int)std::min(batchPollSize_, kMaxPollBatch); + int n = 0; + if (flagcxWrapIbvPollCq(shared_cq_, batch, wcs, &n) != flagcxSuccess) { + WARN("NET/IB_P2P : ibv_poll_cq failed on shared CQ %p", shared_cq_); + return; + } + if (n == 0) + return; + + uint64_t sliceProgressed = 0; + std::unordered_map qpDepthSet; + for (int i = 0; i < n; i++) { + uintptr_t raw = (uintptr_t)wcs[i].wr_id; + if (raw == 0 || (raw & 1ull) == 0) + continue; - task->result.store(error ? -1 : 0, std::memory_order_release); - task->done.store(true, std::memory_order_release); + FlagcxSlice *slice = + reinterpret_cast(raw & ~(uintptr_t)1ull); + if (slice->qpDepth != NULL) + qpDepthSet[slice->qpDepth]++; + if (wcs[i].status != IBV_WC_SUCCESS) { + WARN("NET/IB_P2P : pool poll error status %d for slice %p", + wcs[i].status, slice); + slice->markFailed(); + } else { + slice->markSuccess(); + } + sliceProgressed++; } + for (auto &entry : qpDepthSet) + __sync_fetch_and_sub(entry.first, entry.second); + if (sliceProgressed > 0) + processed_.fetch_add(sliceProgressed, std::memory_order_release); +} + +// ---- Per-ibDev singleton plumbing ----------------------------------- + +static std::unique_ptr gPools[MAX_IB_DEVS]; +static std::mutex gPoolMu; + +static FlagcxWorkerPool *getOrCreatePool(int ibDevN, struct ibv_context *ctx) { + if (ibDevN < 0 || ibDevN >= MAX_IB_DEVS || ctx == NULL) + return NULL; + std::lock_guard lk(gPoolMu); + if (!gPools[ibDevN]) + gPools[ibDevN].reset(new FlagcxWorkerPool(ibDevN, ctx)); + return gPools[ibDevN].get(); +} + +static FlagcxWorkerPool *lookupPool(int ibDevN) { + if (ibDevN < 0 || ibDevN >= MAX_IB_DEVS) + return nullptr; + std::lock_guard lk(gPoolMu); + return gPools[ibDevN].get(); +} + +} // namespace + +// ---- Hooks consumed by ibrc_p2p_adaptor.cc (forward-declared there). ---- +struct ibv_cq *flagcxP2pPoolGetSharedCq(int ibDevN, struct ibv_context *ctx) { + FlagcxWorkerPool *pool = getOrCreatePool(ibDevN, ctx); + return pool ? pool->getSharedCq() : NULL; } -static std::mutex gAsyncWorkerLifecycleMutex; +void flagcxP2pPoolRegisterQp(int ibDevN, void *sendComm, struct ibv_qp *qp) { + if (qp == nullptr) + return; + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool) + pool->registerQp(sendComm, qp); +} + +void flagcxP2pPoolUnregisterQp(int ibDevN, struct ibv_qp *qp) { + if (qp == nullptr) + return; + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool) + pool->unregisterQp(qp); +} -static void ensureAsyncWorkerStarted() { - std::lock_guard lock(gAsyncWorkerLifecycleMutex); - if (gAsyncWorker.thread.joinable() && !gAsyncWorker.stop.load()) - return; // already running - // If previously stopped, join the old thread before restarting - if (gAsyncWorker.thread.joinable()) { - gAsyncWorker.thread.join(); +flagcxResult_t flagcxP2pPoolSubmit(int ibDevN, void *sendComm, + FlagcxSlice **slices, int count) { + FlagcxWorkerPool *pool = lookupPool(ibDevN); + if (pool == nullptr) { + WARN("NET/IB_P2P : flagcxP2pPoolSubmit on uninitialized pool[%d]", ibDevN); + return flagcxInternalError; } - gAsyncWorker.stop.store(false); - gAsyncWorker.thread = std::thread(asyncWorkerFunc); + return pool->submitPostSend(sendComm, slices, count); } -static void stopAsyncWorker() { - std::lock_guard lock(gAsyncWorkerLifecycleMutex); - gAsyncWorker.stop.store(true); - pthread_cond_broadcast(&gAsyncWorker.cv); - if (gAsyncWorker.thread.joinable()) { - gAsyncWorker.thread.join(); +void flagcxP2pPoolStartNotif(int ibDevN, struct ibv_context *ctx, + FlagcxP2pEngine *engine) { + FlagcxWorkerPool *pool = getOrCreatePool(ibDevN, ctx); + if (pool == nullptr) { + WARN("NET/IB_P2P : pool[%d] cannot be created for notif", ibDevN); + return; } + pool->startNotif(engine); } -// Map from transfer ID to async task (for XferStatus polling) -static std::unordered_map> - gAsyncXferMap; -static std::mutex gAsyncXferMutex; +void flagcxP2pPoolStopNotif() { + // Stop notif on whichever pool currently owns it (only one does — the + // first that StartNotif touched). + std::lock_guard lk(gPoolMu); + for (int i = 0; i < MAX_IB_DEVS; i++) { + if (gPools[i]) + gPools[i]->stopNotif(); + } +} + +struct PoolTransferTask { + FlagcxTransferTask fx; + FlagcxP2pConn *conn; + std::atomic postOk{true}; +}; + +static FlagcxP2pCommView *getCommView(void *comm) { + return reinterpret_cast(comm); +} + +static bool buildAndSubmitToPool(PoolTransferTask *task, + const std::vector &dataVec, + const std::vector &sizeVec, + const std::vector &descs, + const std::vector + &localEntries, + int numIovs, void *sendComm, + int connIbDevN, uint8_t opcode) { + for (int i = 0; i < numIovs; i++) { + if (localEntries[i].ibDevN != connIbDevN) { + WARN("NET/IB_P2P : iov[%d] ibDevN mismatch (%d vs conn %d)", i, + localEntries[i].ibDevN, connIbDevN); + for (auto *s : task->fx.sliceList) + s->markFailed(); + return false; + } + auto *localMr = + reinterpret_cast(localEntries[i].mhandle); + uint64_t localVa = (uintptr_t)dataVec[i]; + uint64_t remoteVa = descs[i].addr; + flagcxBuildSlices( + &task->fx, localVa, remoteVa, sizeVec[i], localMr->lkey, + descs[i].rkey, opcode, std::string()); + } + + if (task->fx.sliceList.empty()) { + return false; + } + + flagcxResult_t rc = flagcxP2pPoolSubmit(connIbDevN, sendComm, + task->fx.sliceList.data(), + (int)task->fx.sliceList.size()); + if (rc != flagcxSuccess) { + task->postOk.store(false, std::memory_order_release); + for (auto *s : task->fx.sliceList) + s->markFailed(); + return false; + } + return true; +} + +static std::unordered_map> + gPoolXferMap; +static std::mutex gPoolXferMutex; + static bool findMemReg(uintptr_t addr, FlagcxP2pMemRegEntry *out) { for (std::unordered_map::const_iterator it = @@ -374,15 +927,28 @@ static bool findMemReg(uintptr_t addr, FlagcxP2pMemRegEntry *out) { } static FlagcxP2pMemRegEntry *findMemRegByMr(FlagcxP2pMr mr) { - for (std::unordered_map::iterator it = - gMemRegInfo.begin(); - it != gMemRegInfo.end(); ++it) { - if (it->second.mrId == mr) - return &it->second; - } + std::unordered_map::const_iterator mrIt = + gMrToBaseAddr.find(mr); + if (mrIt == gMrToBaseAddr.end()) + return NULL; + + std::unordered_map::iterator entryIt = + gMemRegInfo.find(mrIt->second); + if (entryIt != gMemRegInfo.end()) + return &entryIt->second; + return NULL; } +static bool memRegContains(const FlagcxP2pMemRegEntry &entry, uintptr_t addr, + size_t size) { + if (addr < entry.baseAddr) + return false; + + const uintptr_t offset = addr - entry.baseAddr; + return offset <= entry.size && size <= entry.size - offset; +} + static int resolveIbDevN(int netDev) { if (netDev < 0 || netDev >= flagcxNMergedIbDevs) return 0; @@ -1048,8 +1614,8 @@ FlagcxP2pEngine *flagcxP2pEngineCreate() { #endif } - if (engine->notifListenActive) { - engine->notifThread = std::thread(notifPollThreadFunc, engine); + if (engine->notifListenActive && flagcxNIbDevs > 0) { + flagcxP2pPoolStartNotif(0, flagcxIbDevs[0].context, engine); } return engine; } @@ -1058,16 +1624,12 @@ void flagcxP2pEngineDestroy(FlagcxP2pEngine *engine) { if (engine == NULL) return; - stopAsyncWorker(); - engine->stopNotif = true; if (engine->notifListenActive) { flagcxSocketClose(&engine->notifListenSock); engine->notifListenActive = false; } - if (engine->notifThread.joinable()) { - engine->notifThread.join(); - } + flagcxP2pPoolStopNotif(); { std::lock_guard lock(engine->notifPeerMutex); @@ -1113,6 +1675,7 @@ void flagcxP2pEngineDestroy(FlagcxP2pEngine *engine) { engine->adaptor->deregMr(&devCtx, it->second.mhandle); } gMemRegInfo.clear(); + gMrToBaseAddr.clear(); } if (engine->topoMgr) { @@ -1308,6 +1871,7 @@ int flagcxP2pEngineReg(FlagcxP2pEngine *engine, uintptr_t data, size_t size, gMemRegInfo.find(data); if (existing != gMemRegInfo.end()) { mrId = existing->second.mrId; + gMrToBaseAddr[mrId] = existing->first; return 0; } @@ -1337,6 +1901,7 @@ int flagcxP2pEngineReg(FlagcxP2pEngine *engine, uintptr_t data, size_t size, } gMemRegInfo[data] = entry; + gMrToBaseAddr[entry.mrId] = data; mrId = entry.mrId; return 0; } @@ -1346,18 +1911,24 @@ void flagcxP2pEngineMrDestroy(FlagcxP2pEngine *engine, FlagcxP2pMr mr) { return; std::lock_guard lock(gMemMutex); - for (std::unordered_map::iterator it = - gMemRegInfo.begin(); - it != gMemRegInfo.end(); ++it) { - if (it->second.mrId == mr) { - struct { - int ibDevN; - } devCtx = {it->second.ibDevN}; - engine->adaptor->deregMr(&devCtx, it->second.mhandle); - gMemRegInfo.erase(it); - return; - } + std::unordered_map::iterator mrIt = + gMrToBaseAddr.find(mr); + if (mrIt == gMrToBaseAddr.end()) + return; + + std::unordered_map::iterator entryIt = + gMemRegInfo.find(mrIt->second); + if (entryIt == gMemRegInfo.end()) { + gMrToBaseAddr.erase(mrIt); + return; } + + struct { + int ibDevN; + } devCtx = {entryIt->second.ibDevN}; + engine->adaptor->deregMr(&devCtx, entryIt->second.mhandle); + gMemRegInfo.erase(entryIt); + gMrToBaseAddr.erase(mrIt); } int flagcxP2pEnginePrepareDesc(FlagcxP2pEngine *engine, FlagcxP2pMr mr, @@ -1457,49 +2028,91 @@ int flagcxP2pEngineReadVector(FlagcxP2pConn *conn, std::vector descs, int numIovs, uint64_t *transferId, std::vector ipcBufs) { - (void)mrIds; - if (conn == NULL || numIovs <= 0 || transferId == NULL) + if (conn == NULL || numIovs <= 0 || transferId == NULL) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: invalid args (conn=%p, " + "numIovs=%d, transferId=%p)\n", + conn, numIovs, (void *)transferId); return -1; + } + + if (dstVec.size() < static_cast(numIovs) || + sizeVec.size() < static_cast(numIovs) || + descs.size() < static_cast(numIovs)) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: vector length mismatch " + "(numIovs=%d)\n", + numIovs); + return -1; + } if (conn->isLocal && (conn->sameProcess || !ipcBufs.empty())) { - return startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, transferId, - ipcBufs, false); + fprintf(stderr, + "[FlagCX P2P] ReadVector taking local transfer path: numIovs=%d\n", + numIovs); + int rc = startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, + transferId, ipcBufs, false); + fprintf(stderr, "[FlagCX P2P] ReadVector local transfer returned: rc=%d\n", + rc); + return rc; + } + + if (mrIds.size() < static_cast(numIovs)) { + fprintf(stderr, + "[FlagCX P2P] ReadVector early exit: mrIds length mismatch " + "(numIovs=%d)\n", + numIovs); + return -1; } std::vector localEntries(numIovs); { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { - if (!findMemReg((uintptr_t)dstVec[i], &localEntries[i])) + FlagcxP2pMemRegEntry *entry = findMemRegByMr(mrIds[i]); + if (entry == NULL) { + fprintf( + stderr, + "[FlagCX P2P] ReadVector memReg lookup failed: iov=%d, mr=%lu\n", i, + (unsigned long)mrIds[i]); + return -1; + } + + if (!memRegContains(*entry, reinterpret_cast(dstVec[i]), + sizeVec[i])) { + fprintf(stderr, + "[FlagCX P2P] ReadVector memReg bounds check failed: iov=%d, " + "mr=%lu, addr=%p, size=%zu\n", + i, (unsigned long)mrIds[i], dstVec[i], sizeVec[i]); return -1; + } + + localEntries[i] = *entry; } } - ensureAsyncWorkerStarted(); - - auto task = std::make_shared(); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + auto task = std::make_shared(); task->conn = conn; - task->op = ASYNC_XFER_READ; - task->numIovs = numIovs; - task->dataVec = std::move(dstVec); - task->sizeVec = std::move(sizeVec); - task->descs = std::move(descs); - task->localEntries = std::move(localEntries); - - const uint64_t xferId = [&] { - std::lock_guard lock(gAsyncXferMutex); - uint64_t id = gNextXferId++; - gAsyncXferMap[id] = task; - return id; - }(); - { - pthread_mutex_lock(&gAsyncWorker.mutex); - gAsyncWorker.queue.push_back(task); - pthread_mutex_unlock(&gAsyncWorker.mutex); + if (!buildAndSubmitToPool(task.get(), dstVec, sizeVec, descs, localEntries, + numIovs, conn->sendComm, connIbDevN, + FLAGCX_SLICE_OP_READ)) { + // sentinel so isAllDone() converges (needs total>0) + auto *sentinel = new FlagcxSlice{0, 0, 0, 0, 0, FLAGCX_SLICE_OP_READ, + std::string(), &task->fx, nullptr}; + task->fx.sliceList.push_back(sentinel); + task->fx.sliceCount.fetch_add(1, std::memory_order_release); + sentinel->markFailed(); + task->postOk.store(false, std::memory_order_release); } - pthread_cond_signal(&gAsyncWorker.cv); + uint64_t xferId; + { + std::lock_guard lock(gPoolXferMutex); + xferId = gNextXferId++; + gPoolXferMap[xferId] = task; + } *transferId = xferId; return 0; } @@ -1571,49 +2184,59 @@ int flagcxP2pEngineWriteVector(FlagcxP2pConn *conn, std::vector descs, int numIovs, uint64_t *transferId, std::vector ipcBufs) { - (void)mrIds; if (conn == NULL || numIovs <= 0 || transferId == NULL) return -1; + if (dstVec.size() < static_cast(numIovs) || + sizeVec.size() < static_cast(numIovs) || + descs.size() < static_cast(numIovs)) + return -1; + if (conn->isLocal && (conn->sameProcess || !ipcBufs.empty())) { return startLocalTransfer(conn, dstVec, sizeVec, descs, numIovs, transferId, ipcBufs, true); } + if (mrIds.size() < static_cast(numIovs)) + return -1; + std::vector localEntries(numIovs); { std::lock_guard memLock(gMemMutex); for (int i = 0; i < numIovs; i++) { - if (!findMemReg((uintptr_t)dstVec[i], &localEntries[i])) + FlagcxP2pMemRegEntry *entry = findMemRegByMr(mrIds[i]); + if (entry == NULL) return -1; + + if (!memRegContains(*entry, reinterpret_cast(dstVec[i]), + sizeVec[i])) + return -1; + + localEntries[i] = *entry; } } - ensureAsyncWorkerStarted(); - - auto task = std::make_shared(); + const int connIbDevN = getCommView(conn->sendComm)->ibDevN; + auto task = std::make_shared(); task->conn = conn; - task->op = ASYNC_XFER_WRITE; - task->numIovs = numIovs; - task->dataVec = std::move(dstVec); - task->sizeVec = std::move(sizeVec); - task->descs = std::move(descs); - task->localEntries = std::move(localEntries); - - const uint64_t xferId = [&] { - std::lock_guard lock(gAsyncXferMutex); - uint64_t id = gNextXferId++; - gAsyncXferMap[id] = task; - return id; - }(); - { - pthread_mutex_lock(&gAsyncWorker.mutex); - gAsyncWorker.queue.push_back(task); - pthread_mutex_unlock(&gAsyncWorker.mutex); + if (!buildAndSubmitToPool(task.get(), dstVec, sizeVec, descs, localEntries, + numIovs, conn->sendComm, connIbDevN, + FLAGCX_SLICE_OP_WRITE)) { + auto *sentinel = new FlagcxSlice{0, 0, 0, 0, 0, FLAGCX_SLICE_OP_WRITE, + std::string(), &task->fx, nullptr}; + task->fx.sliceList.push_back(sentinel); + task->fx.sliceCount.fetch_add(1, std::memory_order_release); + sentinel->markFailed(); + task->postOk.store(false, std::memory_order_release); } - pthread_cond_signal(&gAsyncWorker.cv); + uint64_t xferId; + { + std::lock_guard lock(gPoolXferMutex); + xferId = gNextXferId++; + gPoolXferMap[xferId] = task; + } *transferId = xferId; return 0; } @@ -1655,13 +2278,16 @@ bool flagcxP2pEngineXferStatus(FlagcxP2pConn *conn, uint64_t transferId) { if (conn == NULL) return true; - // Check async transfer map first (for vectored transfers) { - std::lock_guard lock(gAsyncXferMutex); - auto it = gAsyncXferMap.find(transferId); - if (it != gAsyncXferMap.end()) { - if (it->second->done.load(std::memory_order_acquire)) { - gAsyncXferMap.erase(it); + std::lock_guard lock(gPoolXferMutex); + auto it = gPoolXferMap.find(transferId); + if (it != gPoolXferMap.end()) { + auto &task = it->second; + if (task->fx.isAllDone()) { + for (auto *s : task->fx.sliceList) + delete s; + task->fx.sliceList.clear(); + gPoolXferMap.erase(it); return true; } return false; diff --git a/flagcx/include/flagcx_p2p.h b/flagcx/include/flagcx_p2p.h index 42226526..3a5be2b8 100644 --- a/flagcx/include/flagcx_p2p.h +++ b/flagcx/include/flagcx_p2p.h @@ -13,9 +13,11 @@ #ifndef FLAGCX_P2P_H_ #define FLAGCX_P2P_H_ +#include #include #include #include +#include #include /* ------------------------------------------------------------------ */ @@ -26,6 +28,8 @@ #define FLAGCX_P2P_DESC_SIZE 64 #define FLAGCX_P2P_IPC_INFO_SIZE 128 +constexpr int kFlagcxP2pMaxQpsPerEngine = 8; + /* ------------------------------------------------------------------ */ /* Opaque handle types */ /* ------------------------------------------------------------------ */ @@ -87,6 +91,51 @@ inline void flagcxP2pDeserializeRdmaDesc(const char *buf, std::memcpy(desc->padding, buf + 32, sizeof(desc->padding)); } +/* ------------------------------------------------------------------ */ +/* Slice / transfer-task types (shared by engine and adaptor) */ +/* ------------------------------------------------------------------ */ + +struct FlagcxTransferTask { + std::atomic sliceCount{0}; + std::atomic doneSliceCount{0}; + std::vector sliceList; + + bool isAllDone() const { + auto total = sliceCount.load(std::memory_order_acquire); + auto done = doneSliceCount.load(std::memory_order_acquire); + return total > 0 && done >= total; + } +}; + +enum FlagcxSliceOp : uint8_t { + FLAGCX_SLICE_OP_WRITE = 0, + FLAGCX_SLICE_OP_READ = 1, +}; + +struct FlagcxSlice { + // WRITE: local source VA; READ: local destination VA. + uint64_t srcVa = 0; + // WRITE: remote destination VA; READ: remote source VA. + uint64_t dstVa; + uint32_t length; + uint32_t lkey; + uint32_t rkey; + uint8_t opcode; + std::string peerNicPath; + FlagcxTransferTask *task; + volatile int *qpDepth; + + inline void markSuccess() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } + + inline void markFailed() { + if (task) + task->doneSliceCount.fetch_add(1, std::memory_order_release); + } +}; + /* ------------------------------------------------------------------ */ /* Notification message */ /* ------------------------------------------------------------------ */ @@ -424,4 +473,49 @@ int flagcxP2pEngineGetIpcInfo(FlagcxP2pEngine *engine, uintptr_t addr, int flagcxP2pEngineUpdateIpcInfo(char *ipcBuf, uintptr_t addr, uintptr_t baseAddr, size_t size); +/* ================================================================== */ +/* Global runtime configuration */ +/* ================================================================== */ + +struct FlagcxP2pGlobalConfig { + /* Worker pool / QP topology */ + int qpsPerConn = 4; /* FLAGCX_P2P_QPS_PER_CONN */ + int workersPerPool = 4; /* FLAGCX_P2P_WORKERS_PER_POOL */ + int shardCount = 8; /* FLAGCX_P2P_SHARD_COUNT */ + + /* CQ / WR / completion-queue depth */ + size_t sharedCqDepth = 4096; /* FLAGCX_P2P_CQ_DEPTH */ + size_t maxWrPerPost = 256; /* FLAGCX_P2P_MAX_WR_PER_POST */ + size_t maxRequests = 256; /* FLAGCX_P2P_MAX_REQUESTS */ + size_t batchPollSize = 64; /* FLAGCX_P2P_BATCH_POLL_SIZE */ + + /* Slice cut policy */ + size_t sliceSize = 64 * 1024; /* FLAGCX_P2P_SLICE_SIZE */ + size_t fragmentLimit = 4 * 1024; /* FLAGCX_P2P_FRAGMENT_LIMIT */ + + /* IB QP attributes — verbs-clean (plain int) so this header does + not pull . */ + size_t maxSge = 4; /* FLAGCX_P2P_MAX_SGE */ + size_t maxInline = 64; /* FLAGCX_P2P_MAX_INLINE */ + uint8_t ibPort = 1; /* FLAGCX_P2P_IB_PORT */ + int gidIndex = -1; /* FLAGCX_P2P_GID_INDEX (-1=auto) */ + int mtuLength = 4096; /* FLAGCX_P2P_MTU */ + int ibTrafficClass = -1; /* FLAGCX_P2P_IB_TC (-1=off) */ + int retryCnt = 7; /* FLAGCX_P2P_RETRY_CNT */ + + /* Notification */ + int notifMaxPeers = 64; /* FLAGCX_P2P_NOTIF_MAX_PEERS */ + + /* Misc */ + bool enableDestDeviceAffinity = false; /* FLAGCX_P2P_DEST_DEV_AFFINITY */ +}; + +/* Returns the lazy-loaded singleton (mooncake::globalConfig() shape). + First call materializes the struct and parses env vars exactly once. */ +const FlagcxP2pGlobalConfig &flagcxP2pGlobalConfig(); + +/* Logs the resolved config once. Implicitly invoked at first + flagcxP2pGlobalConfig() call. */ +void flagcxP2pDumpGlobalConfig(); + #endif /* FLAGCX_P2P_H_ */