Skip to content

Commit 2bb145a

Browse files
committed
up
1 parent c806c00 commit 2bb145a

12 files changed

Lines changed: 747 additions & 41 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ jobs:
6666
echo "::endgroup::"
6767
6868
echo "::group::Build test runners"
69-
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
69+
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
70+
echo "::endgroup::"
71+
72+
echo "::group::Run mutable-state (multi-session) unit test"
73+
./cmake-out/backends/mlx/test/mlx_mutable_state_test
7074
echo "::endgroup::"
7175
7276
echo "::group::Run op unit tests"

backends/mlx/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
255255
ON
256256
)
257257

258-
set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
259-
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
258+
set(_mlx_backend__srcs
259+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
260+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
261+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp
260262
)
261263

262264
add_library(mlxdelegate ${_mlx_backend__srcs})

backends/mlx/runtime/MLXBackend.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "MLXExecutor.h"
1010
#include "MLXInterpreter.h"
1111
#include "MLXLoader.h"
12+
#include "mlx_mutable_state.h"
1213

1314
#include <executorch/runtime/backend/interface.h>
1415
#include <executorch/runtime/core/error.h>
@@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
277278
eval(handle->constants.tensors);
278279
}
279280

281+
// Register the handle with the per-session mutable-state manager. This is
282+
// a no-op unless a multi-session owner is active for this load (see
283+
// mlx_mutable_state.h); single-session execution is unaffected.
284+
mutable_state_note_handle(
285+
handle, &handle->program, &handle->mutable_buffers);
286+
280287
} catch (const std::exception& e) {
281288
ET_LOG(Error, "Failed to load MLX program: %s", e.what());
282289
handle->~MLXHandle();
@@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
366373
}
367374
}
368375

376+
// Select the active session's mutable buffers (KV cache, recurrent/conv
377+
// state) before running. No-op for single-session handles; weights stay
378+
// shared via ExecutionState::constants.
379+
if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state);
380+
rebind_err != Error::Ok) {
381+
return rebind_err;
382+
}
383+
369384
// Run the MLX program (builds lazy computation graph)
370385
h->interpreter.run(program, h->state, h->stream);
371386

