diff --git a/src/plugins/intel_cpu/src/nodes/kernels/registers_pool.hpp b/src/plugins/intel_cpu/src/nodes/kernels/registers_pool.hpp new file mode 100644 index 00000000000000..35fc795b8ee05a --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/registers_pool.hpp @@ -0,0 +1,349 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu/x64/jit_generator.hpp" +#include +#include "ie_common.h" +#include "utils/cpu_utils.hpp" +#include + +namespace ov { +namespace intel_cpu { + +using namespace dnnl::impl::cpu; + +/** + * The RegistersPool is the base class for the IsaRegistersPool template: + * template + * class IsaRegistersPool : public RegistersPool; + * + * The registers pool must be created by instantiating the IsaRegistersPool template, like the next: + * RegistersPool::Ptr regPool = RegistersPool::create({ + * // the list of the registers to be excluded from pool + * Reg64(Operand::RAX), Reg64(Operand::RCX), Reg64(Operand::RDX), Reg64(Operand::RBX), + * Reg64(Operand::RSP), Reg64(Operand::RBP), Reg64(Operand::RSI), Reg64(Operand::RDI) + * }); + */ +class RegistersPool { +public: + using Ptr = std::shared_ptr; + static constexpr int anyIdx = -1; + + /** + * The scoped wrapper for the Xbyak registers. + * By creating it you are getting the register from the pool RegistersPool. + * It could be created by using constructor with RegistersPool as an argument, like the next: + * const RegistersPool::Reg reg {regPool}; + * The destructor will return the register to the pool. Or it could be returned manually: + * reg.release(); + * @tparam TReg Xbyak register class + */ + template + class Reg { + friend class RegistersPool; + public: + Reg() {} + Reg(const RegistersPool::Ptr& regPool) { initialize(regPool); } + Reg(const RegistersPool::Ptr& regPool, int requestedIdx) { initialize(regPool, requestedIdx); } + ~Reg() { release(); } + Reg& operator=(Reg&& other) noexcept { + release(); + reg = other.reg; + regPool = std::move(other.regPool); + return *this; + } + Reg(Reg&& other) noexcept : reg(other.reg), regPool(std::move(other.regPool)) {} + operator TReg&() { ensureValid(); return reg; } + operator const TReg&() const { ensureValid(); return reg; } + operator Xbyak::RegExp() const { ensureValid(); return reg; } + int getIdx() const { ensureValid(); return reg.getIdx(); } + friend Xbyak::RegExp operator+(const Reg& lhs, const Xbyak::RegExp& rhs) { + lhs.ensureValid(); + return lhs.operator Xbyak::RegExp() + rhs; + } + void release() { + if (regPool) { + regPool->returnToPool(reg); + regPool.reset(); + } + } + bool isInitialized() const { return static_cast(regPool); } + + private: + void ensureValid() const { + if (!isInitialized()) { + IE_THROW() << "RegistersPool::Reg is either not initialized or released"; + } + } + + void initialize(const RegistersPool::Ptr& pool, int requestedIdx = anyIdx) { + static_assert(is_any_of::value, + "Unsupported TReg by RegistersPool::Reg. Please, use the following Xbyak registers either " + "Reg8, Reg16, Reg32, Reg64, Xmm, Ymm, Zmm or Opmask"); + release(); + reg = TReg(pool->template getFree(requestedIdx)); + regPool = pool; + } + + private: + TReg reg; + RegistersPool::Ptr regPool; + }; + + virtual ~RegistersPool() { + checkUniqueAndUpdate(false); + } + + template + static Ptr create(std::initializer_list regsToExclude); + + static Ptr create(x64::cpu_isa_t isa, std::initializer_list regsToExclude); + + template + size_t countFree() const { + static_assert(is_any_of::value, + "Unsupported TReg by RegistersPool::Reg. Please, use the following Xbyak registers either " + "Reg8, Reg16, Reg32, Reg64, Xmm, Ymm, Zmm or Opmask"); + if (std::is_base_of::value) { + return simdSet.countUnused(); + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + return generalSet.countUnused(); + } else if (std::is_same::value) { + return countUnusedOpmask(); + } + } + +protected: + class PhysicalSet { + public: + PhysicalSet(int size) : isFreeIndexVector(size, true) {} + + void setAsUsed(int regIdx) { + if (regIdx >= isFreeIndexVector.size() || regIdx < 0) { + IE_THROW() << "regIdx is out of bounds in RegistersPool::PhysicalSet::setAsUsed()"; + } + if (!isFreeIndexVector[regIdx]) { + IE_THROW() << "Inconsistency in RegistersPool::PhysicalSet::setAsUsed()"; + } + isFreeIndexVector[regIdx] = false; + } + + void setAsUnused(int regIdx) { + if (regIdx >= isFreeIndexVector.size() || regIdx < 0) { + IE_THROW() << "regIdx is out of bounds in RegistersPool::PhysicalSet::setAsUsed()"; + } + if (isFreeIndexVector[regIdx]) { + IE_THROW() << "Inconsistency in RegistersPool::PhysicalSet::setAsUnused()"; + } + isFreeIndexVector[regIdx] = true; + } + + int getUnused(int requestedIdx) { + if (requestedIdx == anyIdx) { + return getFirstFreeIndex(); + } else { + if (requestedIdx >= isFreeIndexVector.size() || requestedIdx < 0) { + IE_THROW() << "requestedIdx is out of bounds in RegistersPool::PhysicalSet::getUnused()"; + } + if (!isFreeIndexVector[requestedIdx]) { + IE_THROW() << "The register with index #" << requestedIdx << " already used in the RegistersPool"; + } + return requestedIdx; + } + } + + void exclude(Xbyak::Reg reg) { + isFreeIndexVector.at(reg.getIdx()) = false; + } + + size_t countUnused() const { + size_t count = 0; + for (const auto& isFree : isFreeIndexVector) { + if (isFree) { + ++count; + } + } + return count; + } + + private: + int getFirstFreeIndex() { + for (int c = 0; c < isFreeIndexVector.size(); ++c) { + if (isFreeIndexVector[c]) { + return c; + } + } + IE_THROW() << "Not enough registers in the RegistersPool"; + } + + private: + std::vector isFreeIndexVector; + }; + + virtual int getFreeOpmask(int requestedIdx) { IE_THROW() << "getFreeOpmask: The Opmask is not supported in current instruction set"; } + virtual void returnOpmaskToPool(int idx) { IE_THROW() << "returnOpmaskToPool: The Opmask is not supported in current instruction set"; } + virtual size_t countUnusedOpmask() const { IE_THROW() << "countUnusedOpmask: The Opmask is not supported in current instruction set"; } + + RegistersPool(int simdRegistersNumber) + : simdSet(simdRegistersNumber) { + checkUniqueAndUpdate(); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RSP)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RAX)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RCX)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RDI)); + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RBP)); + } + + RegistersPool(std::initializer_list regsToExclude, int simdRegistersNumber) + : simdSet(simdRegistersNumber) { + checkUniqueAndUpdate(); + for (auto& reg : regsToExclude) { + if (reg.isXMM() || reg.isYMM() || reg.isZMM()) { + simdSet.exclude(reg); + } else if (reg.isREG()) { + generalSet.exclude(reg); + } + } + generalSet.exclude(Xbyak::Reg64(Xbyak::Operand::RSP)); + } + +private: + template + int getFree(int requestedIdx) { + if (std::is_base_of::value) { + auto idx = simdSet.getUnused(requestedIdx); + simdSet.setAsUsed(idx); + return idx; + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + auto idx = generalSet.getUnused(requestedIdx); + generalSet.setAsUsed(idx); + return idx; + } else if (std::is_same::value) { + return getFreeOpmask(requestedIdx); + } + } + + template + void returnToPool(const TReg& reg) { + if (std::is_base_of::value) { + simdSet.setAsUnused(reg.getIdx()); + } else if (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + generalSet.setAsUnused(reg.getIdx()); + } else if (std::is_same::value) { + returnOpmaskToPool(reg.getIdx()); + } + } + + void checkUniqueAndUpdate(bool isCtor = true) { + static thread_local bool isCreated = false; + if (isCtor) { + if (isCreated) { + IE_THROW() << "There should be only one instance of RegistersPool per thread"; + } + isCreated = true; + } else { + isCreated = false; + } + } + + PhysicalSet generalSet {16}; + PhysicalSet simdSet; +}; + +template +class IsaRegistersPool : public RegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : RegistersPool(regsToExclude, x64::cpu_isa_traits::n_vregs) {} +}; + +template <> +class IsaRegistersPool : public RegistersPool { +public: + IsaRegistersPool() : RegistersPool(x64::cpu_isa_traits::n_vregs) { + opmaskSet.exclude(Xbyak::Opmask(0)); // the Opmask(0) has special meaning for some instructions, like gather instruction + } + + IsaRegistersPool(std::initializer_list regsToExclude) + : RegistersPool(regsToExclude, x64::cpu_isa_traits::n_vregs) { + for (auto& reg : regsToExclude) { + if (reg.isOPMASK()) { + opmaskSet.exclude(reg); + } + } + } + + int getFreeOpmask(int requestedIdx) override { + auto idx = opmaskSet.getUnused(requestedIdx); + opmaskSet.setAsUsed(idx); + return idx; + } + + void returnOpmaskToPool(int idx) override { + opmaskSet.setAsUnused(idx); + } + + size_t countUnusedOpmask() const override { + return opmaskSet.countUnused(); + } + +protected: + PhysicalSet opmaskSet {8}; +}; + +template <> +class IsaRegistersPool : public IsaRegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : IsaRegistersPool(regsToExclude) {} + IsaRegistersPool() : IsaRegistersPool() {} +}; + +template <> +class IsaRegistersPool : public IsaRegistersPool { +public: + IsaRegistersPool(std::initializer_list regsToExclude) : IsaRegistersPool(regsToExclude) {} + IsaRegistersPool() : IsaRegistersPool() {} +}; + +template +RegistersPool::Ptr RegistersPool::create(std::initializer_list regsToExclude) { + return std::make_shared>(regsToExclude); +} + +inline +RegistersPool::Ptr RegistersPool::create(x64::cpu_isa_t isa, std::initializer_list regsToExclude) { +#define ISA_SWITCH_CASE(isa) case isa: return std::make_shared>(regsToExclude); + switch (isa) { + ISA_SWITCH_CASE(x64::sse41) + ISA_SWITCH_CASE(x64::avx) + ISA_SWITCH_CASE(x64::avx2) + ISA_SWITCH_CASE(x64::avx2_vnni) + ISA_SWITCH_CASE(x64::avx512_core) + ISA_SWITCH_CASE(x64::avx512_core_vnni) + ISA_SWITCH_CASE(x64::avx512_core_bf16) + case x64::avx_vnni: return std::make_shared>(regsToExclude); + case x64::avx512_core_bf16_ymm: return std::make_shared>(regsToExclude); + case x64::avx512_core_bf16_amx_int8: return std::make_shared>(regsToExclude); + case x64::avx512_core_bf16_amx_bf16: return std::make_shared>(regsToExclude); + case x64::avx512_core_amx: return std::make_shared>(regsToExclude); + case x64::avx512_vpopcnt: return std::make_shared>(regsToExclude); + case x64::isa_any: + case x64::amx_tile: + case x64::amx_int8: + case x64::amx_bf16: + case x64::isa_all: + IE_THROW() << "Invalid isa argument in RegistersPool::create()"; + } + IE_THROW() << "Invalid isa argument in RegistersPool::create()"; +#undef ISA_SWITCH_CASE +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/utils/cpu_utils.hpp b/src/plugins/intel_cpu/src/utils/cpu_utils.hpp index 0bea750573b3d6..25963b1521a153 100644 --- a/src/plugins/intel_cpu/src/utils/cpu_utils.hpp +++ b/src/plugins/intel_cpu/src/utils/cpu_utils.hpp @@ -10,10 +10,23 @@ #include "ie_common.h" #include "ie_layouts.h" +#include "general_utils.h" namespace ov { namespace intel_cpu { +// helper struct to tell wheter type T is any of given types U... +// termination case when U... is empty -> return std::false_type +template +struct is_any_of : public std::false_type {}; + +// helper struct to tell whether type is any of given types (U, Rest...) +// recurrence case when at least one type U is present -> returns std::true_type if std::same::value is true, +// otherwise call is_any_of recurrently +template +struct is_any_of + : public std::conditional::value, std::true_type, is_any_of>::type {}; + /** * @brief Returns normalized by size dims where missing dimensions are filled with units from the beginning * Example: dims = {2, 3, 5}; ndims = 5; result = {1, 1, 2, 3, 5} diff --git a/src/plugins/intel_cpu/tests/unit/registers_pool.cpp b/src/plugins/intel_cpu/tests/unit/registers_pool.cpp new file mode 100644 index 00000000000000..fa6292c2773441 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/registers_pool.cpp @@ -0,0 +1,247 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "nodes/kernels/registers_pool.hpp" +#include "common/nstl.hpp" + +using namespace ov::intel_cpu; + +template +class RegPoolTest : public ::testing::Test { +protected: + void SetUp() override { + if (typename T::RegT(0).isREG()) { // for general purpose registers Reg8, Reg16, Reg32, Reg64 + regNumber = 15; // the RSP register excluded by default + } else if (typename T::RegT(0).isOPMASK()) { + regNumber = 8; + } else { // SIMD registers + regNumber = x64::cpu_isa_traits::n_vregs; + } + } + + int regNumber; +}; + +TYPED_TEST_SUITE_P(RegPoolTest); + +TYPED_TEST_P(RegPoolTest, get_return_by_scope) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + ASSERT_EQ(regPool->countFree(), this->regNumber); + { + RegistersPool::Reg reg{regPool}; + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); + } + ASSERT_EQ(regPool->countFree(), this->regNumber); +} + +TYPED_TEST_P(RegPoolTest, get_return_by_method) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + ASSERT_EQ(regPool->countFree(), this->regNumber); + RegistersPool::Reg reg{regPool}; + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); + reg.release(); + ASSERT_ANY_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), this->regNumber); + reg = RegistersPool::Reg{regPool}; + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); +} + +TYPED_TEST_P(RegPoolTest, default_ctor) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + RegistersPool::Reg reg; + ASSERT_EQ(regPool->countFree(), this->regNumber); + ASSERT_ANY_THROW([[maybe_unused]] auto val = static_cast(reg)); +} + +TYPED_TEST_P(RegPoolTest, get_all) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + using Ptr = std::shared_ptr>; + std::vector regs(this->regNumber); + for (int c = 0; c < this->regNumber; ++c) { + regs[c] = std::make_shared>(regPool); + } + ASSERT_EQ(regPool->countFree(), 0); + ASSERT_ANY_THROW(RegistersPool::Reg{regPool}); + regs.clear(); + ASSERT_EQ(regPool->countFree(), this->regNumber); +} + +TYPED_TEST_P(RegPoolTest, move) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + RegistersPool::Reg reg{regPool}; + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); + RegistersPool::Reg reg2{regPool}; + ASSERT_EQ(regPool->countFree(), this->regNumber - 2); + auto regIdx = reg.getIdx(); + reg2 = std::move(reg); + ASSERT_EQ(reg2.getIdx(), regIdx); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); + ASSERT_ANY_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg2)); + ASSERT_EQ(regPool->countFree(), this->regNumber - 1); +} + + +TYPED_TEST_P(RegPoolTest, fixed_idx) { + using XbyakRegT = typename TypeParam::RegT; + RegistersPool::Ptr regPool = RegistersPool::create({}); + using Ptr = std::shared_ptr>; + std::vector regs(this->regNumber); + for (int c = 0; c < this->regNumber; ++c) { + if (c == Xbyak::Operand::RSP) continue; + regs[c] = std::make_shared>(regPool, c); + ASSERT_EQ(regs[c]->getIdx(), c); + } + regs[0]->release(); + ASSERT_ANY_THROW(RegistersPool::Reg(regPool, 1)); + ASSERT_NO_THROW(RegistersPool::Reg(regPool, 0)); +} + +TYPED_TEST_P(RegPoolTest, exclude) { + using XbyakRegT = typename TypeParam::RegT; + static constexpr int excludedIdx = 0; + RegistersPool::Ptr regPool = RegistersPool::create({ + XbyakRegT(excludedIdx) + }); + using Ptr = std::shared_ptr>; + std::vector regs(this->regNumber - 1); + std::set idxsInUse; + for (int c = 0; c < this->regNumber - 1; ++c) { + regs[c] = std::make_shared>(regPool); + idxsInUse.emplace(regs[c]->getIdx()); + } + ASSERT_EQ(regPool->countFree(), 0); + ASSERT_TRUE(idxsInUse.find(excludedIdx) == idxsInUse.end()); +} + +namespace combiner { + +template +struct Case { + using RegT = Reg; + using IsaParam = Isa; +}; + +template +struct make_case { + static constexpr std::size_t N = std::tuple_size::value; + + using type = Case::type, + typename std::tuple_element::type>; +}; + +template +struct make_combinations; + +template struct index_sequence { }; +template struct make_index_sequence_impl : make_index_sequence_impl { }; +template struct make_index_sequence_impl <0, S...> { using type = index_sequence; }; +template using make_index_sequence = typename make_index_sequence_impl::type; + +template +struct make_combinations> { + using tuples = std::tuple::type...>; +}; + +template +using Combinations_t = typename make_combinations + , + make_index_sequence<(std::tuple_size::value) *(sizeof...(Params))>>::tuples; + +template +struct TestTypesCombiner; + +template +struct TestTypesCombiner> { + using Types = ::testing::Types; +}; + +} // namespace combiner + +template +struct IsaParam { static constexpr x64::cpu_isa_t isa = Isa; }; + +using TestTypes = combiner::TestTypesCombiner, + IsaParam, + IsaParam, + IsaParam, + IsaParam >>::Types; + +using TestTypesAvx512 = combiner::TestTypesCombiner, + IsaParam, + IsaParam, + IsaParam >>::Types; + +REGISTER_TYPED_TEST_SUITE_P(RegPoolTest, + get_return_by_scope, + get_return_by_method, + default_ctor, + get_all, + move, + fixed_idx, + exclude); + +INSTANTIATE_TYPED_TEST_SUITE_P(testIsaAndRegTypes, RegPoolTest, TestTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(testIsaAndRegTypesAvx512, RegPoolTest, TestTypesAvx512); + + +const int simdRegNumber = x64::cpu_isa_traits::n_vregs; +const int freeGeneralRegNumber = 15; + +TEST(RegistersPoolTests, simd_and_general) { + RegistersPool::Ptr regPool = RegistersPool::create({}); + ASSERT_EQ(regPool->countFree(), simdRegNumber); + ASSERT_EQ(regPool->countFree(), freeGeneralRegNumber); + { + RegistersPool::Reg reg{regPool}; + ASSERT_NO_THROW([[maybe_unused]] auto val = static_cast(reg)); + ASSERT_EQ(regPool->countFree(), simdRegNumber - 1); + ASSERT_EQ(regPool->countFree(), simdRegNumber - 1); + ASSERT_EQ(regPool->countFree(), freeGeneralRegNumber); + } + ASSERT_EQ(regPool->countFree(), simdRegNumber); + ASSERT_EQ(regPool->countFree(), freeGeneralRegNumber); +} + +TEST(RegistersPoolTests, second_pool_exception) { + RegistersPool::Ptr regPool = RegistersPool::create({}); + ASSERT_ANY_THROW(RegistersPool::create({})); +} + +TEST(RegistersPoolTests, get_all) { + RegistersPool::Ptr regPool = RegistersPool::create({}); + using Ptr = std::shared_ptr>; + std::vector regs(simdRegNumber); + std::vector idxs(simdRegNumber); + for (int c = 0; c < simdRegNumber; ++c) { + regs[c] = std::make_shared>(regPool); + idxs[c] = regs[c]->getIdx(); + } + ASSERT_EQ(regPool->countFree(), 0); + ASSERT_ANY_THROW(RegistersPool::Reg{regPool}); + std::sort(idxs.begin(), idxs.end()); + for (int c = 0; c < simdRegNumber; ++c) { + ASSERT_EQ(c, idxs[c]); + } + regs.clear(); + ASSERT_EQ(regPool->countFree(), simdRegNumber); +} +