diff --git a/src/plugins/intel_cpu/src/nodes/rnn.cpp b/src/plugins/intel_cpu/src/nodes/rnn.cpp index 0e6afc3e978e..da11478d3b1b 100644 --- a/src/plugins/intel_cpu/src/nodes/rnn.cpp +++ b/src/plugins/intel_cpu/src/nodes/rnn.cpp @@ -627,9 +627,11 @@ void RNN::fillCellDesc() { inCandidate.emplace_back(std::make_shared(shapeS, inDataTypes[cIdx], memory::format_tag::nc)); outCandidate.emplace_back(std::make_shared(shapeS, outDataTypes[coIdx], memory::format_tag::nc)); } - + // The weight and weights_iter would expose nc layout to avoid unnecessary reorder. + // The onednn would determine the final layout when prepareParams. inCandidate.emplace_back(std::make_shared(WShape, inDataTypes[wIdx], memory::format_tag::nc)); inCandidate.emplace_back(std::make_shared(RShape, inDataTypes[rIdx], memory::format_tag::nc)); + inCandidate.emplace_back(std::make_shared(BShape, inDataTypes[bIdx], memory::format_tag::x)); if (haveAttention(cell_type)) { @@ -732,8 +734,11 @@ void RNN::fillSequenceDesc() { } inCandidate.emplace_back(std::make_shared(TShape, inDataTypes[sIdx], memory::format_tag::x)); // sequence lengths - inCandidate.emplace_back(std::make_shared(WShape, inDataTypes[wIdx], memory::format_tag::ntc)); // W - inCandidate.emplace_back(std::make_shared(RShape, inDataTypes[rIdx], memory::format_tag::ntc)); // R + // The weight and weights_iter would expose tnc layout to avoid unnecessary reorder. + // The onednn would determine the final layout when prepareParams. + inCandidate.emplace_back(std::make_shared(WShape, inDataTypes[wIdx], memory::format_tag::tnc)); // W + inCandidate.emplace_back(std::make_shared(RShape, inDataTypes[rIdx], memory::format_tag::tnc)); // R + inCandidate.emplace_back(std::make_shared(BShape, inDataTypes[bIdx], memory::format_tag::nc)); // B if (haveAttention(cell_type)) { @@ -901,9 +906,6 @@ void RNN::copyWeightsData() { if (one_of(dataType, memory::data_type::bf16, memory::data_type::f16)) { fillWeights(gate_map, wIdx, rIdx); } else if (dataType == memory::data_type::f32) { - // WA To avoid different weights layer and iter formats in FP32 case - if (T.minVal > 1 || N.maxVal < optimalBatchSize) - wFormat = dnnl::memory::format_tag::ldigo; fillWeights(gate_map, wIdx, rIdx); } else if (dataType == memory::data_type::u8 || dataType == memory::data_type::s8) { fillWeights(gate_map, wIdx, rIdx); @@ -1042,9 +1044,11 @@ void RNN::createDescriptor(const std::vector &inputDesc, since internalBlobs are used for the execution, not the initial weights */ const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]); auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC }); - wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat); + //onednn determines the preferred weight layout. + wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, memory::format_tag::any); auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC }); - wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat); + //onednn determines the preferred weights_iter layout. + wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, memory::format_tag::any); auto biasDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, Gb, SC }); wDescs[2] = dnnl::memory::desc(biasDims, inDataTypes[bIdx], memory::format_tag::ldgo); @@ -1119,27 +1123,6 @@ void RNN::prepareParams() { inDataDescs[2] = std::make_shared(Shape{SL, B, 1}, inDataTypes[aIdx], memory::format_tag::tnc); } - bool wFormatWasChanged = false; - // WA To avoid different weights layer and iter formats in FP32 case. - if (one_of(inDataTypes[xIdx], memory::data_type::f32) && - (SL != 1 || B < optimalBatchSize)) { - if (wFormat != dnnl::memory::format_tag::ldigo) { - wFormat = dnnl::memory::format_tag::ldigo; - wFormatWasChanged = true; - } - } else if (wFormat != dnnl::memory::format_tag::any) { - wFormat = dnnl::memory::format_tag::any; - wFormatWasChanged = true; - } - - if (wFormatWasChanged) { - auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC }); - const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]); - wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat); - auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC }); - wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat); - } - const auto attr = initPrimitiveAttr(); RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type, cell_act, direction, *attr }; diff --git a/src/plugins/intel_cpu/src/nodes/rnn.h b/src/plugins/intel_cpu/src/nodes/rnn.h index f163b85c88e1..dcc5faed48ec 100644 --- a/src/plugins/intel_cpu/src/nodes/rnn.h +++ b/src/plugins/intel_cpu/src/nodes/rnn.h @@ -105,9 +105,6 @@ class RNN : public Node { /** activation type for vanilla RNN cell */ dnnl::algorithm cell_act = dnnl::algorithm::undef; - /** Weights data and state memory format: ldigo or any */ - dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::any; - struct Interval { Interval() = default;