Skip to content

Commit d7ca5db

Browse files
authored
Turn on device memory planing as default (#20239)
recreate the PR due to ghexport bot didn't work diff: D107597774 original: #20214
1 parent c5bf380 commit d7ca5db

15 files changed

Lines changed: 399 additions & 313 deletions

File tree

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 132 additions & 206 deletions
Large diffs are not rendered by default.

backends/cuda/runtime/utils.h

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,13 @@ inline void _strided_copy(
147147
}
148148

149149
// Copy data from SlimTensor to ETensor, rearranging if strides differ.
150-
// When stream is non-null, GPU copies use that stream (async fast path).
151-
// When stream is null, GPU copies are synchronous.
150+
// dst_device selects the destination memory space (CPU for D2H, a CUDA device
151+
// for D2D). When stream is non-null, GPU copies use that stream (async fast
152+
// path). When stream is null, GPU copies are synchronous.
152153
inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
153154
const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
154155
executorch::runtime::etensor::Tensor* etensor,
156+
const executorch::backends::aoti::slim::c10::Device& dst_device,
155157
cudaStream_t stream) {
156158
ET_CHECK_OK_OR_RETURN_ERROR(_check_tensor_metadata(slim_tensor, etensor));
157159

@@ -165,7 +167,7 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
165167

166168
if (_strides_match(slim_tensor, etensor)) {
167169
// Fast path: strides match, raw byte copy
168-
if (slim_tensor->is_cpu()) {
170+
if (slim_tensor->is_cpu() && dst_device.is_cpu()) {
169171
std::memcpy(dst_data, src_data, nbytes);
170172
} else if (stream) {
171173
executorch::backends::aoti::slim::DeviceTraits<
@@ -174,23 +176,19 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
174176
dst_data,
175177
src_data,
176178
nbytes,
177-
executorch::backends::aoti::slim::CPU_DEVICE,
179+
dst_device,
178180
slim_tensor->device(),
179181
stream);
180182
} else {
181183
executorch::backends::aoti::slim::DeviceTraits<
182184
executorch::backends::aoti::slim::c10::DeviceType::CUDA>::
183-
memcpy(
184-
dst_data,
185-
src_data,
186-
nbytes,
187-
executorch::backends::aoti::slim::CPU_DEVICE,
188-
slim_tensor->device());
185+
memcpy(dst_data, src_data, nbytes, dst_device, slim_tensor->device());
189186
}
190187
} else {
191188
// Slow path: strides differ (e.g., AOTI delegate output layout differs
192-
// from .pte's dim_order). Copy to a temp CPU buffer, then rearrange
193-
// element-by-element to match the ETensor's expected layout.
189+
// from .pte's dim_order). Copy to a temp CPU buffer, rearrange
190+
// element-by-element to match the ETensor's expected layout, then move the
191+
// result to the destination (CPU stays in place; GPU gets an H2D copy).
194192
std::vector<char> tmp(nbytes);
195193
if (slim_tensor->is_cpu()) {
196194
std::memcpy(tmp.data(), src_data, nbytes);
@@ -218,13 +216,38 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
218216

219217
size_t elem_size = executorch::backends::aoti::slim::c10::elementSize(
220218
slim_tensor->dtype());
221-
_strided_copy(
222-
dst_data,
223-
tmp.data(),
224-
elem_size,
225-
sizes_vec,
226-
src_strides_vec,
227-
dst_strides_vec);
219+
220+
if (dst_device.is_cpu()) {
221+
_strided_copy(
222+
dst_data,
223+
tmp.data(),
224+
elem_size,
225+
sizes_vec,
226+
src_strides_vec,
227+
dst_strides_vec);
228+
} else {
229+
// Rearrange into a CPU staging buffer, then copy to the GPU destination.
230+
std::vector<char> rearranged(nbytes);
231+
_strided_copy(
232+
rearranged.data(),
233+
tmp.data(),
234+
elem_size,
235+
sizes_vec,
236+
src_strides_vec,
237+
dst_strides_vec);
238+
if (stream) {
239+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
240+
dst_data,
241+
rearranged.data(),
242+
nbytes,
243+
cudaMemcpyHostToDevice,
244+
stream));
245+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamSynchronize(stream));
246+
} else {
247+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
248+
dst_data, rearranged.data(), nbytes, cudaMemcpyHostToDevice));
249+
}
250+
}
228251
}
229252

