Skip to content
41 changes: 12 additions & 29 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,11 @@ void RNN::fillCellDesc() {
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, inDataTypes[cIdx], memory::format_tag::nc));
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, outDataTypes[coIdx], memory::format_tag::nc));
}

// The weight and weights_iter would expose nc layout to avoid unnecessary reorder.
// The onednn would determine the final layout when prepareParams.
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, inDataTypes[wIdx], memory::format_tag::nc));
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, inDataTypes[rIdx], memory::format_tag::nc));

inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(BShape, inDataTypes[bIdx], memory::format_tag::x));

if (haveAttention(cell_type)) {
Expand Down Expand Up @@ -732,8 +734,11 @@ void RNN::fillSequenceDesc() {
}

inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(TShape, inDataTypes[sIdx], memory::format_tag::x)); // sequence lengths
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, inDataTypes[wIdx], memory::format_tag::ntc)); // W
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, inDataTypes[rIdx], memory::format_tag::ntc)); // R
// The weight and weights_iter would expose tnc layout to avoid unnecessary reorder.
// The onednn would determine the final layout when prepareParams.
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, inDataTypes[wIdx], memory::format_tag::tnc)); // W
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, inDataTypes[rIdx], memory::format_tag::tnc)); // R

inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(BShape, inDataTypes[bIdx], memory::format_tag::nc)); // B

if (haveAttention(cell_type)) {
Expand Down Expand Up @@ -901,9 +906,6 @@ void RNN::copyWeightsData() {
if (one_of(dataType, memory::data_type::bf16, memory::data_type::f16)) {
fillWeights<uint16_t>(gate_map, wIdx, rIdx);
} else if (dataType == memory::data_type::f32) {
// WA To avoid different weights layer and iter formats in FP32 case
if (T.minVal > 1 || N.maxVal < optimalBatchSize)
wFormat = dnnl::memory::format_tag::ldigo;
fillWeights<float>(gate_map, wIdx, rIdx);
} else if (dataType == memory::data_type::u8 || dataType == memory::data_type::s8) {
fillWeights<int8_t>(gate_map, wIdx, rIdx);
Expand Down Expand Up @@ -1042,9 +1044,11 @@ void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
since internalBlobs are used for the execution, not the initial weights */
const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]);
auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC });
wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat);
//onednn determines the preferred weight layout.
wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, memory::format_tag::any);
auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC });
wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat);
//onednn determines the preferred weights_iter layout.
wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, memory::format_tag::any);
auto biasDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, Gb, SC });
wDescs[2] = dnnl::memory::desc(biasDims, inDataTypes[bIdx], memory::format_tag::ldgo);

Expand Down Expand Up @@ -1119,27 +1123,6 @@ void RNN::prepareParams() {
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, 1}, inDataTypes[aIdx], memory::format_tag::tnc);
}

bool wFormatWasChanged = false;
// WA To avoid different weights layer and iter formats in FP32 case.
if (one_of(inDataTypes[xIdx], memory::data_type::f32) &&
(SL != 1 || B < optimalBatchSize)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this heuristic is not needed anymore?

if (wFormat != dnnl::memory::format_tag::ldigo) {
wFormat = dnnl::memory::format_tag::ldigo;
wFormatWasChanged = true;
}
} else if (wFormat != dnnl::memory::format_tag::any) {
wFormat = dnnl::memory::format_tag::any;
wFormatWasChanged = true;
}

if (wFormatWasChanged) {
auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC });
const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]);
wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat);
auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC });
wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat);
}

const auto attr = initPrimitiveAttr();
RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type, cell_act, direction, *attr };

Expand Down
3 changes: 0 additions & 3 deletions src/plugins/intel_cpu/src/nodes/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ class RNN : public Node {
/** activation type for vanilla RNN cell */
dnnl::algorithm cell_act = dnnl::algorithm::undef;

/** Weights data and state memory format: ldigo or any */
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::any;

struct Interval {
Interval() = default;

Expand Down