Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/ap/facade_matmul_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def infer_symbolic(infer_ctx, inputs, attrs):
return inputs


def infer_meta(inputs, attrs, mut_outputs):
mut_outputs[0].dims = inputs[0].dims
mut_outputs[0].dtype = inputs[0].dtype
1 change: 1 addition & 0 deletions tests/ap/make_axpr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ FILENAMES_ARRAY=(
"matmul_epilogue_pass"
"test_matmul_binary"
"test_matmul_epilogue"
"facade_matmul_utils"
)
for filename in "${FILENAMES_ARRAY[@]}"
do
Expand Down
60 changes: 46 additions & 14 deletions tests/ap/paddle-tests/test_matmul_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,47 @@

import paddle
from paddle.static import InputSpec
import paddle.incubate.cc as pcc


def matmul_add_relu(x, y, b):
out = paddle.matmul(x, y)
return paddle.nn.functional.relu(out + b)


def matmul_add_gelu_true(x, y, b):
def matmul_add_gelu(x, y, b):
out = paddle.matmul(x, y)
return paddle.nn.functional.gelu(out + b, True)
return paddle.nn.functional.gelu(out + b, False)


class FacadeMatmulOp(pcc.ap.FacadeOp):
def __init__(self):
super().__init__()

def custom_op_name(self) -> str:
return "ap_custom_op.facade_matmul"

def infer_meta(self) -> str:
return "facade_matmul_utils.infer_meta"

def infer_symbolic(self) -> str:
return "facade_matmul_utils.infer_symbolic"

def num_inputs(self) -> int:
return 2

def num_outputs(self, args) -> int:
return len(args)

def attributes_schema(self):
# annotations matter.
pass


def facade_matmul_add(x, y, b):
facade_matmul_op = FacadeMatmulOp()
out = facade_matmul_op([x, y])
return out[0] + b


class CINNSubGraphNet(paddle.nn.Layer):
Expand All @@ -56,15 +87,15 @@ def setUp(self):
def prepare_data(self):
self.dtype = "float16"

self.x_shape = [4, 65536, 128]
self.x_shape = [4, 64, 64]
self.x = paddle.randn(self.x_shape, dtype=self.dtype)
self.x.stop_gradient = False

self.y_shape = [128, 32]
self.y_shape = [64, 64]
self.y = paddle.randn(self.y_shape, dtype=self.dtype)
self.y.stop_gradient = False

self.b_shape = [32]
self.b_shape = [64]
self.b = paddle.randn(self.b_shape, dtype=self.dtype)
self.b.stop_gradient = False

Expand All @@ -81,19 +112,20 @@ def eval_symbolic(self, net, use_cinn, profile):

def test_matmul_add_relu(self):
profile = False
net = CINNSubGraphNet(matmul_add_relu)
# net = CINNSubGraphNet(matmul_add_relu)
net = CINNSubGraphNet(facade_matmul_add)
cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile)
dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile)
if not profile:
utils.check_result(self.dtype, cinn_out, dy2st_out)
# dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile)
# if not profile:
# utils.check_result(self.dtype, cinn_out, dy2st_out)

def notest_matmul_add_gelu(self):
profile = False
net = CINNSubGraphNet(matmul_add_gelu_true)
profile = True
net = CINNSubGraphNet(matmul_add_gelu)
cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile)
dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile)
if not profile:
utils.check_result(self.dtype, cinn_out, dy2st_out)
# dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile)
# if not profile:
# utils.check_result(self.dtype, cinn_out, dy2st_out)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion tests/ap/test_matmul_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
@abstract_drr.register_drr_pass("matmul_binary_fusion", nice=0)
class MatmulBinaryFusion(abstract_drr.DrrPass):
def source_pattern(self, o, t):
o.matmul_op = o.ap_native_op("pd_op.matmul")
#o.matmul_op = o.ap_native_op("pd_op.matmul")
o.matmul_op = o.ap_native_op("ap_op.facade")
o.matmul_op.custom_op_name = pir.a_str("ap_custom_op.facade_matmul")
o.matmul_op(
[t.input0, t.input1],
[t.mm_out]
Expand Down Expand Up @@ -128,6 +130,7 @@ def _make_index_func_unique_id2index_program(
self, compute_program, anchor_data_op_name, input_names, output_names):
full_index_program = compute_program.clone()
self._apply_topo_access_passes(full_index_program, anchor_data_op_name)
print("full_index_program:", full_index_program)
def MatchAndCopyInputIndex(dst_input_name):
pass_manager = ir_tools.create_pass_manager()
removed_programs = MutableList()
Expand Down