Skip to content

Commit ef29abd

Browse files
committed
gemma4_31b: stage decode positions on device (pos-array + per-step D2D)
Kill the per-decode-round position H2D (the last per-round host->device copy left after Option A): upload the full decode position array to device once (single H2D), then each step copy that step's position from the array into the fixed position input slot with an on-device D2D. Token stays aliased on device (Option A). Per-round HtoD is now 0, independent of decode length; the fixed input slot keeps it cuda-graph-safe (with cuda graph on, the D2D becomes a captured cudaMemcpyAsync on the decode stream into the same slot). Measured (int6/gguf, cuda graph OFF, p19/d128): post-load HtoD 132->5 (per-round H2D=0); DtoD 129->257 (+128 per-round pos d2d, the intended H2D->d2d trade); DtoH unchanged (129). Greedy output byte-identical to prior runs. Runner-only; reuses the int64-output export (no re-export).
1 parent b486c56 commit ef29abd

1 file changed

Lines changed: 32 additions & 11 deletions

File tree

examples/models/gemma4_31b/main.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,27 @@ int main(int argc, char** argv) {
390390
auto decode_pos_cpu = from_blob(
391391
decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long);
392392
#ifdef EXECUTORCH_BUILD_CUDA
393-
// The token input is the aliased on-device output (device_out_token); only
394-
// the position still needs a fixed device buffer refreshed by a per-round
395-
// H2D (one int64 / round).
393+
// Fixed device-resident position input slot: the decode method always reads
394+
// the position from this same address every step (cuda-graph-safe). Seeded
395+
// once here with a one-time H2D; refreshed each step by an on-device D2D.
396396
auto decode_pos = clone_tensor_ptr_to(decode_pos_cpu, cuda_device);
397+
// Upload the FULL decode position array to device ONCE (a single H2D - the
398+
// one-time copy we keep). Each step copies its position from here into the
399+
// fixed slot with a device-to-device copy, so there is NO per-round pos H2D.
400+
std::vector<int64_t> pos_seq_data(FLAGS_max_new_tokens);
401+
for (int32_t i = 0; i < FLAGS_max_new_tokens; i++) {
402+
pos_seq_data[i] = num_prompt_tokens + i;
403+
}
404+
auto pos_seq_dev = clone_tensor_ptr_to(
405+
from_blob(
406+
pos_seq_data.data(),
407+
{S(FLAGS_max_new_tokens)},
408+
executorch::aten::ScalarType::Long),
409+
cuda_device);
410+
auto* pos_seq_dev_ptr =
411+
static_cast<int64_t*>(pos_seq_dev->mutable_data_ptr());
412+
auto* decode_pos_slot_ptr =
413+
static_cast<int64_t*>(decode_pos->mutable_data_ptr());
397414
#else
398415
// Non-CUDA (MLX) path: keep host token/pos buffers; the backend stages them
399416
// and the host samples from the returned logits.
@@ -406,19 +423,23 @@ int main(int argc, char** argv) {
406423
uint64_t prev_token = cur_token;
407424
bool hit_eos = eos_ids.find(cur_token) != eos_ids.end();
408425
for (int32_t step = 0; step < FLAGS_max_new_tokens && !hit_eos; step++) {
409-
decode_pos_data[0] = pos;
410-
411426
#ifdef EXECUTORCH_BUILD_CUDA
412-
// Token stays on device (aliased from the previous forward's output); only
413-
// the 8-byte position is uploaded each round. No token D2H->H2D round-trip.
427+
// No per-round H2D: copy this step's position from the pre-uploaded device
428+
// position array into the fixed position slot with an on-device D2D. With
429+
// the token aliased on device (Option A) and the position staged via D2D,
430+
// the per-round HtoD count is zero (independent of decode length).
431+
// cudaMemcpy D2D is host-synchronous, so the slot is updated before the
432+
// decode kernels read it; with cuda graph enabled this becomes a captured
433+
// cudaMemcpyAsync on the decode stream into this same fixed slot.
414434
ET_CHECK_MSG(
415435
cudaMemcpy(
416-
decode_pos->mutable_data_ptr(),
417-
decode_pos_data.data(),
436+
decode_pos_slot_ptr,
437+
pos_seq_dev_ptr + step,
418438
sizeof(int64_t),
419-
cudaMemcpyHostToDevice) == cudaSuccess,
420-
"Failed to upload decode position H2D");
439+
cudaMemcpyDeviceToDevice) == cudaSuccess,
440+
"Failed to copy decode position D2D");
421441
#else
442+
decode_pos_data[0] = pos;
422443
decode_token_data[0] = static_cast<int64_t>(cur_token);
423444
#endif
424445

0 commit comments

Comments
 (0)