feat: use async queue to enable the ovelaping of weight loading and RDMA transferring#16
Conversation
b4c1717 to
5c1bb4f
Compare
|
I just have one concern about the logic: for a q_proj, k_proj, v_proj this will call the async task 3 times and there is no guarantee that the right weight will be transferred in order. We need to only submit those parameter once all the shards are updated. Let's combine the two PRs to achieve this. Otherwise, if this async queue impl is clearly better, we should remove the old execute_each. What was blocking by the way? is it about the async execute not having it's own async loop? |
JD-ETH
left a comment
There was a problem hiding this comment.
let's discuss tomorrow and try to merge them together, otherwise the results will be wrong.
| if self.pipelined_transfer: | ||
| transfer_bundle.execute_each(updated_name) | ||
| # Use executable queue for async transfer operations | ||
| transfer_bundle.execute_each(updated_name, self.executable_queue) |
There was a problem hiding this comment.
let's default to execute with the queue if it's clearly better
Agree. The updating order does matter. But it should not a problem here according to the reply from letian
I think it is. Not sure why the |
JD-ETH
left a comment
There was a problem hiding this comment.
good for me to merge right now to get better profiling numbers
Description
Before this PR, the
update_weightsin RDMA mode is sth like :. while
execute_each()is actually an asynchronous operation. After theload_weightsof the model_replica is done, the related updated weights could be transferred, and another roundload_weightscould be started. That's what this pr does: build a queue to execute the RDMA transferring in a asynchronous way, and reduce the latency of the single_update_bucket_weights_from_remote()from 70+ ms ->16ms , with a 10%-20% e2e time cost saving of updating weightsBefore :

After:
