diff --git a/flagcx/runner/c2c_algo.cc b/flagcx/runner/c2c_algo.cc index 47cf06171..a281f3958 100644 --- a/flagcx/runner/c2c_algo.cc +++ b/flagcx/runner/c2c_algo.cc @@ -2368,8 +2368,22 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, ? const_cast(sendbuff) : recvTmpBuff; - flagcxStream_t het_stream; - deviceAdaptor->streamCreate(&het_stream); + flagcxStream_t hetStream; + FLAGCXCHECK(deviceAdaptor->streamCreate(&hetStream)); + + // Create events for async stream synchronization (avoid CPU-blocking + // streamSync). + // syncEvent/hetSyncEvent: used for cross-stream barriers at end of each + // pipeline iteration. innerSyncEvent: used inside the per-step inner loop to + // let hetStream wait for stream — kept separate so the cross-wait re-record + // of syncEvent cannot update a pending inner-loop streamWaitEvent on + // hetStream (which would create a circular dependency / deadlock). + flagcxEvent_t syncEvent, hetSyncEvent, innerSyncEvent; + FLAGCXCHECK(deviceAdaptor->eventCreate(&syncEvent, flagcxEventDisableTiming)); + FLAGCXCHECK( + deviceAdaptor->eventCreate(&hetSyncEvent, flagcxEventDisableTiming)); + FLAGCXCHECK( + deviceAdaptor->eventCreate(&innerSyncEvent, flagcxEventDisableTiming)); // execute sequential preHomoFunc steps cclAdaptors[flagcxCCLAdaptorDevice]->groupStart(); @@ -2382,7 +2396,6 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, } } cclAdaptors[flagcxCCLAdaptorDevice]->groupEnd(); - deviceAdaptor->streamSynchronize(stream); // execute pipelined preHomoFunc and heteroFunc steps // execute refreshFunc @@ -2395,7 +2408,10 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, } } refreshFunc_.run(recvbuff, scratchBuffer_, datatype, stream); - deviceAdaptor->streamSynchronize(stream); + // Record event on stream after refreshFunc; hetStream waits before using + // refreshed data + FLAGCXCHECK(deviceAdaptor->eventRecord(syncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, syncEvent)); for (int s = 0; s < nPipePreSteps_; ++s) { cclAdaptors[flagcxCCLAdaptorDevice]->groupStart(); for (int i = 0; i < preHomoFuncSteps_[nSeqPreSteps_ + s].size(); ++i) { @@ -2407,28 +2423,35 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, cclAdaptors[flagcxCCLAdaptorDevice]->groupEnd(); flagcxHeteroGroupStart(); for (int i = 0; i < heteroFuncSteps_[s].size(); ++i) { - // TODO: use stream wait rather than stream sync to avoid cpu blocking - // deviceAdaptor->streamSynchronize(stream); + // Ensure hetStream waits for preHomoFunc completion on stream. + // Use innerSyncEvent (not syncEvent) to avoid the cross-wait re-record of + // syncEvent updating this pending wait — which would create a circular + // dependency: hetStream waits for A_cross, stream waits for B1 from + // hetStream, hetStream can't produce B1 until past A_cross. Deadlock. + FLAGCXCHECK(deviceAdaptor->eventRecord(innerSyncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, innerSyncEvent)); // execute heteroFuncs heteroFuncSteps_[s][i].run(sendTmpBuff, recvTmpBuff, datatype, comm_, - het_stream); + hetStream); if (homoInterFuncSteps_[s].size() > i) { - // TODO: use stream wait rather than stream sync to avoid cpu blocking - deviceAdaptor->streamSynchronize(het_stream); - // execute homoInterFuncs homoInterFuncSteps_[s][i].run( sendbuff, recvbuff, scratchBuffer_, datatype, redOp_, - comm_->globalRank2HomoRank[root], comm_, het_stream); + comm_->globalRank2HomoRank[root], comm_, hetStream); + // Ensure stream waits for hetStream before refreshFunc + FLAGCXCHECK(deviceAdaptor->eventRecord(hetSyncEvent, hetStream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(stream, hetSyncEvent)); refreshFunc_.run(recvbuff, scratchBuffer_, datatype, stream); } } flagcxHeteroGroupEnd(); - // todo: double-check the synchronization logic - deviceAdaptor->streamSynchronize(stream); - deviceAdaptor->streamSynchronize(het_stream); + // Cross-wait both streams before next pipeline iteration + FLAGCXCHECK(deviceAdaptor->eventRecord(syncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->eventRecord(hetSyncEvent, hetStream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(stream, hetSyncEvent)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, syncEvent)); } // execute sequential heteroFunc steps @@ -2440,17 +2463,12 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, refreshFunc_.run(recvbuff, scratchBuffer_, datatype, stream); } - // TODO: use stream wait rather than stream sync to avoid cpu blocking - // deviceAdaptor->streamSynchronize(stream); - - // execute heteroFuncs + // execute heteroFuncs (all run on stream here; same-stream FIFO provides + // ordering — no cross-stream event needed) heteroFuncSteps_[nPipePreSteps_ + s][i].run(sendTmpBuff, recvTmpBuff, datatype, comm_, stream); if (homoInterFuncSteps_[nPipePreSteps_ + s].size() > i) { - // TODO: use stream wait rather than stream sync to avoid cpu blocking - deviceAdaptor->streamSynchronize(stream); - // execute homoInterFuncs homoInterFuncSteps_[nPipePreSteps_ + s][i].run( sendbuff, recvbuff, scratchBuffer_, datatype, redOp_, @@ -2462,7 +2480,9 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, } } } - deviceAdaptor->streamSynchronize(stream); + // Record event on stream for downstream hetStream dependency + FLAGCXCHECK(deviceAdaptor->eventRecord(syncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, syncEvent)); // execute pipelined heteroFunc and postHomoFunc steps for (int s = 0; s < nPipePostSteps_; ++s) { @@ -2479,28 +2499,30 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, for (int i = 0; i < heteroFuncSteps_[nPipePreSteps_ + nSeqInterSteps_ + s].size(); ++i) { - // TODO: use stream wait rather than stream sync to avoid cpu blocking - // deviceAdaptor->streamSynchronize(stream); + // Ensure hetStream waits for postHomoFunc completion on stream. + // Use innerSyncEvent for the same reason as in pipelined pre-steps. + FLAGCXCHECK(deviceAdaptor->eventRecord(innerSyncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, innerSyncEvent)); // execute heteroFuncs heteroFuncSteps_[nPipePreSteps_ + nSeqInterSteps_ + s][i].run( - sendTmpBuff, recvTmpBuff, datatype, comm_, het_stream); + sendTmpBuff, recvTmpBuff, datatype, comm_, hetStream); if (homoInterFuncSteps_[nPipePreSteps_ + nSeqInterSteps_ + s].size() > i) { - // TODO: use stream wait rather than stream sync to avoid cpu blocking - deviceAdaptor->streamSynchronize(het_stream); - // execute homoInterFuncs homoInterFuncSteps_[nPipePreSteps_ + nSeqInterSteps_ + s][i].run( sendbuff, recvbuff, scratchBuffer_, datatype, redOp_, - comm_->globalRank2HomoRank[root], comm_, het_stream); + comm_->globalRank2HomoRank[root], comm_, hetStream); } } flagcxHeteroGroupEnd(); - deviceAdaptor->streamSynchronize(stream); - deviceAdaptor->streamSynchronize(het_stream); + // Cross-wait both streams before next pipeline iteration + FLAGCXCHECK(deviceAdaptor->eventRecord(syncEvent, stream)); + FLAGCXCHECK(deviceAdaptor->eventRecord(hetSyncEvent, hetStream)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(stream, hetSyncEvent)); + FLAGCXCHECK(deviceAdaptor->streamWaitEvent(hetStream, syncEvent)); } // execute sequential postHomoFunc steps @@ -2526,8 +2548,11 @@ flagcxResult_t flagcxC2cPlanner::execute(const void *sendbuff, void *recvbuff, deviceAdaptor->deviceFree(scratchBuffer_, flagcxMemDevice, stream); } - // destroy temporary hetero comm stream - deviceAdaptor->streamDestroy(het_stream); + // destroy sync events and temporary hetero comm stream + FLAGCXCHECK(deviceAdaptor->eventDestroy(innerSyncEvent)); + FLAGCXCHECK(deviceAdaptor->eventDestroy(syncEvent)); + FLAGCXCHECK(deviceAdaptor->eventDestroy(hetSyncEvent)); + FLAGCXCHECK(deviceAdaptor->streamDestroy(hetStream)); return flagcxSuccess; }