@@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
431446
void destroy(DelegateHandle* handle) const override {
432447
std::lock_guard<std::mutex> lock(mlx_global_mutex());
433448
if (handle != nullptr) {
449+
mutable_state_forget_handle(handle);
434450
auto* mlx_handle = static_cast<MLXHandle*>(handle);
435451
mlx_handle->~MLXHandle();
436452
}
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "mlx_mutable_state.h"
10+
11+
#include "MLXExecutor.h"
12+
#include "MLXLoader.h"
13+
14+
#include <executorch/runtime/platform/log.h>
15+
16+
#include <mutex>
17+
#include <unordered_map>
18+
19+
namespace executorch {
20+
namespace backends {
21+
namespace mlx {
22+
23+
using ::executorch::runtime::Error;
24+
using ::executorch::runtime::Result;
25+
26+
namespace {
27+
28+
struct HandleInfo {
29+
const MLXProgram* program{nullptr};
30+
MutableBufferData* default_buffers{nullptr};
31+
};
32+
33+
struct Context {
34+
// Delegate handles associated with this loaded program (one per loaded
35+
// method). Keyed by opaque MLXHandle pointer.
36+
std::unordered_map<const void*, HandleInfo> handles;
37+
// Per-session mutable buffers: token -> (handle -> buffers). Allocated lazily
38+
// on first execute for a given (session, handle).
39+
std::unordered_map<int, std::unordered_map<const void*, MutableBufferData>>
40+
sessions;
41+
int next_token{0};
42+
};
43+
44+
// Process-global registry. MLX serializes execution via its own global mutex and
45+
// the engine serializes per session, but the registry itself is guarded here so
46+
// context/session lifecycle calls from other threads are safe.
47+
std::mutex& registry_mutex() {
48+
static std::mutex m;
49+
return m;
50+
}
51+
52+
std::unordered_map<MutableStateContext, Context>& contexts() {
53+
static std::unordered_map<MutableStateContext, Context> c;
54+
return c;
55+
}
56+
57+
std::unordered_map<const void*, MutableStateContext>& handle_ctx() {
58+
static std::unordered_map<const void*, MutableStateContext> m;
59+
return m;
60+
}
61+
62+
MutableStateContext g_next_ctx = 1; // 0 is reserved as invalid.
63+
64+
// Thread-local load scope and active (ctx, session) selection.
65+
thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext;
66+
thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext;
67+
thread_local int tl_active_token = kNoMutableSession;
68+
69+
} // namespace
70+
71+
namespace detail {
72+
73+
MutableStateContext mutable_state_create_context() {
74+
std::lock_guard<std::mutex> g(registry_mutex());
75+
MutableStateContext ctx = g_next_ctx++;
76+
if (ctx == kInvalidMutableContext) {
77+
ctx = g_next_ctx++;
78+
}
79+
contexts()[ctx];
80+
return ctx;
81+
}
82+
83+
void mutable_state_destroy_context(MutableStateContext ctx) {
84+
std::lock_guard<std::mutex> g(registry_mutex());
85+
auto it = contexts().find(ctx);
86+
if (it == contexts().end()) {
87+
return;
88+
}
89+
for (const auto& kv : it->second.handles) {
90+
handle_ctx().erase(kv.first);
91+
}
92+
contexts().erase(it);
93+
}
94+
95+
void mutable_state_begin_load(MutableStateContext ctx) {
96+
tl_loading_ctx = ctx;
97+
}
98+
99+
void mutable_state_end_load() {
100+
tl_loading_ctx = kInvalidMutableContext;
101+
}
102+
103+
bool mutable_state_available(MutableStateContext ctx) {
104+
if (ctx == kInvalidMutableContext) {
105+
return false;
106+
}
107+
std::lock_guard<std::mutex> g(registry_mutex());
108+
return contexts().count(ctx) != 0;
109+
}
110+
111+
int64_t mutable_state_bytes_per_session(MutableStateContext ctx) {
112+
std::lock_guard<std::mutex> g(registry_mutex());
113+
auto it = contexts().find(ctx);
114+
if (it == contexts().end()) {
115+
return 0;
116+
}
117+
int64_t total = 0;
118+
for (const auto& kv : it->second.handles) {
119+
const MutableBufferData* bufs = kv.second.default_buffers;
120+
if (bufs == nullptr) {
121+
continue;
122+
}
123+
for (const auto& t : bufs->tensors) {
124+
if (t.has_value()) {
125+
total += static_cast<int64_t>(t->nbytes());
126+
}
127+
}
128+
}
129+
return total;
130+
}
131+
132+
Error mutable_state_validate_coverage(MutableStateContext ctx) {
133+
// MLX clones all mutable buffers by tid; there is no FQN coverage to verify.
134+
(void)ctx;
135+
return Error::Ok;
136+
}
137+
138+
Result<int> mutable_state_create_session(MutableStateContext ctx) {
139+
std::lock_guard<std::mutex> g(registry_mutex());
140+
auto it = contexts().find(ctx);
141+
if (it == contexts().end()) {
142+
ET_LOG(Error, "mutable_state_create_session: unknown context %d", ctx);
143+
return Error::InvalidState;
144+
}
145+
int token = it->second.next_token++;
146+
// Per-handle buffers are allocated lazily on first execute.
147+
it->second.sessions[token];
148+
return token;
149+
}
150+
151+
void mutable_state_destroy_session(MutableStateContext ctx, int token) {
152+
std::lock_guard<std::mutex> g(registry_mutex());
153+
auto it = contexts().find(ctx);
154+
if (it == contexts().end()) {
155+
return;
156+
}
157+
it->second.sessions.erase(token);
158+
}
159+
160+
void mutable_state_set_active(MutableStateContext ctx, int token) {
161+
tl_active_ctx = ctx;
162+
tl_active_token = token;
163+
}
164+
165+
} // namespace detail
166+
167+
void mutable_state_note_handle(
168+
const void* handle,
169+
const MLXProgram* program,
170+
MutableBufferData* default_buffers) {
171+
if (tl_loading_ctx == kInvalidMutableContext) {
172+
return; // No multi-session owner active during this load: single-session.
173+
}
174+
std::lock_guard<std::mutex> g(registry_mutex());
175+
auto it = contexts().find(tl_loading_ctx);
176+
if (it == contexts().end()) {
177+
return;
178+
}
179+
it->second.handles[handle] = HandleInfo{program, default_buffers};
180+
handle_ctx()[handle] = tl_loading_ctx;
181+
}
182+
183+
void mutable_state_forget_handle(const void* handle) {
184+
std::lock_guard<std::mutex> g(registry_mutex());
185+
auto hit = handle_ctx().find(handle);
186+
if (hit == handle_ctx().end()) {
187+
return;
188+
}
189+
auto cit = contexts().find(hit->second);
190+
if (cit != contexts().end()) {
191+
cit->second.handles.erase(handle);
192+
for (auto& session : cit->second.sessions) {
193+
session.second.erase(handle);
194+
}
195+
}
196+
handle_ctx().erase(hit);
197+
}
198+
199+
Error mutable_state_rebind_for_execute(
200+
const void* handle,
201+
ExecutionState& state) {
202+
std::lock_guard<std::mutex> g(registry_mutex());
203+
auto hit = handle_ctx().find(handle);
204+
if (hit == handle_ctx().end()) {
205+
// Handle was not loaded under a multi-session owner: keep default buffers.
206+
return Error::Ok;
207+
}
208+
auto cit = contexts().find(hit->second);
209+
if (cit == contexts().end()) {
210+
return Error::Ok;
211+
}
212+
Context& ctx = cit->second;
213+
HandleInfo& info = ctx.handles[handle];
214+
215+
const bool active_for_this_ctx =
216+
tl_active_token != kNoMutableSession && tl_active_ctx == hit->second;
217+
218+
if (!active_for_this_ctx) {
219+
// No session selected. Refuse if sessions exist (running against the default
220+
// buffers here would not isolate state from created sessions).
221+
if (!ctx.sessions.empty()) {
222+
ET_LOG(
223+
Error,
224+
"mutable_state_rebind_for_execute: no active session selected but "
225+
"sessions exist for this program");
226+
return Error::InvalidState;
227+
}
228+
state.mutable_buffers = info.default_buffers;
229+
return Error::Ok;
230+
}
231+
232+
auto sit = ctx.sessions.find(tl_active_token);
233+
if (sit == ctx.sessions.end()) {
234+
ET_LOG(
235+
Error,
236+
"mutable_state_rebind_for_execute: unknown session token %d",
237+
tl_active_token);
238+
return Error::InvalidState;
239+
}
240+
241+
auto& per_handle = sit->second;
242+
auto bit = per_handle.find(handle);
243+
if (bit == per_handle.end()) {
244+
// First execute for this (session, handle): allocate fresh zeroed buffers.
245+
// Constants/weights stay shared (ExecutionState::constants is untouched);
246+
// only the mutable buffers are per-session.
247+
MutableBufferData buffers;
248+
try {
249+
load_mutable_buffers(*info.program, buffers);
250+
} catch (const std::exception& e) {
251+
ET_LOG(
252+
Error,
253+
"mutable_state_rebind_for_execute: failed to allocate session "
254+
"buffers: %s",
255+
e.what());
256+
return Error::MemoryAllocationFailed;
257+
}
258+
bit = per_handle.emplace(handle, std::move(buffers)).first;
259+
}
260+
// unordered_map keeps element pointers stable across rehash, so this remains
261+
// valid for the duration of the execute.
262+
state.mutable_buffers = &bit->second;
263+
return Error::Ok;
264+
}
265+
266+
} // namespace mlx
267+
} // namespace backends
268+
} // namespace executorch

0 commit comments

Comments
 (0)