Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class MatMulOp : public framework::OperatorWithKernel {
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");

auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0,
auto mat_dim_x = math::GetMatDim(ReshapeVectorToRowMatrix(dim_x), 0,
context->Attrs().Get<bool>("transpose_X"));
auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0,
auto mat_dim_y = math::GetMatDim(ReshapeVectorToColumnMatrix(dim_y), 0,
context->Attrs().Get<bool>("transpose_Y"));

PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
Expand Down
44 changes: 24 additions & 20 deletions paddle/fluid/operators/matmul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,30 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <algorithm>
#include <functional>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {
inline framework::DDim GetXDim(const framework::DDim& x_dim) {

inline framework::DDim ReshapeVectorToRowMatrix(const framework::DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return framework::make_ddim({1, x_dim[0]});
}

inline framework::DDim GetYDim(const framework::DDim& y_dim) {
inline framework::DDim ReshapeVectorToColumnMatrix(
const framework::DDim& y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
Expand All @@ -50,17 +53,17 @@ class MatMulKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace());

auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::GetMatDim(GetXDim(x.dims()), 0,
auto mat_dim_a = math::GetMatDim(ReshapeVectorToRowMatrix(x.dims()), 0,
context.Attr<bool>("transpose_X"));
auto mat_dim_b = math::GetMatDim(GetYDim(y.dims()), 0,
auto mat_dim_b = math::GetMatDim(ReshapeVectorToColumnMatrix(y.dims()), 0,
context.Attr<bool>("transpose_Y"));
blas.MatMul(x, mat_dim_a, y, mat_dim_b, T(1), out, T(0));
}
};

// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) {
// If input is a 3-dimensional tensor, reshape it from PxMxN into
// 2-dimensional one of size (PxM)xN; otherwise, return input.
inline framework::Tensor UnfoldFirstTwoDims(const framework::Tensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
Expand All @@ -69,12 +72,13 @@ inline framework::Tensor CombineBatchAndM(const framework::Tensor& input) {
return output;
}

// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
const framework::Tensor& input) {
// If input is a 3-dimensional tensor, reshape it from PxMxN into a
// 2-dimensional one of size Mx(PxN); otherwise, return input.
// Because this transofrmation depends on transposing and writing into
// new memory buffer, a DeviceContext is required.
template <typename DeviceContext>
inline framework::Tensor UnfoldLastTwoDims(const DeviceContext& context,
const framework::Tensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
Expand Down Expand Up @@ -109,8 +113,8 @@ inline void NormalizeXYOutTensorShape(framework::Tensor* x,
framework::Tensor* y,
framework::Tensor* out, bool trans_a,
bool trans_b) {
auto x_dim = GetXDim(x->dims());
auto y_dim = GetYDim(y->dims());
auto x_dim = ReshapeVectorToRowMatrix(x->dims());
auto y_dim = ReshapeVectorToColumnMatrix(y->dims());
auto mat_dim_x = math::GetMatDim(x_dim, 0, trans_a);
auto mat_dim_y = math::GetMatDim(y_dim, 0, trans_b);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
Expand Down Expand Up @@ -176,10 +180,10 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(
context, is_combine_m_a ? CombineBatchAndM(a)
: CombineBatchAndN<DeviceContext, T>(ctx, a),
trans_a, is_combine_m_b ? CombineBatchAndM(b)
: CombineBatchAndN<DeviceContext, T>(ctx, b),
context, is_combine_m_a ? UnfoldFirstTwoDims(a)
: UnfoldLastTwoDims<DeviceContext, T>(ctx, a),
trans_a, is_combine_m_b ? UnfoldFirstTwoDims(b)
: UnfoldLastTwoDims<DeviceContext, T>(ctx, b),
trans_b, out);
}
}
Expand Down