From 35a9fe41afd638b5230ad3d7a6ab60acbbbee458 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 7 May 2018 15:35:23 -0700 Subject: [PATCH] update --- paddle/fluid/operators/matmul_op.cc | 4 +-- paddle/fluid/operators/matmul_op.h | 44 ++++++++++++++++------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index c285d461e85619..0440994f53b6f6 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -37,9 +37,9 @@ class MatMulOp : public framework::OperatorWithKernel { auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0, + auto mat_dim_x = math::GetMatDim(ReshapeVectorToRowMatrix(dim_x), 0, context->Attrs().Get("transpose_X")); - auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0, + auto mat_dim_y = math::GetMatDim(ReshapeVectorToColumnMatrix(dim_y), 0, context->Attrs().Get("transpose_Y")); PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_); diff --git a/paddle/fluid/operators/matmul_op.h b/paddle/fluid/operators/matmul_op.h index 7b484d124a7bc9..a8a5b16dbaef0e 100644 --- a/paddle/fluid/operators/matmul_op.h +++ b/paddle/fluid/operators/matmul_op.h @@ -11,12 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once + #include #include #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" @@ -24,14 +25,16 @@ limitations under the License. */ namespace paddle { namespace operators { -inline framework::DDim GetXDim(const framework::DDim& x_dim) { + +inline framework::DDim ReshapeVectorToRowMatrix(const framework::DDim& x_dim) { if (x_dim.size() > 1) { return x_dim; } return framework::make_ddim({1, x_dim[0]}); } -inline framework::DDim GetYDim(const framework::DDim& y_dim) { +inline framework::DDim ReshapeVectorToColumnMatrix( + const framework::DDim& y_dim) { if (y_dim.size() > 1) { return y_dim; } @@ -50,17 +53,17 @@ class MatMulKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); auto blas = math::GetBlas(context); - auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0, + auto mat_dim_a = math::GetMatDim(ReshapeVectorToRowMatrix(x.dims()), 0, context.Attr("transpose_X")); - auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0, + auto mat_dim_b = math::GetMatDim(ReshapeVectorToColumnMatrix(y.dims()), 0, context.Attr("transpose_Y")); blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0)); } }; -// Reshape a rank-3 tensor from P x M x N to (P * M) x N. -// Identity op if the tensor is not of rank 3. -inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) { +// If input is a 3-dimensional tensor, reshape it from PxMxN into +// 2-dimensional one of size (PxM)xN; otherwise, return input. +inline framework::Tensor UnfoldFirstTwoDims(const framework::Tensor& input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { @@ -69,12 +72,13 @@ inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) { return output; } -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -inline framework::Tensor CombineBatchAndN(const DeviceContext& context, - const framework::Tensor& input) { +// If input is a 3-dimensional tensor, reshape it from PxMxN into a +// 2-dimensional one of size Mx(PxN); otherwise, return input. +// Because this transofrmation depends on transposing and writing into +// new memory buffer, a DeviceContext is required. +template +inline framework::Tensor UnfoldLastTwoDims(const DeviceContext& context, + const framework::Tensor& input) { auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -109,8 +113,8 @@ inline void NormalizeXYOutTensorShape(framework::Tensor* x, framework::Tensor* y, framework::Tensor* out, bool trans_a, bool trans_b) { - auto x_dim = GetXDim(x->dims()); - auto y_dim = GetYDim(y->dims()); + auto x_dim = ReshapeVectorToRowMatrix(x->dims()); + auto y_dim = ReshapeVectorToColumnMatrix(y->dims()); auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a); auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b); if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { @@ -176,10 +180,10 @@ class MatMulGradKernel : public framework::OpKernel { } else { auto& ctx = context.template device_context(); MatMul( - context, is_combine_m_a ? CombineBatchAndM(a) - : CombineBatchAndN(ctx, a), - trans_a, is_combine_m_b ? CombineBatchAndM(b) - : CombineBatchAndN(ctx, b), + context, is_combine_m_a ? UnfoldFirstTwoDims(a) + : UnfoldLastTwoDims(ctx, a), + trans_a, is_combine_m_b ? UnfoldFirstTwoDims(b) + : UnfoldLastTwoDims(ctx, b), trans_b, out); } }