diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 18787ee6e404fe..f8b02194624924 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -76,6 +76,22 @@ auto available_pred = [](const program_node& input) { return true; }; +// Primitives that read input by explicit tensor coordinate and therefore correctly skip +// padding on the input side. +auto reads_padded_input_safely = [](const program_node& user) { + if (user.can_be_optimized()) + return false; + if (user.is_type()) { + auto broadcast_type = user.as().get_primitive()->broadcast_spec.m_type; + return broadcast_type == ov::op::AutoBroadcastType::NONE; + } + return user.is_type() || + user.is_type() || + user.is_type() || + user.is_type() || + user.is_type(); +}; + bool concat_in_place_optimization::match(const program_node& concat_node, kernel_impl_params& concat_params, std::vector& pred_params, @@ -148,9 +164,16 @@ bool concat_in_place_optimization::match(const program_node& concat_node, // TODO: handle optimized reshape if (pred.first->is_type() && pred.first->can_be_optimized()) return false; - // TODO: Investigate if this condition is needed - if (pred.first->get_users().size() > 2) - return false; + // A predecessor with more than two users can still be fused if all non-concat + // users correctly handle a padded input buffer. + if (pred.first->get_users().size() > 2) { + for (const auto& user : pred.first->get_users()) { + if (user->is_type()) + continue; + if (!reads_padded_input_safely(*user)) + return false; + } + } // Check that input isn't optimized out concatenation along different axis. if (pred.first->is_type() && pred.first->can_be_optimized()) { @@ -236,7 +259,7 @@ bool concat_in_place_optimization::match(const program_node& concat_node, idx++; } - // Implicit concat for onednn only when use_usm and batch 1. + // Implicit concat for onednn only when use_usm and batch 1 on the batch axis. if (is_onednn_impl) { bool use_usm = concat_node.get_program().get_engine().use_unified_shared_memory(); const layout& concat_out_l = concat_params.get_output_layout(); @@ -246,7 +269,11 @@ bool concat_in_place_optimization::match(const program_node& concat_node, // Return true in build time, it will be checked again in runtime return true; } else { - if (concat_out_l.batch() > 1) + // Block formats (b_fs_yx_fsv16 etc.) are not contiguous along the batch axis, + // so batch-axis (axis=0) concat with batch>1 cannot safely alias buffers. + // Feature-axis and other axes are fine — the 64-byte alignment check above is + // the correctness gate for those cases. + if (concat_axis_index == 0 && concat_out_l.batch() > 1) return false; const auto& dims_order = concat_out_l.format.dims_order(); for (auto dim : dims_order) { diff --git a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp index c64bffc5ee5490..79f87577e74ead 100644 --- a/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp @@ -1650,6 +1650,236 @@ TEST(prepare_buffer_fusing, in_place_onednn_concat_static) { ASSERT_EQ(ref_output[x], output_ptr[x]); } } + +// Verifies that feature-axis concat with static batch > 1 and non-uniform feature counts +// can be optimized in-place by oneDNN +TEST(prepare_buffer_fusing, in_place_onednn_concat_static_batch_gt1) { + auto& engine = get_test_engine(); + if (!engine.get_device_info().supports_immad) + return; + + // Three inputs with batch=3 and non-uniform feature counts [16, 32, 16] at spatial 2×3. + auto in_layout1 = layout{ ov::PartialShape{3, 16, 2, 3}, data_types::f32, format::bfyx }; + auto in_layout2 = layout{ ov::PartialShape{3, 32, 2, 3}, data_types::f32, format::bfyx }; + auto in_layout3 = layout{ ov::PartialShape{3, 16, 2, 3}, data_types::f32, format::bfyx }; + + auto build_topology = [](bool use_block_format) { + auto fmt = use_block_format ? format::b_fs_yx_fsv16 : format::bfyx; + topology topo; + topo.add(input_layout("input1", layout{ ov::PartialShape{3, 16, 2, 3}, data_types::f32, format::bfyx })); + topo.add(input_layout("input2", layout{ ov::PartialShape{3, 32, 2, 3}, data_types::f32, format::bfyx })); + topo.add(input_layout("input3", layout{ ov::PartialShape{3, 16, 2, 3}, data_types::f32, format::bfyx })); + topo.add(reorder("reorder1", input_info("input1"), fmt, data_types::f16)); + topo.add(reorder("reorder2", input_info("input2"), fmt, data_types::f16)); + topo.add(reorder("reorder3", input_info("input3"), fmt, data_types::f16)); + topo.add(concatenation("concat", { input_info("reorder1"), input_info("reorder2"), input_info("reorder3") }, 1)); + topo.add(reorder("output", input_info("concat"), format::bfyx, data_types::f32)); + return topo; + }; + + auto input_memory1 = engine.allocate_memory(in_layout1); + auto input_memory2 = engine.allocate_memory(in_layout2); + auto input_memory3 = engine.allocate_memory(in_layout3); + tests::random_generator rg(GET_SUITE_NAME); + auto vals1 = rg.generate_random_1d(3 * 16 * 2 * 3, -1, 1); + auto vals2 = rg.generate_random_1d(3 * 32 * 2 * 3, -1, 1); + auto vals3 = rg.generate_random_1d(3 * 16 * 2 * 3, -1, 1); + set_values(input_memory1, vals1); + set_values(input_memory2, vals2); + set_values(input_memory3, vals3); + + // implicit concat — runtime selects the preferred impl (onednn on immad devices) for the path + ExecutionConfig cfg_implicit = get_test_default_config(engine); + cfg_implicit.set_property(ov::intel_gpu::optimize_data(true)); + cfg_implicit.set_property(ov::intel_gpu::allow_new_shape_infer(false)); + network net_implicit(engine, build_topology(true), cfg_implicit); + net_implicit.set_input_data("input1", input_memory1); + net_implicit.set_input_data("input2", input_memory2); + net_implicit.set_input_data("input3", input_memory3); + auto out_implicit = net_implicit.execute(); + + const auto& concat_node = net_implicit.get_primitive("concat")->get_node(); + ASSERT_TRUE(concat_node.can_be_optimized()); + + // explicit concat — reference without in-place optimisation (bfyx to avoid format aliasing) + ExecutionConfig cfg_explicit = get_test_default_config(engine); + cfg_explicit.set_property(ov::intel_gpu::optimize_data(false)); + network net_explicit(engine, build_topology(false), cfg_explicit); + net_explicit.set_input_data("input1", input_memory1); + net_explicit.set_input_data("input2", input_memory2); + net_explicit.set_input_data("input3", input_memory3); + auto out_explicit = net_explicit.execute(); + + auto mem_implicit = out_implicit.at("output").get_memory(); + auto mem_explicit = out_explicit.at("output").get_memory(); + cldnn::mem_lock ptr_implicit(mem_implicit, get_test_stream()); + cldnn::mem_lock ptr_explicit(mem_explicit, get_test_stream()); + + ASSERT_EQ(ptr_implicit.size(), ptr_explicit.size()); + for (size_t i = 0; i < ptr_implicit.size(); i++) + ASSERT_NEAR(ptr_implicit[i], ptr_explicit[i], 1e-3f) << "mismatch at index " << i; +} + +// Verifies that oneDNN in-place concat remains safe when a shared predecessor has multiple users. +TEST(prepare_buffer_fusing, in_place_onednn_concat_multi_user_safe_type) { + auto& engine = get_test_engine(); + if (!engine.get_device_info().supports_immad) + return; + + auto in_layout = layout{ ov::PartialShape{1, 16, 4, 4}, data_types::f32, format::bfyx }; + auto in_layout2 = layout{ ov::PartialShape{1, 16, 4, 4}, data_types::f32, format::bfyx }; + + topology topology; + topology.add(input_layout("shared_in", in_layout)); + topology.add(input_layout("other_in", in_layout2)); + // shared_r: the predecessor node that will have 3 users + topology.add(reorder("shared_r", input_info("shared_in"), format::bfyx, data_types::f16)); + topology.add(reorder("other_r", input_info("other_in"), format::bfyx, data_types::f16)); + + // User 1 of shared_r: concat (the node we want to fuse) — other_r first so shared_r is in the second slot + topology.add(concatenation("concat", { input_info("other_r"), input_info("shared_r") }, 1)); + // User 2 of shared_r: activation relu (safe type — in available_pred, never reads padding) + topology.add(activation("act1", input_info("shared_r"), activation_func::relu)); + // User 3 of shared_r: activation abs (safe type — preserves full value range) + topology.add(activation("act2", input_info("shared_r"), activation_func::abs)); + + topology.add(reorder("out_concat", input_info("concat"), format::bfyx, data_types::f32)); + topology.add(reorder("out_act1", input_info("act1"), format::bfyx, data_types::f32)); + topology.add(reorder("out_act2", input_info("act2"), format::bfyx, data_types::f32)); + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(false)); + config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ + {"shared_r", ov::intel_gpu::ImplementationDesc{format::any, "", impl_types::onednn}}, + {"other_r", ov::intel_gpu::ImplementationDesc{format::any, "", impl_types::onednn}}, + })); + network network(engine, topology, config); + + auto input_memory = engine.allocate_memory(in_layout); + auto input_memory2 = engine.allocate_memory(in_layout2); + const size_t N = 16 * 4 * 4; // 256 + std::vector d1(N), d2(N); + for (size_t i = 0; i < N; i++) { d1[i] = static_cast(i); d2[i] = static_cast(512 + i); } + set_values(input_memory, d1); + set_values(input_memory2, d2); + + network.set_input_data("shared_in", input_memory); + network.set_input_data("other_in", input_memory2); + + std::map output; + EXPECT_NO_THROW(output = network.execute()); + + const auto& concat_node = network.get_primitive("concat")->get_node(); + ASSERT_TRUE(concat_node.can_be_optimized()); + + auto out_concat_mem = output.at("out_concat").get_memory(); + cldnn::mem_lock concat_ptr(out_concat_mem, get_test_stream()); + ASSERT_EQ(concat_ptr.size(), 2 * N); + for (size_t i = 0; i < N; i++) + ASSERT_NEAR(concat_ptr[i], static_cast(512 + i), 1e-3f) << "out_concat first half mismatch at index " << i; + for (size_t i = 0; i < N; i++) + ASSERT_NEAR(concat_ptr[N + i], static_cast(i), 1e-3f) << "out_concat second half mismatch at index " << i; + + // out_act1 = relu(shared_r) = relu(0..N-1) = 0..N-1 (all non-negative) + auto out_act1_mem = output.at("out_act1").get_memory(); + cldnn::mem_lock act1_ptr(out_act1_mem, get_test_stream()); + for (size_t i = 0; i < act1_ptr.size(); i++) + ASSERT_NEAR(act1_ptr[i], static_cast(i), 1e-3f) << "out_act1 mismatch at index " << i; + + // out_act2 = abs(shared_r) = abs(0..N-1) = 0..N-1 (all non-negative, full range preserved) + auto out_act2_mem = output.at("out_act2").get_memory(); + cldnn::mem_lock act2_ptr(out_act2_mem, get_test_stream()); + for (size_t i = 0; i < act2_ptr.size(); i++) + ASSERT_NEAR(act2_ptr[i], static_cast(i), 1e-3f) << "out_act2 mismatch at index " << i; +} + +// Verifies that in-place concat fuses when an oneDNN conv is among the shared predecessor's users. +// +// shared_in → shared_r (b_fs_yx_fsv16, oneDNN) ─┬→ concat → out_concat +// other_in → other_r (b_fs_yx_fsv16) ─┘ +// └→ conv (oneDNN, b_fs_yx_fsv16) → out_conv +// └→ act1 → out_act1 +TEST(prepare_buffer_fusing, in_place_onednn_concat_multi_user_conv_as_user) { + auto& engine = get_test_engine(); + if (!engine.get_device_info().supports_immad) + return; + + auto in_layout = layout{ ov::PartialShape{1, 16, 4, 4}, data_types::f32, format::bfyx }; + auto in_layout2 = layout{ ov::PartialShape{1, 16, 4, 4}, data_types::f32, format::bfyx }; + + auto weights_layout = layout{ ov::PartialShape{16, 16, 1, 1}, data_types::f16, format::bfyx }; + auto weights_mem = engine.allocate_memory(weights_layout); + std::vector wdata(16 * 16, ov::float16(0.f)); + for (int i = 0; i < 16; ++i) + wdata[i * 16 + i] = ov::float16(1.f); + set_values(weights_mem, wdata); + + topology topology; + topology.add(input_layout("shared_in", in_layout)); + topology.add(input_layout("other_in", in_layout2)); + topology.add(data("conv_w", weights_mem)); + + // shared_r must match the conv preferred format so reorder_inputs does not + // insert an intermediate reorder that would break the fusing path. + topology.add(reorder("shared_r", input_info("shared_in"), format::b_fs_yx_fsv16, data_types::f16)); + topology.add(reorder("other_r", input_info("other_in"), format::b_fs_yx_fsv16, data_types::f16)); + + topology.add(concatenation("concat", { input_info("other_r"), input_info("shared_r") }, 1)); + topology.add(convolution("conv", input_info("shared_r"), "conv_w", "", 1, {1, 1}, {1, 1}, {0, 0}, {0, 0}, false)); + topology.add(activation("act1", input_info("shared_r"), activation_func::relu)); + + topology.add(reorder("out_concat", input_info("concat"), format::bfyx, data_types::f32)); + topology.add(reorder("out_conv", input_info("conv"), format::bfyx, data_types::f32)); + topology.add(reorder("out_act1", input_info("act1"), format::bfyx, data_types::f32)); + + ExecutionConfig config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(false)); + config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ + {"shared_r", ov::intel_gpu::ImplementationDesc{format::b_fs_yx_fsv16, "", impl_types::onednn}}, + {"conv", ov::intel_gpu::ImplementationDesc{format::b_fs_yx_fsv16, "", impl_types::onednn}}, + })); + network network(engine, topology, config); + + auto input_memory = engine.allocate_memory(in_layout); + auto input_memory2 = engine.allocate_memory(in_layout2); + // Natural-number sequences — d1 starts at 0, d2 starts at 512 so the two halves of + // the concat output are unambiguously distinguishable + const size_t N = 16 * 4 * 4; // 256 + std::vector d1(N), d2(N); + for (size_t i = 0; i < N; i++) { d1[i] = static_cast(i); d2[i] = static_cast(512 + i); } + set_values(input_memory, d1); + set_values(input_memory2, d2); + + network.set_input_data("shared_in", input_memory); + network.set_input_data("other_in", input_memory2); + + std::map output; + EXPECT_NO_THROW(output = network.execute()); + + // Confirm concat was fused + const auto& concat_node = network.get_primitive("concat")->get_node(); + ASSERT_TRUE(concat_node.can_be_optimized()); + + auto out_concat_mem = output.at("out_concat").get_memory(); + cldnn::mem_lock concat_ptr(out_concat_mem, get_test_stream()); + ASSERT_EQ(concat_ptr.size(), 2 * N); + for (size_t i = 0; i < N; i++) + ASSERT_NEAR(concat_ptr[i], static_cast(512 + i), 1e-2f) << "out_concat first half mismatch at index " << i; + for (size_t i = 0; i < N; i++) + ASSERT_NEAR(concat_ptr[N + i], static_cast(i), 1e-2f) << "out_concat second half mismatch at index " << i; + + auto out_conv_mem = output.at("out_conv").get_memory(); + cldnn::mem_lock conv_ptr(out_conv_mem, get_test_stream()); + for (size_t i = 0; i < conv_ptr.size(); i++) + ASSERT_NEAR(conv_ptr[i], static_cast(i), 1e-2f) << "out_conv mismatch at index " << i; + + auto out_act1_mem = output.at("out_act1").get_memory(); + cldnn::mem_lock act1_ptr(out_act1_mem, get_test_stream()); + for (size_t i = 0; i < act1_ptr.size(); i++) + ASSERT_NEAR(act1_ptr[i], static_cast(i), 1e-2f) << "out_act1 mismatch at index " << i; +} #endif // ENABLE_ONEDNN_FOR_GPU TEST(prepare_buffer_fusing, in_place_concat_with_fsv32_to_fsv16_reorder_regression) { diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/concatenation_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/concatenation_gpu_test.cpp index a3a6439ead349b..fd669e5e07d3ed 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/concatenation_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/concatenation_gpu_test.cpp @@ -2304,7 +2304,7 @@ INSTANTIATE_TEST_SUITE_P(smoke, concat_no_implicit_gpu_onednn_4d_f16, ::testing::Values( TestParamType_implicit_concat(1, { 16 }, 2, 2, format::b_fs_yx_fsv16, true, false), - TestParamType_implicit_concat(2, { 16 }, 2, 2, format::b_fs_yx_fsv16, false, false) + TestParamType_implicit_concat(2, { 16 }, 2, 2, format::b_fs_yx_fsv16, true, false) ), concat_gpu_implicit::PrintToStringParamName);