diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 9a695191359..3336a394510 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -27,6 +27,7 @@ from .decompose_floor_divide import DecomposeFloorDivide from .decompose_glu import DecomposeGlu from .decompose_hardsigmoid import DecomposeHardsigmoid +from .decompose_hyperbolic_variants import DecomposeHyperbolicVariants from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_log_variants import DecomposeLogVariants from .decompose_maxpool3d import DecomposeMaxPool3d @@ -89,6 +90,7 @@ DecomposeFill, DecomposeFloorDivide, DecomposeGlu, + DecomposeHyperbolicVariants, DecomposeHardsigmoid, DecomposeLinalgVectorNorm, DecomposeLogVariants, diff --git a/backends/qualcomm/_passes/decompose_acos.py b/backends/qualcomm/_passes/decompose_acos.py index f83b18f11fc..d546cf6d92d 100644 --- a/backends/qualcomm/_passes/decompose_acos.py +++ b/backends/qualcomm/_passes/decompose_acos.py @@ -9,7 +9,7 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_meta, get_const_node +from .utils import copy_meta, create_const_node class DecomposeAcos(ExportPass): @@ -52,7 +52,7 @@ def call(self, graph_module: torch.fx.GraphModule): ) if is_edge and pi_half_node is None: - pi_half_node = get_const_node( + pi_half_node = create_const_node( graph, graph_module, "_pi_half_constant", pi_half, node ) diff --git a/backends/qualcomm/_passes/decompose_addmm.py b/backends/qualcomm/_passes/decompose_addmm.py index 674daa9d550..8cd4f9a920f 100644 --- a/backends/qualcomm/_passes/decompose_addmm.py +++ b/backends/qualcomm/_passes/decompose_addmm.py @@ -9,7 +9,7 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_meta, get_const_node +from .utils import copy_meta, create_const_node class DecomposeAddmm(ExportPass): @@ -70,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule): mm_node.meta = copy_meta(meta) if alpha != 1: - alpha_node = get_const_node( + alpha_node = create_const_node( graph, graph_module, f"{node.name}_alpha", @@ -86,7 +86,7 @@ def call(self, graph_module: torch.fx.GraphModule): mm_result = mm_node if beta != 1: - beta_const = get_const_node( + beta_const = create_const_node( graph, graph_module, f"{node.name}_beta", diff --git a/backends/qualcomm/_passes/decompose_atan2.py b/backends/qualcomm/_passes/decompose_atan2.py index 0f54e555e03..a411f997b61 100644 --- a/backends/qualcomm/_passes/decompose_atan2.py +++ b/backends/qualcomm/_passes/decompose_atan2.py @@ -9,7 +9,7 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_meta, create_node, get_const_node +from .utils import copy_meta, create_const_node, create_node class DecomposeAtan2(ExportPass): @@ -68,7 +68,7 @@ def _get_constants(self, graph, graph_module, node, is_edge, const_cache): def make_const(name, val): if name not in const_cache: - const_cache[name] = get_const_node( + const_cache[name] = create_const_node( graph, graph_module, name, val, node ) return const_cache[name] diff --git a/backends/qualcomm/_passes/decompose_hyperbolic_variants.py b/backends/qualcomm/_passes/decompose_hyperbolic_variants.py new file mode 100644 index 00000000000..9664565b309 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_hyperbolic_variants.py @@ -0,0 +1,202 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import create_const_node, create_node + + +class DecomposeHyperbolicVariants(ExportPass): + """ + Decompose hyperbolic functions into supported primitives: + sinh(x) = 0.5 * (exp(x) - exp(-x)) + cosh(x) = 0.5 * (exp(x) + exp(-x)) + asinh(x) = log(x + sqrt(x*x + 1)) + acosh(x) = log(x + sqrt(x*x - 1)) + atanh(x) = 0.5 * log((1 + x) / (1 - x)) + """ + + _EDGE_OPS = { + exir_ops.edge.aten.sinh.default, + exir_ops.edge.aten.cosh.default, + exir_ops.edge.aten.asinh.default, + exir_ops.edge.aten.acosh.default, + exir_ops.edge.aten.atanh.default, + } + + def __init__(self): + super(DecomposeHyperbolicVariants, self).__init__() + self._dispatcher = { + # ATen dialect + torch.ops.aten.sinh.default: self._decompose_sinh, + torch.ops.aten.cosh.default: self._decompose_cosh, + torch.ops.aten.asinh.default: self._decompose_asinh, + torch.ops.aten.acosh.default: self._decompose_acosh, + torch.ops.aten.atanh.default: self._decompose_atanh, + # Edge dialect + exir_ops.edge.aten.sinh.default: self._decompose_sinh, + exir_ops.edge.aten.cosh.default: self._decompose_cosh, + exir_ops.edge.aten.asinh.default: self._decompose_asinh, + exir_ops.edge.aten.acosh.default: self._decompose_acosh, + exir_ops.edge.aten.atanh.default: self._decompose_atanh, + } + + def _get_ops(self, is_edge): + if is_edge: + return { + "exp": exir_ops.edge.aten.exp.default, + "neg": exir_ops.edge.aten.neg.default, + "add": exir_ops.edge.aten.add.Tensor, + "sub": exir_ops.edge.aten.sub.Tensor, + "mul": exir_ops.edge.aten.mul.Tensor, + "div": exir_ops.edge.aten.div.Tensor, + "log": exir_ops.edge.aten.log.default, + "sqrt": exir_ops.edge.aten.sqrt.default, + } + return { + "exp": torch.ops.aten.exp.default, + "neg": torch.ops.aten.neg.default, + "add": torch.ops.aten.add.Tensor, + "sub": torch.ops.aten.sub.Tensor, + "mul": torch.ops.aten.mul.Tensor, + "div": torch.ops.aten.div.Tensor, + "log": torch.ops.aten.log.default, + "sqrt": torch.ops.aten.sqrt.default, + } + + def _decompose_exp_symmetry( + self, node, graph, graph_module, const_cache, combine_op_key + ): + """Shared helper for sinh and cosh: (exp(x) ± exp(-x)) / 2.""" + is_edge = node.target in self._EDGE_OPS + ops = self._get_ops(is_edge) + meta = node.meta + input_node = node.args[0] + + if is_edge: + half_name = "_half_constant" + if half_name not in const_cache: + const_cache[half_name] = create_const_node( + graph, graph_module, half_name, 0.5, node + ) + half_arg = const_cache[half_name] + else: + half_arg = 0.5 + + with graph.inserting_before(node): + exp_pos = create_node(graph, ops["exp"], (input_node,), meta) + neg_x = create_node(graph, ops["neg"], (input_node,), meta) + exp_neg = create_node(graph, ops["exp"], (neg_x,), meta) + combine = create_node(graph, ops[combine_op_key], (exp_pos, exp_neg), meta) + result = create_node(graph, ops["mul"], (combine, half_arg), meta) + + for user in node.users.copy(): + user.replace_input_with(node, result) + + def _decompose_sinh(self, node, graph, graph_module, const_cache): + self._decompose_exp_symmetry(node, graph, graph_module, const_cache, "sub") + + def _decompose_cosh(self, node, graph, graph_module, const_cache): + self._decompose_exp_symmetry(node, graph, graph_module, const_cache, "add") + + def _decompose_asinh(self, node, graph, graph_module, const_cache): + is_edge = node.target in self._EDGE_OPS + ops = self._get_ops(is_edge) + meta = node.meta + input_node = node.args[0] + + if is_edge: + one_name = "_one_constant" + if one_name not in const_cache: + const_cache[one_name] = create_const_node( + graph, graph_module, one_name, 1.0, node + ) + one_arg = const_cache[one_name] + else: + one_arg = 1.0 + + with graph.inserting_before(node): + x_sq = create_node(graph, ops["mul"], (input_node, input_node), meta) + x_sq_plus_1 = create_node(graph, ops["add"], (x_sq, one_arg), meta) + sqrt_node = create_node(graph, ops["sqrt"], (x_sq_plus_1,), meta) + sum_node = create_node(graph, ops["add"], (input_node, sqrt_node), meta) + result = create_node(graph, ops["log"], (sum_node,), meta) + + for user in node.users.copy(): + user.replace_input_with(node, result) + + def _decompose_acosh(self, node, graph, graph_module, const_cache): + is_edge = node.target in self._EDGE_OPS + ops = self._get_ops(is_edge) + meta = node.meta + input_node = node.args[0] + + if is_edge: + one_name = "_one_constant" + if one_name not in const_cache: + const_cache[one_name] = create_const_node( + graph, graph_module, one_name, 1.0, node + ) + one_arg = const_cache[one_name] + else: + one_arg = 1.0 + + with graph.inserting_before(node): + x_sq = create_node(graph, ops["mul"], (input_node, input_node), meta) + x_sq_minus_1 = create_node(graph, ops["sub"], (x_sq, one_arg), meta) + sqrt_node = create_node(graph, ops["sqrt"], (x_sq_minus_1,), meta) + sum_node = create_node(graph, ops["add"], (input_node, sqrt_node), meta) + result = create_node(graph, ops["log"], (sum_node,), meta) + + for user in node.users.copy(): + user.replace_input_with(node, result) + + def _decompose_atanh(self, node, graph, graph_module, const_cache): + is_edge = node.target in self._EDGE_OPS + ops = self._get_ops(is_edge) + meta = node.meta + input_node = node.args[0] + + if is_edge: + one_name = "_one_constant" + if one_name not in const_cache: + const_cache[one_name] = create_const_node( + graph, graph_module, one_name, 1.0, node + ) + one_arg = const_cache[one_name] + half_name = "_half_constant" + if half_name not in const_cache: + const_cache[half_name] = create_const_node( + graph, graph_module, half_name, 0.5, node + ) + half_arg = const_cache[half_name] + else: + one_arg = 1.0 + half_arg = 0.5 + + with graph.inserting_before(node): + one_plus_x = create_node(graph, ops["add"], (one_arg, input_node), meta) + one_minus_x = create_node(graph, ops["sub"], (one_arg, input_node), meta) + ratio = create_node(graph, ops["div"], (one_plus_x, one_minus_x), meta) + log_node = create_node(graph, ops["log"], (ratio,), meta) + result = create_node(graph, ops["mul"], (half_arg, log_node), meta) + + for user in node.users.copy(): + user.replace_input_with(node, result) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + const_cache = {} + + for node in list(graph.nodes): + if node.target in self._dispatcher: + self._dispatcher[node.target](node, graph, graph_module, const_cache) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_log_variants.py b/backends/qualcomm/_passes/decompose_log_variants.py index 2b394806b68..904900dd205 100644 --- a/backends/qualcomm/_passes/decompose_log_variants.py +++ b/backends/qualcomm/_passes/decompose_log_variants.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_meta, get_const_node +from .utils import copy_meta, create_const_node class DecomposeLogVariants(ExportPass): @@ -50,7 +50,7 @@ def _decompose_log_n(self, node, graph, graph_module, const_cache, n): div_op = exir_ops.edge.aten.div.Tensor attr_name = f"_log_base_{n}_constant" if attr_name not in const_cache: - const_cache[attr_name] = get_const_node( + const_cache[attr_name] = create_const_node( graph, graph_module, attr_name, math.log(n), node ) div_arg = const_cache[attr_name] @@ -81,7 +81,7 @@ def _decompose_log_p(self, node, graph, graph_module, const_cache, p): log_op = exir_ops.edge.aten.log.default attr_name = f"_log1p_addend_{p}_constant" if attr_name not in const_cache: - const_cache[attr_name] = get_const_node( + const_cache[attr_name] = create_const_node( graph, graph_module, attr_name, p, node ) add_arg = const_cache[attr_name] diff --git a/backends/qualcomm/_passes/decompose_remainder.py b/backends/qualcomm/_passes/decompose_remainder.py index 4e5ea739856..a6c260d217b 100644 --- a/backends/qualcomm/_passes/decompose_remainder.py +++ b/backends/qualcomm/_passes/decompose_remainder.py @@ -10,7 +10,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix -from .utils import copy_meta, get_const_node +from .utils import copy_meta, create_const_node class DecomposeRemainder(ExportPass): @@ -69,7 +69,7 @@ def call(self, graph_module: torch.fx.GraphModule): attr_name = get_new_attr_name_with_prefix("_remainder_const_")( graph_module ) - const_cache[x_arg] = get_const_node( + const_cache[x_arg] = create_const_node( graph, graph_module, attr_name, x_arg, node ) x_node = const_cache[x_arg] @@ -82,7 +82,7 @@ def call(self, graph_module: torch.fx.GraphModule): attr_name = get_new_attr_name_with_prefix("_remainder_const_")( graph_module ) - const_cache[y_arg] = get_const_node( + const_cache[y_arg] = create_const_node( graph, graph_module, attr_name, y_arg, node ) y_node = const_cache[y_arg] diff --git a/backends/qualcomm/_passes/decompose_var.py b/backends/qualcomm/_passes/decompose_var.py index 923fae4977f..c89929fa50e 100644 --- a/backends/qualcomm/_passes/decompose_var.py +++ b/backends/qualcomm/_passes/decompose_var.py @@ -10,7 +10,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix -from .utils import copy_meta, get_const_node +from .utils import copy_meta, create_const_node class DecomposeVar(ExportPass): @@ -155,7 +155,7 @@ def call(self, graph_module: torch.fx.GraphModule): attr_name = get_new_attr_name_with_prefix( "_var_scale_const_" )(graph_module) - const_cache[cache_key] = get_const_node( + const_cache[cache_key] = create_const_node( graph, graph_module, attr_name, scale, node ) scale_node = const_cache[cache_key] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 4e2d9f3d3a0..7efb4a293e1 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -32,6 +32,7 @@ DecomposeFill, DecomposeFloorDivide, DecomposeGlu, + DecomposeHyperbolicVariants, DecomposeLinalgVectorNorm, DecomposeLogVariants, DecomposeMaxPool3d, @@ -130,6 +131,7 @@ def get_default_pass_activations(cls): (DecomposeCDist, True), (DecomposeDivMode, True), (DecomposeFill, True), + (DecomposeHyperbolicVariants, True), (DecomposeLogVariants, True), (DecomposeMaxPool3d, True), (DecomposeMinMaxDim, True), @@ -182,6 +184,7 @@ def get_annotation_passes(cls): DecomposeExpM1, DecomposeFill, DecomposeGlu, + DecomposeHyperbolicVariants, DecomposeRemainder, DecomposeSelectScatter, DecomposeLinalgVectorNorm, @@ -285,6 +288,7 @@ def get_passes_dependency_for_capture_program(cls): DecomposeCDist: [RemoveRedundancy], DecomposeDivMode: [RemoveRedundancy], DecomposeFill: [RemoveRedundancy], + DecomposeHyperbolicVariants: [RemoveRedundancy], DecomposeLinalgVectorNorm: [RemoveRedundancy], DecomposeLogVariants: [RemoveRedundancy], DecomposeMaxPool3d: [RemoveRedundancy], diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 92a75703bbd..2a580ab11a4 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -343,7 +343,7 @@ def append_qdq( return dq_node -def get_const_node( +def create_const_node( graph: torch.fx.Graph, graph_module: torch.fx.GraphModule, attr_name: str, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 2a838707e05..b8d86b9d6da 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -498,12 +498,16 @@ The following PyTorch operators are supported through decomposition or annotatio | PyTorch Op | Decomposition Pass | |---|---| | `aten.acos` | `DecomposeAcos` | +| `aten.acosh` | `DecomposeHyperbolicVariants` | | `aten.addmm` | `DecomposeAddmm` | | `aten.adaptive_avg_pool1d`, `aten.avg_pool1d` | `AnnotateAvgPool1D` | | `aten.any` | `DecomposeAny` | +| `aten.asinh` | `DecomposeHyperbolicVariants` | | `aten.atan2.default`, `aten.atan2.out` | `DecomposeAtan2` | +| `aten.atanh` | `DecomposeHyperbolicVariants` | | `aten.add` (with alpha), `aten.sub` (with alpha) | `DecomposeBinaryAlpha` | | `aten.cdist`, `aten._cdist_forward` | `DecomposeCDist` | +| `aten.cosh` | `DecomposeHyperbolicVariants` | | `aten.div.Tensor_mode` | `DecomposeDivMode` | | `aten.div.Scalar_mode` | `LiftConstantScalarOperands` → `DecomposeDivMode` | | `aten.im2col`, `aten.col2im` | `DecomposeColIm` | @@ -523,6 +527,7 @@ The following PyTorch operators are supported through decomposition or annotatio | `aten.roll` | `DecomposeRoll` | | `aten.select_scatter` | `DecomposeSelectScatter` | | `aten.silu` | `DecomposeSilu` | +| `aten.sinh` | `DecomposeHyperbolicVariants` | | `aten.tan` | `DecomposeTan` | | `aten.threshold` | `DecomposeThreshold` | | `aten.triu` | `DecomposeTriu` | diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 140fbfe5cc0..0201edb6dee 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -49,6 +49,14 @@ def forward(self, x): return torch.acos(x) +class Acosh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.acosh(x) + + class AcosMultiNode(torch.nn.Module): def __init__(self): super().__init__() @@ -257,6 +265,14 @@ def forward(self, x, y): return squeeze_out, conv_out +class Asinh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.asinh(x) + + class Asin(torch.nn.Module): def __init__(self): super().__init__() @@ -289,6 +305,14 @@ def forward(self, x1, y1, x2, y2): return torch.atan2(x1, y1), torch.atan2(x2, y2) +class Atanh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.atanh(x) + + class AvgPool1D(torch.nn.Module): def __init__(self): super().__init__() @@ -999,6 +1023,14 @@ def forward(self, x): return torch.cos(x) +class Cosh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cosh(x) + + class CumSum(torch.nn.Module): def __init__(self): super().__init__() @@ -2300,6 +2332,14 @@ def forward(self, x): return torch.sin(x) +class Sinh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sinh(x) + + class SimpleModel(torch.nn.Module): def __init__(self, kernel_size=3): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 290320ba38f..4c23cc64193 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -150,6 +150,11 @@ def test_qnn_backend_acos(self): index += 1 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_acosh(self): + module = Acosh() # noqa: F405 + sample_input = (torch.tensor([1.0, 1.5, 2.0, 3.0, 5.0, 10.0]).reshape(2, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -324,6 +329,11 @@ def test_qnn_backend_argmin(self): case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] ) + def test_qnn_backend_asinh(self): + module = Asinh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + self.lower_module_and_test_output(module, sample_input) + @unittest.expectedFailure def test_qnn_backend_asin(self): sample_input = (torch.rand(3, 4) * 2 - 1,) @@ -375,6 +385,11 @@ def test_qnn_backend_atan2(self): index += 1 self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_atanh(self): + module = Atanh() # noqa: F405 + sample_input = (torch.tensor([-0.9, -0.5, -0.1, 0.1, 0.5, 0.9]).reshape(2, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool1d(self): module = AvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -637,6 +652,11 @@ def test_qnn_backend_cos(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cosh(self): + module = Cosh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cumsum(self): sample_input = () test_comb = [ @@ -2171,6 +2191,11 @@ def test_qnn_backend_sin(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sinh(self): + module = Sinh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_select_copy(self): module = SelectCopy() # noqa: F405 sample_input = (torch.randn([1, 3, 3, 3]),) @@ -2949,6 +2974,12 @@ def test_qnn_backend_acos(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_acosh(self): + module = Acosh() # noqa: F405 + sample_input = (torch.tensor([1.0, 1.5, 2.0, 3.0, 5.0, 10.0]).reshape(2, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -3142,6 +3173,12 @@ def test_qnn_backend_asin(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_asinh(self): + module = Asinh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_atan(self): sample_input = (torch.randn(3, 4),) module = Atan() # noqa: F405 @@ -3181,6 +3218,12 @@ def test_qnn_backend_atan2(self): qdq_module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_atanh(self): + module = Atanh() # noqa: F405 + sample_input = (torch.tensor([-0.9, -0.5, -0.1, 0.1, 0.5, 0.9]).reshape(2, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_avg_pool1d(self): module = AvgPool1D() # noqa: F405 sample_input = (torch.randn(1, 512, 7),) @@ -3527,6 +3570,12 @@ def test_qnn_backend_cos(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cosh(self): + module = Cosh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cumsum(self): module = CumSum() # noqa: F405 sample_input = (torch.randn(4),) @@ -5311,6 +5360,12 @@ def test_qnn_backend_sin(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sinh(self): + module = Sinh() # noqa: F405 + sample_input = (torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).reshape(2, 3),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_slice_copy(self): modules = [ SliceCopyDefaultParameter(), # noqa: F405