Skip to content

Commit 182be0e

Browse files
authored
cuda: add per-session mutable state rebinding (#20241)
Local agent serving needs to host multiple logical conversations on one CUDA-resident model without multiplying the model weights. Loading one AOTI module per conversation is not viable for large local models, while sharing the default mutable state across conversations would let KV/recurrent/conv buffers bleed between users. This adds the CUDA-private foundation for separating those concerns: weights remain owned by the loaded AOTI container, while mutable buffer FQNs can be registered as per-session state and rebound before execution. The path is fail-closed and dormant until a model opts in by creating a mutable-state context and validating coverage, so existing CUDA models keep their current behavior. The branch also wires the new source and fall-closed unit test into both Buck and CMake so the primitive can land independently before any model-specific engine consumes it. #20001
1 parent 99ca02f commit 182be0e

6 files changed

Lines changed: 1783 additions & 6 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ install(
184184
)
185185

186186
# CUDA backend implementation
187-
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
187+
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp
188+
runtime/cuda_mutable_state.cpp
189+
)
188190
if(_cuda_is_msvc_toolchain)
189191
# MSVC links aoti_cuda_backend into portable_lib without relying on C++
190192
# symbols exported from aoti_cuda_shims.dll.
@@ -236,3 +238,13 @@ install(
236238
EXPORT ExecuTorchTargets
237239
DESTINATION lib
238240
)
241+
242+
if(BUILD_TESTING)
243+
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
244+
245+
et_cxx_test(
246+
test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp
247+
EXTRA_LIBS aoti_cuda_backend
248+
)
249+
target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1)
250+
endif()

backends/cuda/runtime/TARGETS

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
3+
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
24
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
35

46
oncall("executorch")
@@ -105,9 +107,11 @@ runtime.cxx_library(
105107
name = "cuda_backend",
106108
srcs = [
107109
"cuda_backend.cpp",
110+
"cuda_mutable_state.cpp",
108111
],
109112
headers = [
110113
"cuda_delegate_handle.h",
114+
"cuda_mutable_state.h",
111115
],
112116
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
113117
link_whole = True,
@@ -135,3 +139,26 @@ runtime.cxx_library(
135139
("cuda", None, "cuda-lazy"),
136140
],
137141
)
142+
143+
cpp_unittest(
144+
name = "test_cuda_mutable_state",
145+
srcs = [
146+
"test/test_cuda_mutable_state.cpp",
147+
],
148+
deps = [
149+
":cuda_backend",
150+
"//executorch/backends/aoti:aoti_common_slim",
151+
"//executorch/backends/aoti/slim/core:slimtensor",
152+
"//executorch/backends/aoti/slim/factory:from_blob",
153+
"//executorch/runtime/core:core",
154+
"//executorch/runtime/platform:platform",
155+
],
156+
external_deps = [
157+
("cuda", None, "cuda-lazy"),
158+
],
159+
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
160+
keep_gpu_sections = True,
161+
remote_execution = re_test_utils.remote_execution(
162+
platform = "gpu-remote-execution",
163+
),
164+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <executorch/backends/aoti/utils.h>
4545
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
4646
#include <executorch/backends/cuda/runtime/cuda_delegate_handle.h>
47+
#include <executorch/backends/cuda/runtime/cuda_mutable_state.h>
4748
#include <executorch/backends/cuda/runtime/platform/platform.h>
4849
#include <executorch/backends/cuda/runtime/shims/memory.h>
4950
#include <executorch/backends/cuda/runtime/utils.h>
@@ -436,6 +437,8 @@ class ET_EXPERIMENTAL CudaBackend final
436437
kCudaGraphWarmupSteps);
437438
}
438439

440+
mutable_state_note_handle(handle);
441+
439442
return (DelegateHandle*)handle; // Return the handle post-processing
440443
}
441444

@@ -539,6 +542,8 @@ class ET_EXPERIMENTAL CudaBackend final
539542
}
540543
}
541544

545+
ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle));
546+
542547
// ---------------------------------------------------------------
543548
// CUDA graph REPLAY path — skip all tensor setup and just replay
544549
// ---------------------------------------------------------------
@@ -826,6 +831,8 @@ class ET_EXPERIMENTAL CudaBackend final
826831
}
827832
cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_;
828833

834+
mutable_state_forget_handle(handle);
835+
829836
// The CUDA stream is managed by shared_ptr in the handle.
830837
// It will be automatically destroyed when the last handle using it
831838
// is destroyed. Just reset our reference.
@@ -899,11 +906,12 @@ class ET_EXPERIMENTAL CudaBackend final
899906
// * Constants are assumed to be IMMUTABLE (parameters or read-only
900907
// buffers). The AOTI shim today does not expose a mutability bit
901908
// through GetConstantOriginalFQN, so we cannot detect or refuse
902-
// to share mutable buffers (e.g. a per-method KV cache). If a
903-
// future model exports the same FQN as a mutable buffer in
904-
// multiple methods, mutations from one method WILL be visible to
905-
// the other through the shared GPU memory. Callers that need
906-
// per-method mutable state must currently use distinct FQNs.
909+
// to share mutable buffers (for example, runtime caches). If a
910+
// model exports the same FQN as a mutable buffer in multiple
911+
// methods, mutations from one method WILL be visible to the other
912+
// through the shared GPU memory. Callers that need isolated mutable
913+
// state for shared FQNs must opt into cuda_mutable_state or use
914+
// distinct FQNs.
907915
// TODO: when AOTInductor exposes a constant-type / mutability
908916
// query, refuse to share entries that are not PARAMETER or
909917
// non-mutable BUFFER.

0 commit comments

Comments
 (0)