From 8eee03812e715757ff8c5a79daf60e6c1afb346a Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 16 May 2025 09:42:38 +0800 Subject: [PATCH] Add example of facade_matmul. --- tests/ap/facade_matmul_utils.py | 7 +++ tests/ap/make_axpr.sh | 1 + tests/ap/paddle-tests/test_matmul_binary.py | 60 ++++++++++++++++----- tests/ap/test_matmul_binary.py | 5 +- 4 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 tests/ap/facade_matmul_utils.py diff --git a/tests/ap/facade_matmul_utils.py b/tests/ap/facade_matmul_utils.py new file mode 100644 index 0000000..0fe0c76 --- /dev/null +++ b/tests/ap/facade_matmul_utils.py @@ -0,0 +1,7 @@ +def infer_symbolic(infer_ctx, inputs, attrs): + return inputs + + +def infer_meta(inputs, attrs, mut_outputs): + mut_outputs[0].dims = inputs[0].dims + mut_outputs[0].dtype = inputs[0].dtype diff --git a/tests/ap/make_axpr.sh b/tests/ap/make_axpr.sh index 77588e5..a038aac 100644 --- a/tests/ap/make_axpr.sh +++ b/tests/ap/make_axpr.sh @@ -22,6 +22,7 @@ FILENAMES_ARRAY=( "matmul_epilogue_pass" "test_matmul_binary" "test_matmul_epilogue" + "facade_matmul_utils" ) for filename in "${FILENAMES_ARRAY[@]}" do diff --git a/tests/ap/paddle-tests/test_matmul_binary.py b/tests/ap/paddle-tests/test_matmul_binary.py index 059b4ae..9305e4e 100644 --- a/tests/ap/paddle-tests/test_matmul_binary.py +++ b/tests/ap/paddle-tests/test_matmul_binary.py @@ -22,6 +22,7 @@ import paddle from paddle.static import InputSpec +import paddle.incubate.cc as pcc def matmul_add_relu(x, y, b): @@ -29,9 +30,39 @@ def matmul_add_relu(x, y, b): return paddle.nn.functional.relu(out + b) -def matmul_add_gelu_true(x, y, b): +def matmul_add_gelu(x, y, b): out = paddle.matmul(x, y) - return paddle.nn.functional.gelu(out + b, True) + return paddle.nn.functional.gelu(out + b, False) + + +class FacadeMatmulOp(pcc.ap.FacadeOp): + def __init__(self): + super().__init__() + + def custom_op_name(self) -> str: + return "ap_custom_op.facade_matmul" + + def infer_meta(self) -> str: + return "facade_matmul_utils.infer_meta" + + def infer_symbolic(self) -> str: + return "facade_matmul_utils.infer_symbolic" + + def num_inputs(self) -> int: + return 2 + + def num_outputs(self, args) -> int: + return len(args) + + def attributes_schema(self): + # annotations matter. + pass + + +def facade_matmul_add(x, y, b): + facade_matmul_op = FacadeMatmulOp() + out = facade_matmul_op([x, y]) + return out[0] + b class CINNSubGraphNet(paddle.nn.Layer): @@ -56,15 +87,15 @@ def setUp(self): def prepare_data(self): self.dtype = "float16" - self.x_shape = [4, 65536, 128] + self.x_shape = [4, 64, 64] self.x = paddle.randn(self.x_shape, dtype=self.dtype) self.x.stop_gradient = False - self.y_shape = [128, 32] + self.y_shape = [64, 64] self.y = paddle.randn(self.y_shape, dtype=self.dtype) self.y.stop_gradient = False - self.b_shape = [32] + self.b_shape = [64] self.b = paddle.randn(self.b_shape, dtype=self.dtype) self.b.stop_gradient = False @@ -81,19 +112,20 @@ def eval_symbolic(self, net, use_cinn, profile): def test_matmul_add_relu(self): profile = False - net = CINNSubGraphNet(matmul_add_relu) + # net = CINNSubGraphNet(matmul_add_relu) + net = CINNSubGraphNet(facade_matmul_add) cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile) - dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) - if not profile: - utils.check_result(self.dtype, cinn_out, dy2st_out) + # dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) + # if not profile: + # utils.check_result(self.dtype, cinn_out, dy2st_out) def notest_matmul_add_gelu(self): - profile = False - net = CINNSubGraphNet(matmul_add_gelu_true) + profile = True + net = CINNSubGraphNet(matmul_add_gelu) cinn_out = self.eval_symbolic(net, use_cinn=True, profile=profile) - dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) - if not profile: - utils.check_result(self.dtype, cinn_out, dy2st_out) + # dy2st_out = self.eval_symbolic(net, use_cinn=False, profile=profile) + # if not profile: + # utils.check_result(self.dtype, cinn_out, dy2st_out) if __name__ == "__main__": diff --git a/tests/ap/test_matmul_binary.py b/tests/ap/test_matmul_binary.py index 1827b77..16be87d 100644 --- a/tests/ap/test_matmul_binary.py +++ b/tests/ap/test_matmul_binary.py @@ -17,7 +17,9 @@ @abstract_drr.register_drr_pass("matmul_binary_fusion", nice=0) class MatmulBinaryFusion(abstract_drr.DrrPass): def source_pattern(self, o, t): - o.matmul_op = o.ap_native_op("pd_op.matmul") + #o.matmul_op = o.ap_native_op("pd_op.matmul") + o.matmul_op = o.ap_native_op("ap_op.facade") + o.matmul_op.custom_op_name = pir.a_str("ap_custom_op.facade_matmul") o.matmul_op( [t.input0, t.input1], [t.mm_out] @@ -128,6 +130,7 @@ def _make_index_func_unique_id2index_program( self, compute_program, anchor_data_op_name, input_names, output_names): full_index_program = compute_program.clone() self._apply_topo_access_passes(full_index_program, anchor_data_op_name) + print("full_index_program:", full_index_program) def MatchAndCopyInputIndex(dst_input_name): pass_manager = ir_tools.create_pass_manager() removed_programs = MutableList()