230253
return executorch::runtime::Error::Ok;
@@ -251,7 +274,39 @@ inline executorch::runtime::Error copy_slimtensor_to_etensor_async(
251274
const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
252275
executorch::runtime::etensor::Tensor* etensor,
253276
cudaStream_t stream) {
254-
return _copy_slimtensor_to_etensor_impl(slim_tensor, etensor, stream);
277+
return _copy_slimtensor_to_etensor_impl(
278+
slim_tensor,
279+
etensor,
280+
executorch::backends::aoti::slim::CPU_DEVICE,
281+
stream);
282+
}
283+
284+
/**
285+
* Copies data from a SlimTensor to a GPU-resident ETensor asynchronously
286+
* (device-to-device).
287+
*
288+
* Used when the destination ETensor's storage lives in a planned GPU arena.
289+
* The destination device is taken from the source SlimTensor, so this only
290+
* supports same-device D2D copies (source and destination on the same GPU).
291+
*
292+
* When strides match (common case), performs a fast async D2D copy on the
293+
* provided stream. When strides differ, falls back to a staged copy with
294+
* element-by-element rearrangement on the host.
295+
*
296+
* NOTE: In the fast path the copy is asynchronous. The caller must synchronize
297+
* the stream before consuming the ETensor data.
298+
*
299+
* @param slim_tensor Pointer to the source SlimTensor (must not be null).
300+
* @param etensor Pointer to the destination GPU ETensor (must not be null).
301+
* @param stream The CUDA stream to use for async copy.
302+
* @return Error::Ok on success, or an appropriate error code on failure.
303+
*/
304+
inline executorch::runtime::Error copy_slimtensor_to_device_etensor_async(
305+
const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
306+
executorch::runtime::etensor::Tensor* etensor,
307+
cudaStream_t stream) {
308+
return _copy_slimtensor_to_etensor_impl(
309+
slim_tensor, etensor, slim_tensor->device(), stream);
255310
}
256311

