Extend moe_3gemm to all oneDNN aware GPUs#35335
Open
peterchen-intel wants to merge 22 commits into
Open
Conversation
causing it to be skipped on MTL-class iGPU (12.70.x, XeHPG, no DPAS). This left raw FP32 weight-decompression chains that overwhelmed propagate_constants with ~56 GB of constant-folding memory. Root cause of inference failure: moe_3gemm_swiglu_opt uses oneDNN internally (onednn_linear for gate/up/down matrix multiplications). OneDNN requires an in-order OCL queue. MTL uses out-of-order queue by default because use_onednn is false when supports_immad=false. Fix: three MoE transformation passes (FuseVectorizedMOE3GEMM, ConvertMOEToMOECompressed, FuseMOE3GemmCompressed) run on all architectures. FuseMOE3GemmCompressed creates MOE3GemmFusedCompressed which the OCL moe_3gemm_swiglu_opt kernel executes. - Detect MOE3GemmFusedCompressed in apply_model_specific_options and force use_onednn=true so finalize_impl sets queue_type=in_order, satisfying the oneDNN in-order queue requirement. - Fix moe_gather validate_impl to accept rank-2 input for models where the batch dimension is pre-flattened (Qwen3-style). - Re-apply iGPU transfer skip (usm_shared -> usm_device) in network.cpp and program.cpp for integrated GPUs where both allocation types share system DRAM (xe2+ or 12.7x-class MTL/ARL-S). Tested on machine (GPU uArch 12.70.4 / XeHPG / System memory 64 GB): model loads in 14 s, generates meaningful tokens, Unevictable stays below 120 MB. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Chen Peter <peter.chen@intel.com>
Instead of setting use_onednn=true when MOE3GemmFusedCompressed is detected, set m_queue_type=in_order directly. This is more precise: the only requirement is an in-order OCL command queue (for onednn_linear in moe_3gemm_swiglu_opt.cpp), not full oneDNN enablement for the whole model. Leaving use_onednn=false on non-systolic hardware (MTL, 12.70.x) ensures that oneDNN implementations for FC, convolution, GEMM etc. are not activated on hardware without DPAS units. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The MTL-class (12.70.x) iGPU has a separate GPU L3 cache from the CPU, so copying usm_shared -> usm_device does improve GPU access performance. Reverts the MTL condition added in the prior fix commit, keeping only the original xe2+ integrated GPU skip (which has true unified memory). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
In the correct fix, GEMM3_SWIGLU (Qwen3) always goes through FuseMOE3GemmCompressed -> MOE3GemmFusedCompressed, which creates a single fused primitive with no standalone moe_gather node. The rank-2 accept was only needed during an intermediate broken debug state where FuseMOE3GemmCompressed was wrongly blocked. moe_gather is only used by GEMM2_BIAS_SWIGLU_CLAMP models, whose input is rank-3. Restore original: input_pshapes.rank() != 3 || input_pshapes[2].is_dynamic() Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The previous fix set m_queue_type = in_order directly in apply_model_specific_options, but this left m_use_onednn = false. On non-systolic hardware (supports_immad=false), program.cpp only calls lo.enable_onednn_for<lstm_seq/gru_seq>() (making the onednn_impls_optimization_attribute non-empty, which triggers create_onednn_engine() in select_preferred_formats.cpp) when use_onednn=true. With use_onednn=false, the engine is never initialized, causing moe_3gemm_fused_compressed to crash at inference time with 'oneDNN engine not initialized'. Fix: set m_use_onednn = true (not queue_type) when MOE3GemmFusedCompressed is detected. finalize_impl then sets queue_type = in_order because use_onednn=true, and the create_onednn_engine() call is correctly triggered. This is safe on non-systolic hardware: FuseVectorizedFC (systolic FC) is gated independently on supports_immad, so no systolic ops are introduced by enabling use_onednn for the MoE path. Verified: all 3 prompts pass with correct output on MTL iGPU (GPU_UARCH_VERSION=12.70.4, supports_immad=false). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Extends the Intel GPU plugin’s MoE GEMM3 (Qwen3-style GEMM3_SWIGLU) conversion/execution path so it is no longer limited to systolic-array devices, ensuring the MoE graph is structurally converted early and executed via the fused OpenCL kernel path.
Changes:
- Always registers/runs
ov::pass::FuseVectorizedMOE3GEMM(removes the non-systolic skip), enabling the downstreamMOE → MOECompressed → MOE3GemmFusedCompressedpipeline across architectures. - Ensures
use_onednn=trueis enabled whenMOE3GemmFusedCompressedis present so the GPU plugin uses an in-order queue and initializes oneDNN engine required by the fused MoE kernel. - Makes
MOECompressed(GEMM3_SWIGLU)a hard error at program build time if it reaches primitive creation, catching missing fusion/pipeline misconfiguration early.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
src/plugins/intel_gpu/src/runtime/execution_config.cpp |
Enables use_onednn when MOE3GemmFusedCompressed is detected to force in-order queue + oneDNN engine init. |
src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp |
Removes architecture gating so FuseVectorizedMOE3GEMM runs unconditionally. |
src/plugins/intel_gpu/src/plugin/ops/moe.cpp |
Throws if MOECompressed(GEMM3_SWIGLU) reaches program build, enforcing the fused-kernel execution path. |
Contributor
Author
No. It is a fixing, but still has issue for multi-chunk. |
paged_attention_opt__multi_tokens allocates a tmp_out scratch buffer sized total_tokens * heads_num * v_head_size * num_of_partitions * sizeof(float). For Qwen3-30B with chunk_size=4096 and 8K KV context this is 2 GB per layer. With 48 layers all executing sequentially, this totalled 96 GB of demand-paged USM device allocation. On Intel iGPU (ARLS, i915 driver), the driver pins the entire allocation as Unevictable on first GPU access regardless of pages touched, causing CL_OUT_OF_RESOURCES on a 31 GB machine. Root cause: can_share_internal_buffer(false) in paged_attention_node unconditionally blocked the memory pool for ALL internal buffers. This was added in PR openvinotoolkit#33204 to prevent CPU/GPU races on lockable buffers (blocks_indexes_start/end, blocked_gws_subseq_mapping) written by prepare_internal_buffers(). However it also blocked pool reuse for non-lockable GPU-only buffers (exp_sums, max_logits, tmp_out) which are safe to share across sequential layers. Fix: - Remove can_share_internal_buffer(false) from paged_attention_node; per-buffer lockability already tracked via BufferDescriptor::m_lockable. so CPU-written (lockable=true, usm_host) buffers remain non-shareable while GPU-only (lockable=false, usm_device) buffers can be reused from the pool. - In allocate_internal_buffers(): pass buffer_descs[i].m_lockable to the call (previously dropped, causing wrong alloc type on initial allocation). Result: 48 layers share one 2 GB tmp_out buffer instead of allocating 48 separate 2 GB buffers. Peak Unevictable drops from OOM crash (~28+ GB) to ~18.9 GB on ARLS (Intel Arc 8086:7d67, Arrow Lake-S iGPU, 31 GB). Verified: Qwen3-30B-A3B-Instruct-2507-int4-ov with chunk_size=4096, 8K prompt, ContinuousBatching on ARLS completes successfully with exit code 0 and 20 coherent output tokens. Not affected on ARLH (supports_immad=true takes micro_sdpa path which does not allocate tmp_out at all). Signed-off-by: Chen Peter <peter.chen@intel.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
8575a7d to
1498dd1
Compare
Contributor
Author
|
Since #34974 merged today. Will test QWen3 and GPT-OSS again. The results are in CVS-182696 |
Contributor
ef2519b to
2ee0c08
Compare
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
peterchen-intel
commented
Jun 2, 2026
peterchen-intel
commented
Jun 2, 2026
peterchen-intel
commented
Jun 2, 2026
Co-authored-by: Chen Peter <peter.chen@intel.com>
peterchen-intel
commented
Jun 2, 2026
Comment on lines
+231
to
+242
| // moe_3gemm_fused_compressed uses oneDNN internally for matrix multiplications | ||
| // (onednn_linear wrappers in moe_3gemm_swiglu_opt.cpp), which requires: | ||
| // 1. use_onednn=true so create_onednn_engine() is called during program build | ||
| // (see program.cpp: lo.enable_onednn_for<lstm_seq/gru_seq> path which makes | ||
| // onednn_impls_optimization_attribute non-empty, triggering engine init). | ||
| // 2. in-order OCL command queue (finalize_impl sets this when use_onednn=true). | ||
| // Auto-enable this only on architectures with oneDNN support, consistent with | ||
| // the LSTM/GRU path above, to avoid initializing oneDNN on unsupported devices. | ||
| if (ov::is_type<ov::intel_gpu::op::MOE3GemmFusedCompressed>(op) && | ||
| info.arch >= cldnn::gpu_arch::xe_lp) { | ||
| m_use_onednn = true; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Details:
This pull request enhances support for the compressed Mixture-of-Experts (MoE) fusion chain in the Intel GPU plugin, particularly improving compatibility on non-systolic (non-IMAD) hardware. The changes ensure that the MoE fusion passes and their dependencies are applied more broadly, and that the required oneDNN engine is enabled when using the new
MOE3GemmFusedCompressedoperation, even on devices that do not natively support systolic operations.MoE Fusion Pipeline Improvements:
ConvertTiledMoeBlockToGatherMatmulsandConvertGatherMatmulToGatherMatmulCompressed) is now run on all devices, not just those with systolic (IMAD) support, improving model compatibility and performance across a wider range of hardware.oneDNN Integration for MoE:
use_onednnwhen theMOE3GemmFusedCompressedoperation is detected, ensuring the oneDNN engine is initialized for models that require it, even on non-systolic hardware. This is necessary for correct operation and performance of the fused MoE kernels.Dependency Updates:
moe_3gemm_fused_compressed.hppto the runtime configuration source, ensuring the new operation is recognized and available during execution.Limitation
Tickets:
AI Assistance: