Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/gated_delta_net.hpp"

#include "common_test_utils/common_utils.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"
#include "common_test_utils/ov_test_utils.hpp"
#include "common_test_utils/test_common.hpp"
#include "common_test_utils/common_utils.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/gated_delta_net.hpp"
#include "openvino/runtime/core.hpp"

namespace {

using GatedDeltaNetParams = std::tuple<
std::vector<ov::Shape>, // Input shapes: query, key, value, state, gate, beta
ov::element::Type, // Input precision
bool>; // fuse_qk_l2norm
using GatedDeltaNetParams = std::tuple<std::vector<ov::Shape>, // Input shapes: query, key, value, state, gate, beta
ov::element::Type, // Input precision
bool>; // fuse_qk_l2norm

class GatedDeltaNetStaticTest : public testing::WithParamInterface<GatedDeltaNetParams>,
public ov::test::TestsCommon {
class GatedDeltaNetStaticTest : public testing::WithParamInterface<GatedDeltaNetParams>, public ov::test::TestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<GatedDeltaNetParams>& obj) {
const auto& [input_shapes, precision, fuse_qk_l2norm] = obj.param;
Expand Down Expand Up @@ -46,16 +46,12 @@ class GatedDeltaNetStaticTest : public testing::WithParamInterface<GatedDeltaNet
auto gate = std::make_shared<ov::op::v0::Parameter>(precision, input_shapes[4]);
auto beta = std::make_shared<ov::op::v0::Parameter>(precision, input_shapes[5]);

auto gdn = std::make_shared<ov::op::internal::GatedDeltaNet>(
query, key, value, state, gate, beta, fuse_qk_l2norm);
auto gdn = std::make_shared<ov::op::internal::GatedDeltaNet>(query, key, value, state, gate, beta, fuse_qk_l2norm);

auto result0 = std::make_shared<ov::op::v0::Result>(gdn->output(0));
auto result1 = std::make_shared<ov::op::v0::Result>(gdn->output(1));

model = std::make_shared<ov::Model>(
ov::ResultVector{result0, result1},
ov::ParameterVector{query, key, value, state, gate, beta},
"GatedDeltaNetTest");
model = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, ov::ParameterVector{query, key, value, state, gate, beta}, "GatedDeltaNetTest");
}

std::map<std::shared_ptr<ov::op::v0::Parameter>, ov::Tensor> generate_inputs() {
Expand All @@ -66,29 +62,28 @@ class GatedDeltaNetStaticTest : public testing::WithParamInterface<GatedDeltaNet
if (i == 4) {
in_data = ov::test::utils::InputGenerateData(-1, 1, 1000, 1);
}
inputs[params[i]] = ov::test::utils::create_and_fill_tensor(
params[i]->get_element_type(), params[i]->get_shape(), in_data);
inputs[params[i]] = ov::test::utils::create_and_fill_tensor(params[i]->get_element_type(), params[i]->get_shape(), in_data);
}
return inputs;
}

std::shared_ptr<ov::Model> model;
};

TEST_P(GatedDeltaNetStaticTest, CompareWithCPU) {
TEST_P(GatedDeltaNetStaticTest, CompareWithTemplate) {
auto inputs = generate_inputs();

ov::Core core;

// Run on CPU (reference)
auto compiled_cpu = core.compile_model(model, "CPU");
auto req_cpu = compiled_cpu.create_infer_request();
for (const auto& [param, tensor] : inputs) {
req_cpu.set_tensor(param->output(0), tensor);
// Build input tensor vector for infer_on_template
ov::TensorVector input_tensors;
for (const auto& param : model->get_parameters()) {
input_tensors.push_back(inputs.at(param));
}
req_cpu.infer();

// Run on TEMPLATE (reference)
auto ref_outputs = ov::test::utils::infer_on_template(model, input_tensors);

// Run on GPU
ov::Core core;
auto compiled_gpu = core.compile_model(model, "GPU");
auto req_gpu = compiled_gpu.create_infer_request();
for (const auto& [param, tensor] : inputs) {
Expand All @@ -98,30 +93,28 @@ TEST_P(GatedDeltaNetStaticTest, CompareWithCPU) {

// Compare outputs
for (size_t i = 0; i < model->get_output_size(); i++) {
auto out_cpu = req_cpu.get_output_tensor(i);
auto out_gpu = req_gpu.get_output_tensor(i);
ov::test::utils::compare(out_cpu, out_gpu, 1e-2, 1e-2);
ov::test::utils::compare(ref_outputs[i], out_gpu, 1e-2, 1e-2);
}
}

// Shapes: query[B,S,H,D], key[B,S,H,D], value[B,S,H,Dv], state[B,H,D,Dv], gate[B,S,H], beta[B,S,H]
// Shapes: query[B,S,qk_H,D], key[B,S,qk_H,D], value[B,S,v_H,Dv], state[B,v_H,D,Dv], gate[B,S,v_H], beta[B,S,v_H]
const std::vector<std::vector<ov::Shape>> static_shapes = {
// B=1, S=1, H=4, D=16, Dv=16 (minimal)
// B=1, S=1, qk_H=4, v_H=4, D=16, Dv=16 (minimal)
{{1, 1, 4, 16}, {1, 1, 4, 16}, {1, 1, 4, 16}, {1, 4, 16, 16}, {1, 1, 4}, {1, 1, 4}},
// B=1, S=1, H=32, D=128, Dv=128 (typical LLM decode)
// B=1, S=1, qk_H=32, v_H=32, D=128, Dv=128 (typical LLM decode)
{{1, 1, 32, 128}, {1, 1, 32, 128}, {1, 1, 32, 128}, {1, 32, 128, 128}, {1, 1, 32}, {1, 1, 32}},
// B=1, S=16, H=2, D=16, Dv=32 (seq_len > 1, different D and Dv)
// B=1, S=16, qk_H=2, v_H=2, D=16, Dv=32 (seq_len > 1, different D and Dv)
{{1, 16, 2, 16}, {1, 16, 2, 16}, {1, 16, 2, 32}, {1, 2, 16, 32}, {1, 16, 2}, {1, 16, 2}},
// B=2, S=1, H=8, D=64, Dv=64 (batch > 1)
// B=2, S=1, qk_H=8, v_H=8, D=64, Dv=64 (batch > 1)
{{2, 1, 8, 64}, {2, 1, 8, 64}, {2, 1, 8, 64}, {2, 8, 64, 64}, {2, 1, 8}, {2, 1, 8}},
// B=1, S=4, qk_H=2, v_H=8, D=16, Dv=16 (GQA: v_H is multiple of qk_H)
{{1, 4, 2, 16}, {1, 4, 2, 16}, {1, 4, 8, 16}, {1, 8, 16, 16}, {1, 4, 8}, {1, 4, 8}},
};

INSTANTIATE_TEST_SUITE_P(
smoke_GatedDeltaNetStatic,
GatedDeltaNetStaticTest,
::testing::Combine(::testing::ValuesIn(static_shapes),
::testing::Values(ov::element::f32),
::testing::Values(false, true)),
GatedDeltaNetStaticTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_GatedDeltaNetStatic,
GatedDeltaNetStaticTest,
::testing::Combine(::testing::ValuesIn(static_shapes), ::testing::Values(ov::element::f32), ::testing::Values(false, true)),
GatedDeltaNetStaticTest::getTestCaseName);

} // namespace
156 changes: 156 additions & 0 deletions src/plugins/template/backend/ops/gated_delta_net.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/gated_delta_net.hpp"

#include <cmath>
#include <vector>

#include "evaluate_node.hpp"
#include "openvino/core/type/element_type_traits.hpp"

template <ov::element::Type_t ET>
bool evaluate(const std::shared_ptr<ov::op::internal::GatedDeltaNet>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T = typename ov::element_type_traits<ET>::value_type;

const auto& q_shape = inputs[0].get_shape();
const auto& v_shape = inputs[2].get_shape();
const auto& state_shape = inputs[3].get_shape();

const size_t B = q_shape[0];
const size_t S = q_shape[1];
const size_t qk_H = q_shape[2];
const size_t D = q_shape[3];
const size_t v_H = v_shape[2];
const size_t Dv = v_shape[3];

OPENVINO_ASSERT(qk_H > 0 && v_H >= qk_H && v_H % qk_H == 0,
"GatedDeltaNet evaluate: v_H (",
v_H,
") must be a positive multiple of qk_H (",
qk_H,
")");
const size_t group_size = v_H / qk_H;

outputs[0].set_shape(v_shape);
outputs[1].set_shape(state_shape);

const T* q_data = inputs[0].data<const T>();
const T* k_data = inputs[1].data<const T>();
const T* v_data = inputs[2].data<const T>();
const T* state_data = inputs[3].data<const T>();
const T* gate_data = inputs[4].data<const T>();
const T* beta_data = inputs[5].data<const T>();

T* out_state = outputs[1].data<T>();
T* out_data = outputs[0].data<T>();
const T attn_scale = static_cast<T>(1) / std::sqrt(static_cast<T>(D));

const size_t qk_stride_batch = S * qk_H * D;
const size_t v_stride_batch = S * v_H * Dv;
const size_t gate_beta_stride_batch = S * v_H;

const bool fuse_qk_l2norm = op->get_fuse_qk_l2norm();
const T q_l2_norm_eps = static_cast<T>(op->get_q_l2_norm_eps());
const T k_l2_norm_eps = static_cast<T>(op->get_k_l2_norm_eps());

auto dot_product = [](const T* a, const T* b, size_t n) {
T result = static_cast<T>(0);
for (size_t i = 0; i < n; i++) {
result += a[i] * b[i];
}
return result;
};

auto l2norm = [](std::vector<T>& vec, T eps) {
T sum = static_cast<T>(0);
for (size_t i = 0; i < vec.size(); i++)
sum += vec[i] * vec[i];
sum = static_cast<T>(1) / std::sqrt(sum + eps);
for (size_t i = 0; i < vec.size(); i++)
vec[i] *= sum;
};

for (size_t b = 0; b < B; b++) {
for (size_t h_v = 0; h_v < v_H; h_v++) {
const size_t h_qk = h_v / group_size;
for (size_t d_v = 0; d_v < Dv; d_v++) {
// state layout: [B, v_H, D, Dv]
const size_t state_offset = b * v_H * D * Dv + h_v * D * Dv + d_v;
T* state_ptr = out_state + state_offset;

// Load initial state from input
std::vector<T> local_state(D);
const T* src_state = state_data + state_offset;
for (size_t d = 0; d < D; d++) {
local_state[d] = src_state[d * Dv];
}

for (size_t t = 0; t < S; t++) {
const T* q_ptr = q_data + b * qk_stride_batch + t * qk_H * D + h_qk * D;
const T* k_ptr = k_data + b * qk_stride_batch + t * qk_H * D + h_qk * D;

std::vector<T> q_vec(q_ptr, q_ptr + D);
std::vector<T> k_vec(k_ptr, k_ptr + D);

if (fuse_qk_l2norm) {
l2norm(q_vec, q_l2_norm_eps);
l2norm(k_vec, k_l2_norm_eps);
}

// Scale q
for (size_t i = 0; i < D; i++)
q_vec[i] *= attn_scale;

// gate[b, t, h_v] — layout [B, S, v_H]
T g = std::exp(gate_data[b * gate_beta_stride_batch + t * v_H + h_v]);
T bt = beta_data[b * gate_beta_stride_batch + t * v_H + h_v];

// Decay state: state *= g
for (size_t d = 0; d < D; d++) {
local_state[d] *= g;
}

// h_k = dot(state, k)
T h_k = dot_product(local_state.data(), k_vec.data(), D);

// delta: v_val = value[b, t, h_v, d_v] - h_k
T v_val = v_data[b * v_stride_batch + t * v_H * Dv + h_v * Dv + d_v] - h_k;

// Update state: state += k * (v_val * beta)
T update_scale = v_val * bt;
for (size_t d = 0; d < D; d++) {
local_state[d] += k_vec[d] * update_scale;
}

// Output: out[b, t, h_v, d_v] = dot(state, q)
out_data[b * v_stride_batch + t * v_H * Dv + h_v * Dv + d_v] =
dot_product(local_state.data(), q_vec.data(), D);
}

// Write final state back
for (size_t d = 0; d < D; d++) {
state_ptr[d * Dv] = local_state[d];
}
}
}
}
return true;
}

template <>
bool evaluate_node<ov::op::internal::GatedDeltaNet>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_input_element_type(0);

switch (element_type) {
case ov::element::f32:
return evaluate<ov::element::f32>(ov::as_type_ptr<ov::op::internal::GatedDeltaNet>(node), outputs, inputs);
default:
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node<GatedDeltaNet>()");
}
}
5 changes: 5 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once
#include "evaluate_node.hpp"
#include "openvino/op/gated_delta_net.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/op/rms_norm.hpp"
Expand Down Expand Up @@ -549,6 +550,10 @@ extern template bool evaluate_node<ov::op::internal::AUGRUSequence>(std::shared_
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::internal::GatedDeltaNet>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::internal::RMS>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/template/backend/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ _OPENVINO_OP_REG(OneHot, ov::op::v16)

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
_OPENVINO_OP_REG(GatedDeltaNet, ov::op::internal)
_OPENVINO_OP_REG(RMS, ov::op::internal)
_OPENVINO_OP_REG(RMSNorm, ov::op::internal)
_OPENVINO_OP_REG(PagedAttentionExtension, ov::op)
Loading