257312
/**
@@ -267,7 +322,11 @@ inline executorch::runtime::Error copy_slimtensor_to_etensor_async(
267322
inline executorch::runtime::Error copy_slimtensor_to_etensor(
268323
const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
269324
executorch::runtime::etensor::Tensor* etensor) {
270-
return _copy_slimtensor_to_etensor_impl(slim_tensor, etensor, nullptr);
325+
return _copy_slimtensor_to_etensor_impl(
326+
slim_tensor,
327+
etensor,
328+
executorch::backends::aoti::slim::CPU_DEVICE,
329+
nullptr);
271330
}
272331

273332
/**

backends/cuda/tests/test_cuda_export.py

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -328,18 +328,23 @@ def test_triton_kernel_mode_off(self):
328328

329329
def test_device_info_propagated_to_cuda_delegate_outputs(self):
330330
"""
331-
Test that device info is correctly propagated from export to serialization
332-
for CUDA delegate outputs.
333-
334-
This verifies the device propagation flow:
335-
1. CudaPartitioner adds target_device="cuda:0" CompileSpec
336-
2. PropagateDevicePass sets TensorSpec.device = CUDA for delegate outputs
337-
3. Emitter serializes device info into ExtraTensorInfo.device_type
338-
4. Serialized tensors have device_type = DeviceType.CUDA
339-
340-
Note: At this stage, the tensor memory is still on CPU. The CUDA backend
341-
will copy data to GPU device at runtime. Device info tagging is the first
342-
step toward full device-aware memory allocation.
331+
Verify that, for a CUDA-delegated graph, every memory-planned tensor's
332+
actual planned memory location matches its device_type tag.
333+
334+
With device memory planning (the default), the flow is:
335+
1. CudaPartitioner adds target_device="cuda:0" CompileSpec.
336+
2. PropagateDevicePass tags delegate IO TensorSpecs as CUDA and inserts
337+
et_copy._h2d_copy / _d2h_copy ops at the delegate boundary, so the
338+
method inputs/outputs stay on CPU while the delegate IO is CUDA.
339+
3. Device-aware memory planning allocates each non-CPU tensor into a CUDA
340+
buffer, recorded in ExecutionPlan.non_const_buffer_device.
341+
4. The emitter serializes device info into ExtraTensorInfo.device_type.
342+
343+
The core check: for each planned tensor, the device of the buffer it is
344+
allocated into (non_const_buffer_device) must agree with the tensor's
345+
own device_type. A CUDA-tagged tensor planned into a CPU buffer (or vice
346+
versa) means planning and device tagging disagree about where the
347+
tensor's real memory lives.
343348
"""
344349

345350
class AddModule(torch.nn.Module):
@@ -354,7 +359,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
354359
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
355360
self.assertIsNotNone(edge_program_manager, "CUDA export failed")
356361

357-
# Convert to ExecuTorch and access the serialized program
362+
# Convert to ExecuTorch and access the serialized program. The default
363+
# config enables device memory planning, so delegate IO is GPU-resident.
358364
et_prog = edge_program_manager.to_executorch()
359365
program = et_prog._emitter_output.program
360366

@@ -366,32 +372,60 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
366372
"Expected at least one delegate in the execution plan",
367373
)
368374

369-
# Count tensors by device type
370-
cpu_tensors = []
371-
cuda_tensors = []
372-
375+
# Build buffer_idx -> device map from the per-buffer device mapping.
376+
# Buffers without an entry default to CPU.
377+
buffer_device: dict[int, schema.DeviceType] = {}
378+
for entry in plan.non_const_buffer_device or []:
379+
buffer_device[entry.buffer_idx] = entry.device_type
380+
381+
def tensor_device(t: schema.Tensor) -> schema.DeviceType:
382+
if t.extra_tensor_info is not None:
383+
return t.extra_tensor_info.device_type
384+
return schema.DeviceType.CPU
385+
386+
# Walk every memory-planned tensor in the graph and assert its declared
387+
# device_type matches the device of the buffer it lives in.
388+
cuda_planned = 0
389+
cpu_planned = 0
373390
for value in plan.values:
374-
if isinstance(value.val, schema.Tensor):
375-
tensor = value.val
376-
if (
377-
tensor.extra_tensor_info is not None
378-
and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA
379-
):
380-
cuda_tensors.append(tensor)
381-
else:
382-
# Either no extra_tensor_info or device_type is CPU (default)
383-
cpu_tensors.append(tensor)
384-
385-
# Both input and output tensors should be on CUDA device for now.
391+
if not isinstance(value.val, schema.Tensor):
392+
continue
393+
tensor = value.val
394+
# Only memory-planned (non-constant) tensors have allocation_info;
395+
# their memory_id indexes into the non_const buffers.
396+
if tensor.allocation_info is None:
397+
continue
398+
399+
declared = tensor_device(tensor)
400+
mem_id = tensor.allocation_info.memory_id
401+
planned = buffer_device.get(mem_id, schema.DeviceType.CPU)
402+
403+
self.assertEqual(
404+
planned,
405+
declared,
406+
f"Tensor planned into buffer {mem_id} has device_type="
407+
f"{declared.name} but the buffer is allocated on "
408+
f"{planned.name}; planned memory location and device tag "
409+
f"must agree.",
410+
)
411+
if declared == schema.DeviceType.CUDA:
412+
cuda_planned += 1
413+
else:
414+
cpu_planned += 1
415+
416+
# AddModule has 2 inputs + 1 output. With device memory planning the
417+
# delegate IO is CUDA-resident (2 h2d copies + 1 delegate output) and
418+
# the host-side method inputs/outputs stay on CPU (2 inputs + 1 d2h
419+
# output), giving exactly 3 CUDA- and 3 CPU-resident planned tensors.
386420
self.assertEqual(
387-
len(cpu_tensors),
388-
0,
389-
f"Expected no CPU tensors: method inputs/outputs should be tagged "
390-
f"CUDA, but found {len(cpu_tensors)}",
421+
cuda_planned,
422+
3,
423+
f"Expected exactly 3 CUDA-resident planned tensors (2 h2d copies + "
424+
f"1 delegate output), but found {cuda_planned}.",
391425
)
392426
self.assertEqual(
393-
len(cuda_tensors),
427+
cpu_planned,
394428
3,
395-
f"Expected 3 CUDA tensors (2 method inputs + 1 method output), "
396-
f"but found {len(cuda_tensors)}",
429+
f"Expected exactly 3 CPU-resident planned tensors (2 method inputs "
430+
f"+ 1 d2h output), but found {cpu_planned}.",
397431
)

examples/models/gemma4_31b/export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def _export_cuda(
268268
do_quant_fusion_and_const_prop=True,
269269
memory_planning_pass=MemoryPlanningPass(
270270
alloc_graph_input=False,
271-
share_mutable_buffers=True,
272271
),
273272
emit_mutable_buffer_names=True,
274273
),

examples/models/gemma4_31b/main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ int main(int argc, char** argv) {
158158
Module::LoadMode::MmapUseMlockIgnoreErrors,
159159
/*event_tracer=*/nullptr,
160160
/*memory_allocator=*/nullptr,
161-
/*temp_allocator=*/nullptr,
162-
/*share_memory_arenas=*/true);
161+
/*temp_allocator=*/nullptr);
163162

164163
// Get metadata
165164
auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get());

examples/models/gemma4_31b/model.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,13 @@ Decoder norms per layer: `input_layernorm`, `post_attention_layernorm`,
109109
| `decode` | tokens `(1, 1)` + input_pos `(1,)` + temperature `(1,)` | `(1, 1)` float |
110110
| `prefill` | tokens `(1, T)` + input_pos `(T,)` + temperature `(1,)`, T∈[5, min(max_seq_len-1, 2×sliding_window)] | `(1, 1)` float |
111111

