@@ -147,11 +147,13 @@ inline void _strided_copy(
147147}
148148
149149// Copy data from SlimTensor to ETensor, rearranging if strides differ.
150- // When stream is non-null, GPU copies use that stream (async fast path).
151- // When stream is null, GPU copies are synchronous.
150+ // dst_device selects the destination memory space (CPU for D2H, a CUDA device
151+ // for D2D). When stream is non-null, GPU copies use that stream (async fast
152+ // path). When stream is null, GPU copies are synchronous.
152153inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl (
153154 const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
154155 executorch::runtime::etensor::Tensor* etensor,
156+ const executorch::backends::aoti::slim::c10::Device& dst_device,
155157 cudaStream_t stream) {
156158 ET_CHECK_OK_OR_RETURN_ERROR (_check_tensor_metadata (slim_tensor, etensor));
157159
@@ -165,7 +167,7 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
165167
166168 if (_strides_match (slim_tensor, etensor)) {
167169 // Fast path: strides match, raw byte copy
168- if (slim_tensor->is_cpu ()) {
170+ if (slim_tensor->is_cpu () && dst_device. is_cpu () ) {
169171 std::memcpy (dst_data, src_data, nbytes);
170172 } else if (stream) {
171173 executorch::backends::aoti::slim::DeviceTraits<
@@ -174,23 +176,19 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
174176 dst_data,
175177 src_data,
176178 nbytes,
177- executorch::backends::aoti::slim:: CPU_DEVICE ,
179+ dst_device ,
178180 slim_tensor->device (),
179181 stream);
180182 } else {
181183 executorch::backends::aoti::slim::DeviceTraits<
182184 executorch::backends::aoti::slim::c10::DeviceType::CUDA >::
183- memcpy (
184- dst_data,
185- src_data,
186- nbytes,
187- executorch::backends::aoti::slim::CPU_DEVICE ,
188- slim_tensor->device ());
185+ memcpy (dst_data, src_data, nbytes, dst_device, slim_tensor->device ());
189186 }
190187 } else {
191188 // Slow path: strides differ (e.g., AOTI delegate output layout differs
192- // from .pte's dim_order). Copy to a temp CPU buffer, then rearrange
193- // element-by-element to match the ETensor's expected layout.
189+ // from .pte's dim_order). Copy to a temp CPU buffer, rearrange
190+ // element-by-element to match the ETensor's expected layout, then move the
191+ // result to the destination (CPU stays in place; GPU gets an H2D copy).
194192 std::vector<char > tmp (nbytes);
195193 if (slim_tensor->is_cpu ()) {
196194 std::memcpy (tmp.data (), src_data, nbytes);
@@ -218,13 +216,38 @@ inline executorch::runtime::Error _copy_slimtensor_to_etensor_impl(
218216
219217 size_t elem_size = executorch::backends::aoti::slim::c10::elementSize (
220218 slim_tensor->dtype ());
221- _strided_copy (
222- dst_data,
223- tmp.data (),
224- elem_size,
225- sizes_vec,
226- src_strides_vec,
227- dst_strides_vec);
219+
220+ if (dst_device.is_cpu ()) {
221+ _strided_copy (
222+ dst_data,
223+ tmp.data (),
224+ elem_size,
225+ sizes_vec,
226+ src_strides_vec,
227+ dst_strides_vec);
228+ } else {
229+ // Rearrange into a CPU staging buffer, then copy to the GPU destination.
230+ std::vector<char > rearranged (nbytes);
231+ _strided_copy (
232+ rearranged.data (),
233+ tmp.data (),
234+ elem_size,
235+ sizes_vec,
236+ src_strides_vec,
237+ dst_strides_vec);
238+ if (stream) {
239+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
240+ dst_data,
241+ rearranged.data (),
242+ nbytes,
243+ cudaMemcpyHostToDevice,
244+ stream));
245+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaStreamSynchronize (stream));
246+ } else {
247+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
248+ dst_data, rearranged.data (), nbytes, cudaMemcpyHostToDevice));
249+ }
250+ }
228251 }
229252
230253 return executorch::runtime::Error::Ok;
@@ -251,7 +274,39 @@ inline executorch::runtime::Error copy_slimtensor_to_etensor_async(
251274 const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
252275 executorch::runtime::etensor::Tensor* etensor,
253276 cudaStream_t stream) {
254- return _copy_slimtensor_to_etensor_impl (slim_tensor, etensor, stream);
277+ return _copy_slimtensor_to_etensor_impl (
278+ slim_tensor,
279+ etensor,
280+ executorch::backends::aoti::slim::CPU_DEVICE ,
281+ stream);
282+ }
283+
284+ /* *
285+ * Copies data from a SlimTensor to a GPU-resident ETensor asynchronously
286+ * (device-to-device).
287+ *
288+ * Used when the destination ETensor's storage lives in a planned GPU arena.
289+ * The destination device is taken from the source SlimTensor, so this only
290+ * supports same-device D2D copies (source and destination on the same GPU).
291+ *
292+ * When strides match (common case), performs a fast async D2D copy on the
293+ * provided stream. When strides differ, falls back to a staged copy with
294+ * element-by-element rearrangement on the host.
295+ *
296+ * NOTE: In the fast path the copy is asynchronous. The caller must synchronize
297+ * the stream before consuming the ETensor data.
298+ *
299+ * @param slim_tensor Pointer to the source SlimTensor (must not be null).
300+ * @param etensor Pointer to the destination GPU ETensor (must not be null).
301+ * @param stream The CUDA stream to use for async copy.
302+ * @return Error::Ok on success, or an appropriate error code on failure.
303+ */
304+ inline executorch::runtime::Error copy_slimtensor_to_device_etensor_async (
305+ const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
306+ executorch::runtime::etensor::Tensor* etensor,
307+ cudaStream_t stream) {
308+ return _copy_slimtensor_to_etensor_impl (
309+ slim_tensor, etensor, slim_tensor->device (), stream);
255310}
256311
257312/* *
@@ -267,7 +322,11 @@ inline executorch::runtime::Error copy_slimtensor_to_etensor_async(
267322inline executorch::runtime::Error copy_slimtensor_to_etensor (
268323 const executorch::backends::aoti::slim::SlimTensor* slim_tensor,
269324 executorch::runtime::etensor::Tensor* etensor) {
270- return _copy_slimtensor_to_etensor_impl (slim_tensor, etensor, nullptr );
325+ return _copy_slimtensor_to_etensor_impl (
326+ slim_tensor,
327+ etensor,
328+ executorch::backends::aoti::slim::CPU_DEVICE ,
329+ nullptr );
271330}
272331
273332/* *
0 commit comments