1919#include < cmath>
2020#include < cstring>
2121
22+ #include < algorithm>
23+
2224#ifdef EXECUTORCH_BUILD_CUDA
2325#include < cuda_runtime.h>
2426#include < executorch/backends/cuda/runtime/cuda_mutable_state.h>
@@ -39,6 +41,20 @@ using SizesType = executorch::aten::SizesType;
3941
4042namespace {
4143
44+ #ifdef EXECUTORCH_BUILD_MLX
45+ // The MLX export emits a single dynamic-seq `forward` method that handles both
46+ // prefill (T>=2) and decode (T=1). Mirror gemma4_31b's MLX runner, which loads
47+ // and calls `forward` for both phases.
48+ constexpr const char * kPrefillMethod = " forward" ;
49+ constexpr const char * kDecodeMethod = " forward" ;
50+ // Prefill is chunked on MLX to cap peak memory and the compiled prefill shape.
51+ constexpr int64_t kPrefillChunkSize = 1024 ;
52+ #else
53+ // CUDA/Metal exports emit two separate methods.
54+ constexpr const char * kPrefillMethod = " prefill" ;
55+ constexpr const char * kDecodeMethod = " decode" ;
56+ #endif
57+
4258Result<uint64_t > read_sampled_token (
4359 const executorch::aten::Tensor& output,
4460 float temperature) {
@@ -98,8 +114,10 @@ Result<std::unique_ptr<Module>> build_qwen_module(
98114 }
99115#endif
100116
101- ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (" prefill" ));
102- ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (" decode" ));
117+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kPrefillMethod ));
118+ if (std::string (kDecodeMethod ) != std::string (kPrefillMethod )) {
119+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kDecodeMethod ));
120+ }
103121 return module ;
104122}
105123
@@ -240,34 +258,51 @@ class Qwen35MoESession : public LLMSession {
240258 }
241259
242260 stop_.store (false , std::memory_order_relaxed);
243- std::vector<int64_t > token_data (tokens.begin (), tokens.end ());
244- std::vector<int64_t > pos_data (T);
245- for (int64_t i = 0 ; i < T; ++i) {
246- pos_data[i] = pos_ + i;
247- }
248- auto tokens_tensor = from_blob (
249- token_data.data (),
250- {1 , static_cast <SizesType>(T)},
251- executorch::aten::ScalarType::Long);
252- auto pos_tensor = from_blob (
253- pos_data.data (),
254- {static_cast <SizesType>(T)},
255- executorch::aten::ScalarType::Long);
256-
257- const char * method = (T >= 2 ) ? " prefill" : " decode" ;
258- std::vector<EValue> inputs;
259- inputs.push_back (tokens_tensor);
260- inputs.push_back (pos_tensor);
261+
262+ // On MLX, run prefill in fixed-size chunks (caps peak memory and the
263+ // compiled prefill shape). Other backends prefill the whole prompt in one
264+ // pass. Only the final chunk's sampled token is kept; the recurrence/KV
265+ // state from earlier chunks persists via pos_ advancement.
266+ #ifdef EXECUTORCH_BUILD_MLX
267+ const int64_t chunk_size = kPrefillChunkSize ;
268+ #else
269+ const int64_t chunk_size = T;
270+ #endif
271+
272+ uint64_t sampled_token = 0 ;
273+ for (int64_t off = 0 ; off < T; off += chunk_size) {
274+ const int64_t len = std::min (chunk_size, T - off);
275+ std::vector<int64_t > token_data (
276+ tokens.begin () + off, tokens.begin () + off + len);
277+ std::vector<int64_t > pos_data (len);
278+ for (int64_t i = 0 ; i < len; ++i) {
279+ pos_data[i] = pos_ + i;
280+ }
281+ auto tokens_tensor = from_blob (
282+ token_data.data (),
283+ {1 , static_cast <SizesType>(len)},
284+ executorch::aten::ScalarType::Long);
285+ auto pos_tensor = from_blob (
286+ pos_data.data (),
287+ {static_cast <SizesType>(len)},
288+ executorch::aten::ScalarType::Long);
289+
290+ const char * method = (len >= 2 ) ? kPrefillMethod : kDecodeMethod ;
291+ std::vector<EValue> inputs;
292+ inputs.push_back (tokens_tensor);
293+ inputs.push_back (pos_tensor);
261294#ifdef EXECUTORCH_BUILD_CUDA
262- set_temp (first_token_temp);
263- inputs.push_back (EValue (temp_tensor_));
295+ set_temp (first_token_temp);
296+ inputs.push_back (EValue (temp_tensor_));
264297#endif
265- auto sampled =
266- run_locked (method, inputs, first_token_temp, /* sync_after=*/ true );
267- ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
268- pending_ = sampled.get ();
298+ auto sampled =
299+ run_locked (method, inputs, first_token_temp, /* sync_after=*/ true );
300+ ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
301+ sampled_token = sampled.get ();
302+ pos_ += len;
303+ }
304+ pending_ = sampled_token;
269305 prev_decode_token_.reset ();
270- pos_ += T;
271306 return Error::Ok;
272307 }
273308
@@ -334,7 +369,7 @@ class Qwen35MoESession : public LLMSession {
334369 inputs.push_back (EValue (temp_tensor_));
335370#endif
336371 auto sampled =
337- run_locked (" decode " , inputs, temperature_, /* sync_after=*/ false );
372+ run_locked (kDecodeMethod , inputs, temperature_, /* sync_after=*/ false );
338373 ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
339374 pending_ = sampled.get ();
340375 prev_decode_token_ = token;
0 commit comments