112-
Both methods share the same KV-cache buffers via
113-
`MemoryPlanningPass(share_mutable_buffers=True)` and
114-
`emit_mutable_buffer_names=True`. The exported program performs Gumbel-max
115-
sampling on-device and returns a single token ID per call so the C++ runner
116-
only has to feed tokens.
112+
Both methods share the same KV-cache buffers. On the CUDA/AOTI backend the
113+
stateful buffers are lifted into the delegate as constants and shared across
114+
`decode`/`prefill` at runtime via the backend's per-FQN buffer cache, so the
115+
CUDA export leaves `share_mutable_buffers` off (other backends, e.g. MLX, instead
116+
share graph-level buffers via `share_mutable_buffers`). The exported program
117+
performs Gumbel-max sampling on-device and returns a single token ID per call so
118+
the C++ runner only has to feed tokens.
117119

118120
### MLX (`--backend mlx`)
119121

examples/models/gemma4_31b/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
Gemma 4 31B-IT — export-friendly reference implementation for ExecuTorch.
99
1010
Model definition designed for torch.export(strict=True) with the CUDA backend.
11-
All stateful buffers (KV cache, RoPE inv_freq) are registered buffers so they
12-
are captured by share_mutable_buffers across prefill/decode. The numerically
11+
All stateful buffers (KV cache, RoPE inv_freq) are registered buffers with
12+
in-place updates. On the CUDA/AOTI backend they are lifted into the delegate as
13+
constants and shared across prefill/decode at runtime via the backend's per-FQN
14+
buffer cache (so the CUDA export leaves share_mutable_buffers off); backends that
15+
keep these buffers at the graph level (e.g. MLX) instead share them via
16+
share_mutable_buffers. The numerically
1317
sensitive primitives — RMSNorm, GELU-tanh MLP, proportional/full RoPE, and
1418
the BHSD KV cache — are imported from ``examples.models.gemma4.text_decoder``
1519
so the 31B and E2B/E4B paths share them.

examples/models/qwen3_5_moe/export.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,10 @@ def _materialize_buffers(model, config):
623623
624624
Replaces meta buffers with real tensors on CPU, recomputes RoPE
625625
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
626-
state) are zero-initialized registered buffers that will be shared
627-
across methods via share_mutable_buffers.
626+
state) are zero-initialized registered buffers. On the CUDA/AOTI backend
627+
they are lifted into the delegate as constants and shared across methods at
628+
runtime via the backend's per-FQN buffer cache; backends that keep them at
629+
the graph level instead share them via share_mutable_buffers.
628630
"""
629631
# Masks stay bool, inv_freq stays float32.
630632
for fqn, buf in list(model.named_buffers()):
@@ -922,8 +924,12 @@ def _export_cuda(model, config, args):
922924
via fused_moe_batched_gemm, with dynamic sequence length.
923925
924926
Both methods share mutable state buffers (KV cache, conv_state,
925-
recurrent_state) via share_mutable_buffers=True. The model uses
926-
registered buffers with in-place updates — no state in/out args.
927+
recurrent_state): the model uses registered buffers with in-place
928+
updates (no state in/out args). On the CUDA/AOTI backend these buffers
929+
are lifted into the delegate as constants and shared across the
930+
decode/prefill methods at runtime via the backend's per-FQN buffer cache
931+
(share_mutable_buffers is left off for CUDA); backends that keep them at
932+
the graph level instead share them via share_mutable_buffers.
927933
"""
928934
import torch._inductor.config as inductor_config
929935

@@ -1031,10 +1037,7 @@ def _export_cuda(model, config, args):
10311037
config=ExecutorchBackendConfig(
10321038
extract_delegate_segments=True,
10331039
do_quant_fusion_and_const_prop=True,
1034-
memory_planning_pass=MemoryPlanningPass(
1035-
alloc_graph_input=False,
1036-
share_mutable_buffers=True,
1037-
),
1040+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
10381041
emit_mutable_buffer_names=True,
10391042
),
10401043
)

examples/models/qwen3_5_moe/main.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,6 @@ int main(int argc, char** argv) {
144144

145145
stats.model_load_start_ms = llm::time_in_ms();
146146

147-
// Create Module with share_memory_arenas=true so prefill and decode
148-
// share mutable buffers (KV cache, conv_state, recurrent_state).
149147
std::vector<std::string> data_files;
150148
if (!FLAGS_data_path.empty()) {
151149
data_files.push_back(FLAGS_data_path);
@@ -156,8 +154,7 @@ int main(int argc, char** argv) {
156154
Module::LoadMode::File,
157155
/*event_tracer=*/nullptr,
158156
/*memory_allocator=*/nullptr,
159-
/*temp_allocator=*/nullptr,
160-
/*share_memory_arenas=*/true);
157+
/*temp_allocator=*/nullptr);
161158

162159
// Get metadata
163160
auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get());

0 commit comments

Comments
 (0)