diff --git a/src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp b/src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp index b2c727d7fe1a..2e0ce163fe1a 100644 --- a/src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp +++ b/src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp @@ -5,6 +5,7 @@ #include "gated_delta_net.h" #include +#include #include #include #include @@ -24,12 +25,139 @@ #include "openvino/op/gated_delta_net.hpp" #include "shape_inference/shape_inference_cpu.hpp" #include "utils/plain_tensor.hpp" +#if defined(OPENVINO_ARCH_X86_64) +# include "cpu_parallel.hpp" +# include "kernels/x64/gdn_jit_kernel.hpp" +using namespace dnnl::impl::cpu::x64; +#endif using namespace ov::Extensions::Cpu; using namespace ov::Extensions::Cpu::XARCH; namespace ov::intel_cpu::node { +#if defined(OPENVINO_ARCH_X86_64) +namespace { +struct GatedDeltaNetKey { + ov::element::Type precision; + size_t qk_head_size; + size_t v_tile; + bool fuse_qk_l2norm; + float q_l2_norm_eps; + float k_l2_norm_eps; + + [[nodiscard]] size_t hash() const { + size_t seed = 0; + seed = dnnl::impl::hash_combine(seed, precision.hash()); + seed = dnnl::impl::hash_combine(seed, qk_head_size); + seed = dnnl::impl::hash_combine(seed, v_tile); + seed = dnnl::impl::hash_combine(seed, fuse_qk_l2norm); + seed = dnnl::impl::hash_combine(seed, q_l2_norm_eps); + seed = dnnl::impl::hash_combine(seed, k_l2_norm_eps); + return seed; + } + + bool operator==(const GatedDeltaNetKey& rhs) const { + return precision == rhs.precision && qk_head_size == rhs.qk_head_size && v_tile == rhs.v_tile && + fuse_qk_l2norm == rhs.fuse_qk_l2norm && q_l2_norm_eps == rhs.q_l2_norm_eps && + k_l2_norm_eps == rhs.k_l2_norm_eps; + } +}; + +void recurrent_linear_attn_jit(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& key, + const ov::intel_cpu::PlainTensor& value, + const ov::intel_cpu::PlainTensor& recurrent_state, + const ov::intel_cpu::PlainTensor& gate, + const ov::intel_cpu::PlainTensor& beta, + ov::intel_cpu::PlainTensor& output_attn, + ov::intel_cpu::PlainTensor& output_recurrent_state, + uint8_t* temp_buffer, + const ov::intel_cpu::CpuParallelPtr& cpu_parallel, + const std::shared_ptr& jit_kernel, + const size_t gdn_jit_v_tile) { + OPENVINO_ASSERT(jit_kernel, "GDN JIT kernel is not created"); + + const size_t B = query.m_dims[0]; + const size_t T = query.m_dims[1]; + const size_t qk_heads = query.m_dims[2]; + const size_t K = query.m_dims[3]; + const size_t v_heads = value.m_dims[2]; + const size_t V = value.m_dims[3]; + const auto data_prc = query.m_dt; + const size_t elem_size = ov::element::Type(data_prc).size(); + OPENVINO_ASSERT(ov::intel_cpu::any_of(data_prc, ov::element::f16, ov::element::bf16), + "GDN JIT supports only f16/bf16 state copy path"); + const size_t group_size = v_heads / qk_heads; + OPENVINO_ASSERT(V % gdn_jit_v_tile == 0, "GDN JIT requires V divisible by ", gdn_jit_v_tile, ", got V=", V); + const size_t v_tiles = V / gdn_jit_v_tile; + const size_t state_tile_size = gdn_jit_v_tile * K * elem_size; + const size_t thread_buffer_size = (gdn_jit_v_tile + 2) * K * elem_size; + cpu_parallel->parallel_for3d(B, v_heads, v_tiles, [&](size_t i_b, size_t i_h, size_t i_v_tile) { + const size_t tid = parallel_get_thread_num(); + const size_t i_v_begin = i_v_tile * gdn_jit_v_tile; + + // Per-thread layout: [state tile][key tmp][query tmp] + uint8_t* state_buffer = temp_buffer + tid * thread_buffer_size; + uint8_t* b_k = state_buffer + state_tile_size; + uint8_t* b_q = b_k + K * elem_size; + + const size_t hk = i_h / group_size; + auto* q_ptr = query.ptr_v(i_b, 0, hk); + auto* k_ptr = key.ptr_v(i_b, 0, hk); + auto* v_ptr = value.ptr_v(i_b, 0, i_h, i_v_begin); + auto* out_ptr = output_attn.ptr_v(i_b, 0, i_h, i_v_begin); + + auto* gate_ptr = gate.ptr_v(i_b, 0, i_h); + auto* beta_ptr = beta.ptr_v(i_b, 0, i_h); + + const size_t recurrent_state_stride_k = recurrent_state.stride(2); + const size_t recurrent_state_stride_v = recurrent_state.stride(3); + const size_t output_state_stride_k = output_recurrent_state.stride(2); + const size_t output_state_stride_v = output_recurrent_state.stride(3); + + // JIT path stores state in 2-byte elements (f16/bf16). + auto* init_state_u16 = reinterpret_cast(state_buffer); + auto* recurrent_state_u16 = reinterpret_cast(recurrent_state.ptr_v(i_b, i_h, 0, i_v_begin)); + for (size_t j = 0; j < K; j++) { + const auto* src_row = recurrent_state_u16 + j * recurrent_state_stride_k; + for (size_t v_idx = 0; v_idx < gdn_jit_v_tile; v_idx++) { + init_state_u16[v_idx * K + j] = src_row[v_idx * recurrent_state_stride_v]; + } + } + + kernel::jit_gdn_call_args args{}; + args.state = state_buffer; + args.key_seq = reinterpret_cast(k_ptr); + args.query_seq = reinterpret_cast(q_ptr); + args.value_seq = reinterpret_cast(v_ptr); + args.gate_seq = reinterpret_cast(gate_ptr); + args.beta_seq = reinterpret_cast(beta_ptr); + args.t_size = T; + args.key_query_stride = qk_heads * K; + args.gate_beta_stride = v_heads; + args.value_stride = v_heads * V; + args.output_stride = v_heads * V; + args.key_tmp = b_k; + args.query_tmp = b_q; + args.output_seq = reinterpret_cast(out_ptr); + (*jit_kernel)(&args); + + // Copy final state tile back (2-byte elements: f16/bf16). + auto* final_state_u16 = reinterpret_cast(state_buffer); + auto* output_state_u16 = reinterpret_cast(output_recurrent_state.ptr_v(i_b, i_h, 0, i_v_begin)); + for (size_t j = 0; j < K; j++) { + auto* dst_row = output_state_u16 + j * output_state_stride_k; + for (size_t v_idx = 0; v_idx < gdn_jit_v_tile; v_idx++) { + dst_row[v_idx * output_state_stride_v] = final_state_u16[v_idx * K + j]; + } + } + }); +} + +} // namespace +#endif + GatedDeltaNet::GatedDeltaNet(const std::shared_ptr& op, const GraphContext::CPtr& context) : Node(op, context, NgraphShapeInferFactory(op)) { std::string errorMessage; @@ -44,7 +172,18 @@ GatedDeltaNet::GatedDeltaNet(const std::shared_ptr& op, const GraphCon void GatedDeltaNet::initSupportedPrimitiveDescriptors() { // TODO: support other precision CVS-182464 - auto dataPrecision = ov::element::f32; + auto dataPrecision = getOriginalOutputPrecisionAtPort(0); + auto implType = impl_desc_type::ref_any; +#if defined(OPENVINO_ARCH_X86_64) + const auto queryDims = getInputShapeAtPort(0).getDims(); + auto headSize = *(queryDims.end() - 1); + if (ov::intel_cpu::any_of(getOriginalOutputPrecisionAtPort(0), ov::element::f16, ov::element::bf16) && + (mayiuse(avx512_core_bf16) || mayiuse(avx512_core_fp16)) && headSize % 32 == 0) { + implType = impl_desc_type::jit_avx512; + m_enableJit = true; + } +#endif + std::vector inPortConfigs; for (size_t i = 0; i < getParentEdges().size(); ++i) { inPortConfigs.emplace_back(LayoutType::ncsp, dataPrecision, getInputShapeAtPort(i), false, -1); @@ -52,16 +191,43 @@ void GatedDeltaNet::initSupportedPrimitiveDescriptors() { std::vector outPortConfigs = { PortConfigurator{LayoutType::ncsp, dataPrecision, getOutputShapeAtPort(0), false, -1}, PortConfigurator{LayoutType::ncsp, dataPrecision, getOutputShapeAtPort(1), false, -1}}; - addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); + addSupportedPrimDesc(inPortConfigs, outPortConfigs, implType); } void GatedDeltaNet::createPrimitive() { const auto queryDims = getInputShapeAtPort(0).getDims(); auto headSize = *(queryDims.end() - 1); + size_t scratchRows = 3; + auto scratchPrecision = ov::element::f32; + // if head_size is not multiple of 32, fallbacks to intrinsic kernel +#if defined(OPENVINO_ARCH_X86_64) + if (m_enableJit) { + const auto precision = getOriginalOutputPrecisionAtPort(0); + GatedDeltaNetKey key{precision, headSize, m_gdnJitVTile, m_fuse_qk_l2norm, m_q_l2_norm_eps, m_k_l2_norm_eps}; + + auto builder = [&](const GatedDeltaNetKey& compile_key) -> std::shared_ptr { + return kernel::create_gdn_jit_kernel(compile_key.precision, + compile_key.qk_head_size, + compile_key.v_tile, + compile_key.fuse_qk_l2norm, + compile_key.q_l2_norm_eps, + compile_key.k_l2_norm_eps); + }; + + auto cache = context->getParamsCache(); + auto result = cache->getOrCreate(key, builder); + m_gdnJitKernel = result.first; + if (m_gdnJitKernel) { + scratchPrecision = precision; + scratchRows = m_gdnJitVTile + 2; + } + } +#endif + const auto numWorkerThreads = context->getCpuParallel()->get_num_worker_threads(); auto newMemDesc = std::make_shared( - ov::element::f32, - ov::intel_cpu::Shape{static_cast(numWorkerThreads), 3 * headSize}); + scratchPrecision, + ov::intel_cpu::Shape{static_cast(numWorkerThreads), scratchRows * headSize}); m_tmpInpBuffer = context->getScratchPad()->createScratchPadMem(newMemDesc); } @@ -86,7 +252,31 @@ void GatedDeltaNet::execute([[maybe_unused]] const dnnl::stream& strm) { PlainTensor output_attn(outputs[0]); PlainTensor output_recurrent_state(outputs[1]); - auto* temp_buffer = m_tmpInpBuffer->getDataAs(); + auto* temp_buffer = reinterpret_cast(m_tmpInpBuffer->getData()); +#if defined(OPENVINO_ARCH_X86_64) + if (m_gdnJitKernel) { + OPENVINO_ASSERT(value.m_dims[3] % m_gdnJitVTile == 0, + "GDN JIT requires V divisible by ", + m_gdnJitVTile, + ", got V=", + value.m_dims[3]); + + recurrent_linear_attn_jit(query, + key, + value, + recurrent_state, + gate, + beta, + output_attn, + output_recurrent_state, + temp_buffer, + context->getCpuParallel(), + m_gdnJitKernel, + m_gdnJitVTile); + return; + } +#endif + recurrent_linear_attn(query, key, value, @@ -98,7 +288,7 @@ void GatedDeltaNet::execute([[maybe_unused]] const dnnl::stream& strm) { m_fuse_qk_l2norm, output_attn, output_recurrent_state, - temp_buffer, + reinterpret_cast(temp_buffer), context->getCpuParallel()); } diff --git a/src/plugins/intel_cpu/src/nodes/gated_delta_net.h b/src/plugins/intel_cpu/src/nodes/gated_delta_net.h index b7f908ae06ce..a118890a5c33 100644 --- a/src/plugins/intel_cpu/src/nodes/gated_delta_net.h +++ b/src/plugins/intel_cpu/src/nodes/gated_delta_net.h @@ -15,6 +15,10 @@ #include "openvino/core/node.hpp" #include "openvino/core/type/element_type.hpp" +namespace ov::intel_cpu::kernel { +class JitKernelBase; +} + namespace ov::intel_cpu::node { class GatedDeltaNet : public Node { @@ -48,6 +52,11 @@ class GatedDeltaNet : public Node { bool m_fuse_qk_l2norm = false; float m_q_l2_norm_eps = 1e-6F; float m_k_l2_norm_eps = 1e-6F; +#if defined(OPENVINO_ARCH_X86_64) + std::shared_ptr m_gdnJitKernel; + size_t m_gdnJitVTile = 16; + bool m_enableJit = false; +#endif }; } // namespace ov::intel_cpu::node diff --git a/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp b/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp index f0e93e59cc32..58b8e90aeaf7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp @@ -78,21 +78,22 @@ static inline void l2norm(float* a, size_t n, float eps) { #endif } -void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, - const ov::intel_cpu::PlainTensor& key, - const ov::intel_cpu::PlainTensor& value, - const ov::intel_cpu::PlainTensor& recurrent_state, - const ov::intel_cpu::PlainTensor& gate, - const ov::intel_cpu::PlainTensor& beta, - float q_l2_norm_eps, - float k_l2_norm_eps, - bool use_qk_l2norm, - ov::intel_cpu::PlainTensor& output_attn, - ov::intel_cpu::PlainTensor& output_recurrent_state, - float* temp_buffer, - const ov::intel_cpu::CpuParallelPtr& cpu_parallel) { +template +static void recurrent_linear_attn_impl(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& key, + const ov::intel_cpu::PlainTensor& value, + const ov::intel_cpu::PlainTensor& recurrent_state, + const ov::intel_cpu::PlainTensor& gate, + const ov::intel_cpu::PlainTensor& beta, + float q_l2_norm_eps, + float k_l2_norm_eps, + bool use_qk_l2norm, + ov::intel_cpu::PlainTensor& output_attn, + ov::intel_cpu::PlainTensor& output_recurrent_state, + float* temp_buffer, + const ov::intel_cpu::CpuParallelPtr& cpu_parallel) { size_t B = query.m_dims[0]; - size_t T = query.m_dims[1]; + size_t timesteps = query.m_dims[1]; size_t qk_heads = query.m_dims[2]; size_t K = query.m_dims[3]; size_t v_heads = value.m_dims[2]; @@ -108,24 +109,25 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, float* b_q = temp_buffer + tid * 3 * K_HEAD_DIMS + 2 * K_HEAD_DIMS; const size_t hk = i_h / group_size; // B, T, qk, K - float* q_ptr = query.ptr(i_b, 0, hk); - float* k_ptr = key.ptr(i_b, 0, hk); + T* q_ptr = query.ptr(i_b, 0, hk); + T* k_ptr = key.ptr(i_b, 0, hk); // B, T, v_heads, V - float* v_ptr = value.ptr(i_b, 0, i_h); + T* v_ptr = value.ptr(i_b, 0, i_h); // B, v_heads, K, V + // Load recurrent state with stride V in K dimension + T* state_ptr = recurrent_state.ptr(i_b, i_h, 0, i_v); for (size_t j = 0; j < K_HEAD_DIMS; j++) { - init_state[j] = recurrent_state.at({i_b, i_h, j, i_v}); + init_state[j] = static_cast(state_ptr[j * V_HEAD_DIMS]); } - for (size_t i = 0; i < T; i++) { + for (size_t i = 0; i < timesteps; i++) { // gate: B, T, v_heads - float b_g = gate.at({i_b, i, i_h}); - float b_beta = beta.at({i_b, i, i_h}); - b_g = exp(b_g); - for (size_t j = 0; j < K_HEAD_DIMS; j++) { - b_k[j] = k_ptr[i * qk_heads * K_HEAD_DIMS + j]; - b_q[j] = q_ptr[i * qk_heads * K_HEAD_DIMS + j]; - } + float b_g = static_cast(gate.at({i_b, i, i_h})); + float b_beta = static_cast(beta.at({i_b, i, i_h})); + b_g = std::exp(b_g); + // Vectorized load of contiguous k and q + cvt_copy(b_k, k_ptr + i * qk_heads * K_HEAD_DIMS, 1, K_HEAD_DIMS, 0, 0); + cvt_copy(b_q, q_ptr + i * qk_heads * K_HEAD_DIMS, 1, K_HEAD_DIMS, 0, 0); if (use_qk_l2norm) { l2norm(b_k, K_HEAD_DIMS, k_l2_norm_eps); l2norm(b_q, K_HEAD_DIMS, q_l2_norm_eps); @@ -135,7 +137,7 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, multiply_scalar(init_state, init_state, b_g, K_HEAD_DIMS); float h_k = dot_product(init_state, b_k, K_HEAD_DIMS, nullptr, nullptr, nullptr, 0); // B, T, v_heads, V - float b_v = v_ptr[i_v + i * v_heads * V_HEAD_DIMS]; + float b_v = static_cast(v_ptr[i_v + i * v_heads * V_HEAD_DIMS]); b_v -= h_k; // b_v * b_k b_v *= b_beta; @@ -143,14 +145,78 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, // h = h0 + update cvt_add(init_state, init_state, b_k, 1, K_HEAD_DIMS, 0, 0, 0); float b_output = dot_product(init_state, b_q, K_HEAD_DIMS, nullptr, nullptr, nullptr, 0); - output_attn.at({i_b, i, i_h, i_v}) = b_output; + output_attn.at({i_b, i, i_h, i_v}) = static_cast(b_output); } + // Store recurrent state with stride V in K dimension + T* state_out_ptr = output_recurrent_state.ptr(i_b, i_h, 0, i_v); for (size_t j = 0; j < K_HEAD_DIMS; j++) { - output_recurrent_state.at({i_b, i_h, j, i_v}) = init_state[j]; + state_out_ptr[j * V_HEAD_DIMS] = static_cast(init_state[j]); } }); } +void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query, + const ov::intel_cpu::PlainTensor& key, + const ov::intel_cpu::PlainTensor& value, + const ov::intel_cpu::PlainTensor& recurrent_state, + const ov::intel_cpu::PlainTensor& gate, + const ov::intel_cpu::PlainTensor& beta, + float q_l2_norm_eps, + float k_l2_norm_eps, + bool use_qk_l2norm, + ov::intel_cpu::PlainTensor& output_attn, + ov::intel_cpu::PlainTensor& output_recurrent_state, + float* temp_buffer, + const ov::intel_cpu::CpuParallelPtr& cpu_parallel) { + const auto data_prc = query.get_precision(); + + if (data_prc == ov::element::f32) { + recurrent_linear_attn_impl(query, + key, + value, + recurrent_state, + gate, + beta, + q_l2_norm_eps, + k_l2_norm_eps, + use_qk_l2norm, + output_attn, + output_recurrent_state, + temp_buffer, + cpu_parallel); + } else if (data_prc == ov::element::f16) { + recurrent_linear_attn_impl(query, + key, + value, + recurrent_state, + gate, + beta, + q_l2_norm_eps, + k_l2_norm_eps, + use_qk_l2norm, + output_attn, + output_recurrent_state, + temp_buffer, + cpu_parallel); + } else if (data_prc == ov::element::bf16) { + recurrent_linear_attn_impl(query, + key, + value, + recurrent_state, + gate, + beta, + q_l2_norm_eps, + k_l2_norm_eps, + use_qk_l2norm, + output_attn, + output_recurrent_state, + temp_buffer, + cpu_parallel); + } else { + OPENVINO_ASSERT(false, "[CPU] gdn: unsupported precision", data_prc); + } +} + template static void recurrent_linear_attn_paged_impl(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& key, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.cpp new file mode 100644 index 000000000000..f445ebd22689 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.cpp @@ -0,0 +1,666 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gdn_jit_kernel.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "emitters/plugin/x64/jit_load_store_emitters.hpp" +#include "jit_kernel_base.hpp" +#include "openvino/core/type/element_type.hpp" + +using namespace dnnl::impl::cpu; +using namespace dnnl::impl::cpu::x64; + +namespace ov::intel_cpu::kernel { + +#define GET_OFF(field) offsetof(jit_gdn_call_args, field) + +template +void jit_gdn_kernel::load(const Vmm& vmm_dst, + const Xbyak::Reg64& reg_src, + ov::element::Type src_prc, + const int& elt_num, + bool fill, + size_t offset) { + // Typed load helper (src_prc -> f32 VMM via jit emitter) + const auto seed = load_emitter_params(src_prc, ov::element::f32, elt_num, fill, "float_min").hash(); + if (!emitters[seed]) { + constexpr cpu_isa_t load_isa = ((isa & zmm_bit) != 0) ? avx512_core : isa; + emitters[seed] = std::make_unique(this, + load_isa, + src_prc, + ov::element::f32, + elt_num, + ov::element::f32, + fill, + "float_min"); + } + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), offset}, + {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, + pool_aux_gpr_idxs); +} + +template +void jit_gdn_kernel::store(const Xbyak::Reg64& reg_dst, + const Vmm& vmm_src, + ov::element::Type dst_prc, + const int& elt_num, + size_t offset) { + // Typed store helper (f32 VMM -> dst_prc via jit emitter) + const auto seed = store_emitter_params(ov::element::f32, dst_prc, elt_num).hash(); + if (!emitters[seed]) { + constexpr cpu_isa_t store_isa = ((isa & zmm_bit) != 0) ? avx512_core : isa; + emitters[seed] = std::make_unique(this, store_isa, ov::element::f32, dst_prc, elt_num); + } + emitters[seed]->emit_code({static_cast(vmm_src.getIdx())}, + {static_cast(reg_dst.getIdx()), offset}, + pool_aux_vmm_idxs, + pool_aux_gpr_idxs); +} + +template +void jit_gdn_kernel::reduce_zmm_f32_to_xmm_scalar(const Xbyak::Zmm& zmm_src, + const Xbyak::Xmm& xmm_dst, + const Xbyak::Xmm& xmm_tmp0, + const Xbyak::Xmm& xmm_tmp1) { + // Horizontal reduce 16x f32 (ZMM) into scalar lane of xmm_dst + vextractf32x8(Xbyak::Ymm(xmm_tmp1.getIdx()), zmm_src, 1); + vaddps(Xbyak::Ymm(xmm_tmp0.getIdx()), Xbyak::Ymm(zmm_src.getIdx()), Xbyak::Ymm(xmm_tmp1.getIdx())); + vextractf128(xmm_tmp1, Xbyak::Ymm(xmm_tmp0.getIdx()), 1); + vaddps(xmm_tmp0, xmm_tmp0, xmm_tmp1); + vhaddps(xmm_tmp0, xmm_tmp0, xmm_tmp0); + vhaddps(xmm_tmp0, xmm_tmp0, xmm_tmp0); + vaddss(xmm_dst, xmm_dst, xmm_tmp0); +} + +// ============================================ +// Native xf16 helpers +// ============================================ + +template +void jit_gdn_kernel::load_vector_native_xf16(Vmm* vmm_array, const Xbyak::Reg64& reg_src, int num_regs) { + // Load fp16 vector (up to 4 ZMMs) directly from memory + for (int i = 0; i < num_regs; i++) { + vmovups(vmm_array[i], ptr[reg_src + i * 64]); + } +} + +template +void jit_gdn_kernel::store_vector_native_xf16(const Xbyak::Reg64& reg_dst, Vmm* vmm_array, int num_regs) { + // Store fp16 vector (up to 4 ZMMs) to memory + for (int i = 0; i < num_regs; i++) { + vmovups(ptr[reg_dst + i * 64], vmm_array[i]); + } +} + +template +void jit_gdn_kernel::dot_product_native_xf16(const Xbyak::Xmm& xmm_dst, Vmm* vmm_a, Vmm* vmm_b, int num_regs) { + if (m_jcp.data_prc == ov::element::bf16) { + // bf16 path: accumulate directly in fp32 with vdpbf16ps + uni_vpxor(v_aux0, v_aux0, v_aux0); + for (int i = 0; i < num_regs; i++) { + vdpbf16ps(v_aux0, vmm_a[i], vmm_b[i]); + } + uni_vpxor(xmm_dst, xmm_dst, xmm_dst); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), xmm_dst, x_tmp0, x_tmp1); + return; + } + + // f16 path: native fp16 accumulation then fp32 reduction + uni_vpxor(v_tmp0, v_tmp0, v_tmp0); // fp16 accumulator (32 lanes) + + for (int i = 0; i < num_regs; i++) { + vfmadd231ph(v_tmp0, vmm_a[i], vmm_b[i]); + } + + vcvtph2ps(v_aux0, Xbyak::Ymm(v_tmp0.getIdx())); + vextractf32x8(Xbyak::Ymm(x_tmp0.getIdx()), Xbyak::Zmm(v_tmp0.getIdx()), 1); + vcvtph2ps(v_aux1, Xbyak::Ymm(x_tmp0.getIdx())); + vaddps(v_aux0, v_aux0, v_aux1); + + uni_vpxor(xmm_dst, xmm_dst, xmm_dst); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), xmm_dst, x_tmp0, x_tmp1); +} + +template +void jit_gdn_kernel::scale_vector_native_xf16(Vmm* vmm_array, const Xbyak::Xmm& xmm_scalar, int num_regs) { + if (m_jcp.data_prc == ov::element::bf16) { + // bf16 path: unpack->fp32 mul->pack bf16 per half + vbroadcastss(v_aux2, xmm_scalar); + + for (int i = 0; i < num_regs; i++) { + // lower 16 bf16 -> fp32 + vpmovzxwd(v_aux0, Xbyak::Ymm(vmm_array[i].getIdx())); + vpslld(v_aux0, v_aux0, 16); + vmulps(v_aux0, v_aux0, v_aux2); + vcvtneps2bf16(Xbyak::Ymm(x_tmp0.getIdx()), v_aux0); + + // upper 16 bf16 -> fp32 + vextractf32x8(Xbyak::Ymm(v_aux1.getIdx()), Xbyak::Zmm(vmm_array[i].getIdx()), 1); + vpmovzxwd(v_aux1, Xbyak::Ymm(v_aux1.getIdx())); + vpslld(v_aux1, v_aux1, 16); + vmulps(v_aux1, v_aux1, v_aux2); + vcvtneps2bf16(Xbyak::Ymm(x_tmp1.getIdx()), v_aux1); + + vinsertf32x8(Xbyak::Zmm(vmm_array[i].getIdx()), + Xbyak::Zmm(x_tmp0.getIdx()), + Xbyak::Ymm(x_tmp1.getIdx()), + 1); + } + return; + } + + // f16 path + vcvtps2ph(x_tmp0, xmm_scalar, 0); + vpbroadcastw(v_aux2, x_tmp0); + + for (int i = 0; i < num_regs; i++) { + vmulph(vmm_array[i], vmm_array[i], v_aux2); + } +} + +template +void jit_gdn_kernel::fmadd_vector_native_xf16(Vmm* vmm_dst, + Vmm* vmm_src, + const Xbyak::Xmm& xmm_scalar, + int num_regs) { + if (m_jcp.data_prc == ov::element::bf16) { + // bf16 path: unpack dst/src -> fp32 fma -> pack bf16 per half + vbroadcastss(v_aux2, xmm_scalar); + + for (int i = 0; i < num_regs; i++) { + // lower half + vpmovzxwd(v_aux0, Xbyak::Ymm(vmm_dst[i].getIdx())); + vpslld(v_aux0, v_aux0, 16); + vpmovzxwd(v_aux1, Xbyak::Ymm(vmm_src[i].getIdx())); + vpslld(v_aux1, v_aux1, 16); + vfmadd231ps(v_aux0, v_aux1, v_aux2); + vcvtneps2bf16(Xbyak::Ymm(x_tmp0.getIdx()), v_aux0); + + // upper half + vextractf32x8(Xbyak::Ymm(v_aux0.getIdx()), Xbyak::Zmm(vmm_dst[i].getIdx()), 1); + vpmovzxwd(v_aux0, Xbyak::Ymm(v_aux0.getIdx())); + vpslld(v_aux0, v_aux0, 16); + vextractf32x8(Xbyak::Ymm(v_aux1.getIdx()), Xbyak::Zmm(vmm_src[i].getIdx()), 1); + vpmovzxwd(v_aux1, Xbyak::Ymm(v_aux1.getIdx())); + vpslld(v_aux1, v_aux1, 16); + vfmadd231ps(v_aux0, v_aux1, v_aux2); + vcvtneps2bf16(Xbyak::Ymm(x_tmp1.getIdx()), v_aux0); + + vinsertf32x8(Xbyak::Zmm(vmm_dst[i].getIdx()), Xbyak::Zmm(x_tmp0.getIdx()), Xbyak::Ymm(x_tmp1.getIdx()), 1); + } + return; + } + + // f16 path + vcvtps2ph(x_tmp0, xmm_scalar, 0); + vpbroadcastw(v_aux2, x_tmp0); + + for (int i = 0; i < num_regs; i++) { + vfmadd231ph(vmm_dst[i], vmm_src[i], v_aux2); + } +} + +template +void jit_gdn_kernel::l2norm_inplace_native_xf16(Vmm* vmm_array, const Xbyak::Xmm& xmm_eps, int num_regs) { + // L2 normalization: vmm /= sqrt(sum(vmm^2) + eps) + uni_vpxor(v_aux0, v_aux0, v_aux0); // fp32 accumulator + + if (m_jcp.data_prc == ov::element::bf16) { + for (int i = 0; i < num_regs; i++) { + vdpbf16ps(v_aux0, vmm_array[i], vmm_array[i]); + } + } else { + for (int i = 0; i < num_regs; i++) { + // lower 16 fp16 lanes + vcvtph2ps(v_aux1, Xbyak::Ymm(vmm_array[i].getIdx())); + vfmadd231ps(v_aux0, v_aux1, v_aux1); + + // upper 16 fp16 lanes + vextractf32x8(Xbyak::Ymm(v_tmp0.getIdx()), Xbyak::Zmm(vmm_array[i].getIdx()), 1); + vcvtph2ps(v_aux1, Xbyak::Ymm(v_tmp0.getIdx())); + vfmadd231ps(v_aux0, v_aux1, v_aux1); + } + } + + // Reduce to scalar: sqrt(sum + eps), then compute reciprocal + uni_vpxor(x_hk, x_hk, x_hk); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), x_hk, x_tmp0, x_tmp1); + vaddss(x_hk, x_hk, xmm_eps); + vsqrtss(x_hk, x_hk, x_hk); + + mov(reg_aux.cvt32(), float2int(1.0F)); + vmovd(x_tmp1, reg_aux.cvt32()); + vdivss(x_tmp1, x_tmp1, x_hk); // reciprocal + + // Scale vector by reciprocal + scale_vector_native_xf16(vmm_array, x_tmp1, num_regs); +} + +// ============================================ +// Buffer-based L2 norm helper for qk_head_size > 128 +// ============================================ + +template +void jit_gdn_kernel::l2norm_buffer_compute_scale_native_xf16(const Xbyak::Reg64& reg_buffer, + const Xbyak::Xmm& xmm_eps, + const Xbyak::Xmm& xmm_scale_out, + int num_regs, + int num_chunks) { + // Compute L2 norm scale: 1/sqrt(sum(x^2) + eps) + // Accumulates across all chunks from buffer, returns scale factor + uni_vpxor(v_aux0, v_aux0, v_aux0); + if (m_jcp.data_prc == ov::element::f16) { + uni_vpxor(v_tmp0, v_tmp0, v_tmp0); + } + + // Accumulate sum of squares across all chunks + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_buffer); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_k), reg_aux2, chunk_regs); + + if (m_jcp.data_prc == ov::element::bf16) { + for (int i = 0; i < chunk_regs; i++) { + vdpbf16ps(v_aux0, v_k[i], v_k[i]); + } + } else { + // fp16 path - accumulate in native fp16 + for (int i = 0; i < chunk_regs; i++) { + vfmadd231ph(v_tmp0, v_k[i], v_k[i]); + } + } + } + + // Convert fp16 to fp32 after all chunks + if (m_jcp.data_prc == ov::element::f16) { + vcvtph2ps(v_aux1, Xbyak::Ymm(v_tmp0.getIdx())); + vextractf32x8(Xbyak::Ymm(x_tmp0.getIdx()), Xbyak::Zmm(v_tmp0.getIdx()), 1); + vcvtph2ps(v_aux2, Xbyak::Ymm(x_tmp0.getIdx())); + vaddps(v_aux0, v_aux1, v_aux2); + } + + // Compute reciprocal: 1/sqrt(sum + eps) + uni_vpxor(xmm_scale_out, xmm_scale_out, xmm_scale_out); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), xmm_scale_out, x_tmp0, x_tmp1); + vaddss(xmm_scale_out, xmm_scale_out, xmm_eps); + vsqrtss(xmm_scale_out, xmm_scale_out, xmm_scale_out); + mov(reg_aux.cvt32(), float2int(1.0F)); + vmovd(x_value, reg_aux.cvt32()); + vdivss(xmm_scale_out, x_value, xmm_scale_out); +} + +template +void jit_gdn_kernel::scale_buffer_native_xf16(const Xbyak::Reg64& reg_buffer, + const Xbyak::Xmm& xmm_scale, + Vmm* vmm_temp, + int num_regs, + int num_chunks) { + // Scale all chunks of a buffer by a scalar + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_buffer); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(vmm_temp, reg_aux2, chunk_regs); + scale_vector_native_xf16(vmm_temp, xmm_scale, chunk_regs); + store_vector_native_xf16(reg_aux2, vmm_temp, chunk_regs); + } +} + +// ============================================ +// Main native xf16 kernel +// ============================================ + +template +void jit_gdn_kernel::generate() { + // JIT codegen for native xf16 path + // For qk_head_size <= 128: register-resident Q/K/H + // For qk_head_size > 128: use temp buffers + + auto exp_injector = std::make_shared>(this, + dnnl::impl::alg_kind::eltwise_exp, + 0.F, + 0.F, + 1.F, + dnnl::impl::data_type::f32, + true, + Xbyak::Reg64(Xbyak::Operand::RCX), + Xbyak::Opmask(1), + true, + false, + false, + false); + + this->preamble(); + + Xbyak::Label l_t_loop; + Xbyak::Label l_end; + + mov(reg_args, abi_param1); + + // Determine if we use registers or temp buffers + const size_t qk = m_jcp.qk_head_size; + const bool use_registers = (qk <= 128); + const auto num_regs = static_cast(qk / XF16_ELEMS_PER_ZMM); + const auto num_chunks = (num_regs + MAX_REGS_PER_VEC - 1) / MAX_REGS_PER_VEC; + + // One-time setup + exp_injector->load_table_addr(); + + mov(reg_key_seq, ptr[reg_args + GET_OFF(key_seq)]); + mov(reg_query_seq, ptr[reg_args + GET_OFF(query_seq)]); + mov(reg_gate_seq, ptr[reg_args + GET_OFF(gate_seq)]); + mov(reg_beta_seq, ptr[reg_args + GET_OFF(beta_seq)]); + mov(reg_value_seq, ptr[reg_args + GET_OFF(value_seq)]); + mov(reg_out_seq, ptr[reg_args + GET_OFF(output_seq)]); + mov(reg_t, ptr[reg_args + GET_OFF(t_size)]); + + if (!use_registers) { + mov(reg_key_tmp, ptr[reg_args + GET_OFF(key_tmp)]); + mov(reg_query_tmp, ptr[reg_args + GET_OFF(query_tmp)]); + } + + test(reg_t, reg_t); + jz(l_end, T_NEAR); + + L(l_t_loop); + { + // Reload scalar constants each iteration + mov(reg_aux.cvt32(), float2int(m_jcp.k_l2_norm_eps)); + vmovd(x_eps_k, reg_aux.cvt32()); + mov(reg_aux.cvt32(), float2int(m_jcp.q_l2_norm_eps)); + vmovd(x_eps_q, reg_aux.cvt32()); + mov(reg_aux.cvt32(), float2int(m_jcp.q_scale)); + vmovd(x_qscale, reg_aux.cvt32()); + + if (use_registers) { + // Load K, Q directly into registers + load_vector_native_xf16(const_cast(v_k), reg_key_seq, num_regs); + load_vector_native_xf16(const_cast(v_q), reg_query_seq, num_regs); + + // Optional L2 normalization + if (m_jcp.fuse_qk_l2norm) { + l2norm_inplace_native_xf16(const_cast(v_k), x_eps_k, num_regs); + l2norm_inplace_native_xf16(const_cast(v_q), x_eps_q, num_regs); + } + + // Scale query + scale_vector_native_xf16(const_cast(v_q), x_qscale, num_regs); + } else { + // Large head_size: use temp buffers and process in chunks + // Reset temp buffer pointers + mov(reg_key_tmp, ptr[reg_args + GET_OFF(key_tmp)]); + mov(reg_query_tmp, ptr[reg_args + GET_OFF(query_tmp)]); + + // Copy K, Q to temp buffers + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_key_seq); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_k), reg_aux2, chunk_regs); + + mov(reg_aux2, reg_key_tmp); + add(reg_aux2, chunk_offset); + store_vector_native_xf16(reg_aux2, const_cast(v_k), chunk_regs); + + mov(reg_aux2, reg_query_seq); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_q), reg_aux2, chunk_regs); + + mov(reg_aux2, reg_query_tmp); + add(reg_aux2, chunk_offset); + store_vector_native_xf16(reg_aux2, const_cast(v_q), chunk_regs); + } + + // Optional L2 normalization (process in chunks) + if (m_jcp.fuse_qk_l2norm) { + // Normalize K + l2norm_buffer_compute_scale_native_xf16(reg_key_tmp, x_eps_k, x_beta, num_regs, num_chunks); + scale_buffer_native_xf16(reg_key_tmp, x_beta, const_cast(v_k), num_regs, num_chunks); + + // Normalize Q and combine with q_scale + l2norm_buffer_compute_scale_native_xf16(reg_query_tmp, x_eps_q, x_beta, num_regs, num_chunks); + vmulss(x_beta, x_beta, x_qscale); // Combine: l2norm_scale * q_scale + scale_buffer_native_xf16(reg_query_tmp, x_beta, const_cast(v_q), num_regs, num_chunks); + } else { + // No L2 norm, just scale Q by q_scale + scale_buffer_native_xf16(reg_query_tmp, x_qscale, const_cast(v_q), num_regs, num_chunks); + } + } + + // Compute gate and beta once per timestep and share across V lanes + load(Vmm(x_gate.getIdx()), reg_gate_seq, m_jcp.data_prc, 1, false); + exp_injector->compute_vector_range(x_gate.getIdx(), x_gate.getIdx() + 1); + load(Vmm(x_beta.getIdx()), reg_beta_seq, m_jcp.data_prc, 1, false); + + // accumulate dot product of two vectors into v_aux0 + auto accumulate_dot_product = [&](Vmm* vmm_a, Vmm* vmm_b, int chunk_regs) { + if (m_jcp.data_prc == ov::element::bf16) { + for (int i = 0; i < chunk_regs; i++) { + vdpbf16ps(v_aux0, vmm_a[i], vmm_b[i]); + } + } else { + for (int i = 0; i < chunk_regs; i++) { + // lower 16 elements + vcvtph2ps(v_aux1, Xbyak::Ymm(vmm_a[i].getIdx())); + vcvtph2ps(v_aux2, Xbyak::Ymm(vmm_b[i].getIdx())); + vfmadd231ps(v_aux0, v_aux1, v_aux2); + // upper 16 elements + vextractf32x8(Xbyak::Ymm(v_aux1.getIdx()), Xbyak::Zmm(vmm_a[i].getIdx()), 1); + vcvtph2ps(v_aux1, Xbyak::Ymm(v_aux1.getIdx())); + vextractf32x8(Xbyak::Ymm(v_aux2.getIdx()), Xbyak::Zmm(vmm_b[i].getIdx()), 1); + vcvtph2ps(v_aux2, Xbyak::Ymm(v_aux2.getIdx())); + vfmadd231ps(v_aux0, v_aux1, v_aux2); + } + } + }; + + // Unrolled V lanes: all lanes share same K/Q path above + for (size_t v_idx = 0; v_idx < m_jcp.v_tile; ++v_idx) { + mov(reg_state, ptr[reg_args + GET_OFF(state)]); + add(reg_state, static_cast(v_idx * m_jcp.qk_head_size * m_jcp.data_prc.size())); + + // Preload value scalar to overlap memory latency with hk reduction work. + mov(reg_aux2, reg_value_seq); + add(reg_aux2, static_cast(v_idx * m_jcp.data_prc.size())); + load(Vmm(x_value.getIdx()), reg_aux2, m_jcp.data_prc, 1, false); + + if (use_registers) { + // Load H for current V lane + load_vector_native_xf16(const_cast(v_h), reg_state, num_regs); + + // Scale hidden state by exp(gate) + scale_vector_native_xf16(const_cast(v_h), x_gate, num_regs); + + // Compute hk = dot(H, K) + dot_product_native_xf16(x_hk, const_cast(v_h), const_cast(v_k), num_regs); + + // delta = (value - hk) * beta + vsubss(x_delta, x_value, x_hk); + vmulss(x_delta, x_delta, x_beta); + + // Update: H += K * delta + fmadd_vector_native_xf16(const_cast(v_h), const_cast(v_k), x_delta, num_regs); + + // Output: out = dot(H, Q) + dot_product_native_xf16(x_out, const_cast(v_h), const_cast(v_q), num_regs); + + // Store output and H state for current V lane + mov(reg_aux2, reg_out_seq); + add(reg_aux2, static_cast(v_idx * m_jcp.data_prc.size())); + store(reg_aux2, Vmm(x_out.getIdx()), m_jcp.data_prc, 1); + store_vector_native_xf16(reg_state, const_cast(v_h), num_regs); + } else { + // Scale H by exp(gate) + scale_buffer_native_xf16(reg_state, x_gate, const_cast(v_h), num_regs, num_chunks); + + // Compute hk = dot(H, K) + uni_vpxor(v_aux0, v_aux0, v_aux0); + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_state); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_h), reg_aux2, chunk_regs); + + mov(reg_aux2, reg_key_tmp); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_k), reg_aux2, chunk_regs); + + accumulate_dot_product(const_cast(v_h), const_cast(v_k), chunk_regs); + } + uni_vpxor(x_hk, x_hk, x_hk); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), x_hk, x_tmp0, x_tmp1); + + // delta = (value - hk) * beta + vsubss(x_delta, x_value, x_hk); + vmulss(x_delta, x_delta, x_beta); + + // Update: H += K * delta + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_state); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_h), reg_aux2, chunk_regs); + + mov(reg_aux2, reg_key_tmp); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_k), reg_aux2, chunk_regs); + + fmadd_vector_native_xf16(const_cast(v_h), const_cast(v_k), x_delta, chunk_regs); + + mov(reg_aux2, reg_state); + add(reg_aux2, chunk_offset); + store_vector_native_xf16(reg_aux2, const_cast(v_h), chunk_regs); + } + + // Output: out = dot(H, Q) + uni_vpxor(v_aux0, v_aux0, v_aux0); + for (int chunk = 0; chunk < num_chunks; chunk++) { + const int chunk_start = chunk * MAX_REGS_PER_VEC; + const int chunk_regs = std::min(MAX_REGS_PER_VEC, num_regs - chunk_start); + const size_t chunk_offset = chunk_start * XF16_ELEMS_PER_ZMM * m_jcp.data_prc.size(); + + mov(reg_aux2, reg_state); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_h), reg_aux2, chunk_regs); + + mov(reg_aux2, reg_query_tmp); + add(reg_aux2, chunk_offset); + load_vector_native_xf16(const_cast(v_q), reg_aux2, chunk_regs); + + accumulate_dot_product(const_cast(v_h), const_cast(v_q), chunk_regs); + } + uni_vpxor(x_out, x_out, x_out); + reduce_zmm_f32_to_xmm_scalar(Xbyak::Zmm(v_aux0.getIdx()), x_out, x_tmp0, x_tmp1); + + mov(reg_aux2, reg_out_seq); + add(reg_aux2, static_cast(v_idx * m_jcp.data_prc.size())); + store(reg_aux2, Vmm(x_out.getIdx()), m_jcp.data_prc, 1); + } + } + + // Advance pointers using stride parameters + mov(reg_aux2, ptr[reg_args + GET_OFF(key_query_stride)]); + imul(reg_aux2, reg_aux2, m_jcp.data_prc.size()); + add(reg_key_seq, reg_aux2); + add(reg_query_seq, reg_aux2); + + mov(reg_aux2, ptr[reg_args + GET_OFF(value_stride)]); + imul(reg_aux2, reg_aux2, m_jcp.data_prc.size()); + add(reg_value_seq, reg_aux2); + + mov(reg_aux2, ptr[reg_args + GET_OFF(gate_beta_stride)]); + imul(reg_aux2, reg_aux2, m_jcp.data_prc.size()); + add(reg_gate_seq, reg_aux2); + add(reg_beta_seq, reg_aux2); + + mov(reg_aux2, ptr[reg_args + GET_OFF(output_stride)]); + imul(reg_aux2, reg_aux2, m_jcp.data_prc.size()); + add(reg_out_seq, reg_aux2); + + dec(reg_t); + jnz(l_t_loop, T_NEAR); + } + + L(l_end); + + this->postamble(); + + exp_injector->prepare_table(); +} + +std::shared_ptr create_gdn_jit_kernel(ov::element::Type data_prc, + size_t qk_head_size, + size_t v_tile, + bool fuse_qk_l2norm, + float q_l2_norm_eps, + float k_l2_norm_eps) { + std::shared_ptr res; + jit_gdn_compile_params jcp; + jcp.data_prc = data_prc; + jcp.qk_head_size = qk_head_size; + jcp.v_tile = v_tile; + jcp.fuse_qk_l2norm = fuse_qk_l2norm; + jcp.q_l2_norm_eps = q_l2_norm_eps; + jcp.k_l2_norm_eps = k_l2_norm_eps; + jcp.q_scale = 1.0F / std::sqrt(static_cast(qk_head_size)); + + if (data_prc != ov::element::bf16 && data_prc != ov::element::f16) { + return res; + } + if (qk_head_size == 0 || qk_head_size % 32 != 0) { + return res; + } + if (v_tile == 0) { + return res; + } + + if (data_prc == ov::element::bf16) { + if (mayiuse(avx512_core_bf16)) { + res = std::make_shared>(jcp); + } + } else if (data_prc == ov::element::f16) { + if (mayiuse(avx512_core_fp16)) { + res = std::make_shared>(jcp); + } + } + + if (res) { + res->create_kernel(); + } + + return res; +} + +template struct jit_gdn_kernel; +template struct jit_gdn_kernel; + +} // namespace ov::intel_cpu::kernel diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.hpp new file mode 100644 index 000000000000..83690bee413e --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/gdn_jit_kernel.hpp @@ -0,0 +1,156 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "jit_kernel_base.hpp" +#include "openvino/core/type/element_type.hpp" + +namespace ov::intel_cpu::kernel { + +struct jit_gdn_compile_params { + ov::element::Type data_prc = ov::element::f32; + size_t qk_head_size = 0; + size_t v_tile = 1; + bool fuse_qk_l2norm = false; + float q_l2_norm_eps = 1e-6F; + float k_l2_norm_eps = 1e-6F; + float q_scale = 1.0F; +}; + +struct jit_gdn_call_args { + uint8_t* state; + const uint8_t* key_seq; + const uint8_t* query_seq; + const uint8_t* value_seq; + const uint8_t* gate_seq; + const uint8_t* beta_seq; + size_t t_size; + size_t key_query_stride; + size_t gate_beta_stride; + size_t value_stride; + size_t output_stride; + uint8_t* key_tmp; + uint8_t* query_tmp; + uint8_t* output_seq; +}; + +template +struct jit_gdn_kernel : public JitKernel { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_gdn_kernel) + + explicit jit_gdn_kernel(const jit_gdn_compile_params& jcp) : JitKernel(jit_name(), jcp, isa) {} + +private: + using Xmm = Xbyak::Xmm; + using Vmm = std::conditional_t; + + static constexpr size_t vec_size = dnnl::impl::cpu::x64::cpu_isa_traits_t::vlen / sizeof(float); + static constexpr size_t vec_bytes = vec_size * sizeof(float); + static constexpr int vec_shift = isa == dnnl::impl::cpu::x64::avx2 ? 3 : 4; + + // GPR map + const Xbyak::Reg64 reg_args = rbx; + const Xbyak::Reg64 reg_state = r8; + const Xbyak::Reg64 reg_key_tmp = r9; + const Xbyak::Reg64 reg_query_tmp = r10; + const Xbyak::Reg64 reg_t = r12; + const Xbyak::Reg64 reg_key_seq = r13; + const Xbyak::Reg64 reg_query_seq = r14; + const Xbyak::Reg64 reg_value_seq = r15; + const Xbyak::Reg64 reg_aux = r11; + const Xbyak::Reg64 reg_gate_seq = rsi; + const Xbyak::Reg64 reg_beta_seq = rdi; + const Xbyak::Reg64 reg_out_seq = rbp; + const Xbyak::Reg64 reg_aux2 = rax; + + // XMM map + const Xmm x_hk = Xmm(0); + const Xmm x_tmp0 = Xmm(1); + const Xmm x_tmp1 = Xmm(2); + const Xmm x_delta = Xmm(3); + const Xmm x_out = Xmm(4); + const Xmm x_gate = Xmm(5); + const Xmm x_beta = Xmm(6); + const Xmm x_value = Xmm(7); + const Xmm x_eps_k = Xmm(8); + const Xmm x_eps_q = Xmm(9); + const Xmm x_qscale = Xmm(10); + + const Vmm v_tmp0 = Vmm(x_tmp0.getIdx()); + const Vmm v_tmp1 = Vmm(x_tmp1.getIdx()); + const Vmm v_aux0 = Vmm(11); + const Vmm v_aux1 = Vmm(12); + const Vmm v_aux2 = Vmm(13); + + // Register-based Q/K/H storage for native f16 + // Supports head_dims that are multiples of 32, up to 128 + static constexpr int XF16_ELEMS_PER_ZMM = 32; // 32 xf16 elements per ZMM register + static constexpr int MAX_REGS_PER_VEC = 4; // Max ZMMs per vector (for head_dims=128) + + const Vmm v_q[MAX_REGS_PER_VEC] = {Vmm(14), Vmm(15), Vmm(16), Vmm(17)}; // Query + const Vmm v_k[MAX_REGS_PER_VEC] = {Vmm(18), Vmm(19), Vmm(20), Vmm(21)}; // Key + const Vmm v_h[MAX_REGS_PER_VEC] = {Vmm(22), Vmm(23), Vmm(24), Vmm(25)}; // Hidden state + + void generate() override; + + // Native xf16 helpers - f16/bf16, head_dims must be multiple of 32 + void load_vector_native_xf16(Vmm* vmm_array, const Xbyak::Reg64& reg_src, int num_regs); + void store_vector_native_xf16(const Xbyak::Reg64& reg_dst, Vmm* vmm_array, int num_regs); + void dot_product_native_xf16(const Xbyak::Xmm& xmm_dst, Vmm* vmm_a, Vmm* vmm_b, int num_regs); + void scale_vector_native_xf16(Vmm* vmm_array, const Xbyak::Xmm& xmm_scalar, int num_regs); + void fmadd_vector_native_xf16(Vmm* vmm_dst, Vmm* vmm_src, const Xbyak::Xmm& xmm_scalar, int num_regs); + void l2norm_inplace_native_xf16(Vmm* vmm_array, const Xbyak::Xmm& xmm_eps, int num_regs); + + // Buffer-based helpers for qk_head_size > 128 + void l2norm_buffer_compute_scale_native_xf16(const Xbyak::Reg64& reg_buffer, + const Xbyak::Xmm& xmm_eps, + const Xbyak::Xmm& xmm_scale_out, + int num_regs, + int num_chunks); + void scale_buffer_native_xf16(const Xbyak::Reg64& reg_buffer, + const Xbyak::Xmm& xmm_scale, + Vmm* vmm_temp, + int num_regs, + int num_chunks); + + void reduce_zmm_f32_to_xmm_scalar(const Xbyak::Zmm& zmm_src, + const Xbyak::Xmm& xmm_dst, + const Xbyak::Xmm& xmm_tmp0, + const Xbyak::Xmm& xmm_tmp1); + void store(const Xbyak::Reg64& reg_dst, + const Vmm& vmm_src, + ov::element::Type dst_prc, + const int& elt_num, + size_t offset = 0); + void load(const Vmm& vmm_dst, + const Xbyak::Reg64& reg_src, + ov::element::Type src_prc, + const int& elt_num, + bool fill, + size_t offset = 0); + + std::unordered_map> emitters; + const std::vector pool_aux_gpr_idxs; + const std::vector pool_aux_vmm_idxs; +}; + +std::shared_ptr create_gdn_jit_kernel(ov::element::Type data_prc = ov::element::f32, + size_t qk_head_size = 0, + size_t v_tile = 1, + bool fuse_qk_l2norm = false, + float q_l2_norm_eps = 1e-6F, + float k_l2_norm_eps = 1e-6F); + +} // namespace ov::intel_cpu::kernel diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp index 0ad3d1058258..6a15f966775c 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gated_delta_net.cpp @@ -25,6 +25,20 @@ std::vector test_cases = { {1, 2, 2, 2, 15, 15, ov::element::f32, "CPU"}, {1, 2, 2, 2, 31, 31, ov::element::f32, "CPU"}, {1, 2, 2, 2, 1, 1, ov::element::f32, "CPU"}, + // f16 cases + {1, 32, 2, 2, 128, 128, ov::element::f16, "CPU"}, + {1, 32, 4, 4, 128, 128, ov::element::f16, "CPU"}, + {1, 32, 2, 4, 128, 128, ov::element::f16, "CPU"}, + {1, 16, 2, 2, 64, 128, ov::element::f16, "CPU"}, + {1, 32, 4, 4, 256, 256, ov::element::f16, "CPU"}, + {1, 32, 2, 4, 256, 256, ov::element::f16, "CPU"}, + // bf16 cases + {1, 32, 2, 2, 128, 128, ov::element::bf16, "CPU"}, + {1, 32, 4, 4, 128, 128, ov::element::bf16, "CPU"}, + {1, 32, 2, 4, 128, 128, ov::element::bf16, "CPU"}, + {1, 16, 2, 2, 64, 128, ov::element::bf16, "CPU"}, + {1, 32, 4, 4, 256, 256, ov::element::bf16, "CPU"}, + {1, 32, 2, 4, 256, 256, ov::element::bf16, "CPU"}, }; INSTANTIATE_TEST_SUITE_P(smoke_GatedDeltaNet, GatedDeltaNet,