|
5 | 5 |
|
6 | 6 |
|
7 | 7 | import itertools |
8 | | -from typing import Any, Set, Type |
| 8 | +from typing import Any, cast, Set, Type |
9 | 9 |
|
10 | 10 | import torch |
11 | 11 | from executorch.backends.arm._passes import ArmPass |
|
39 | 39 | from executorch.exir.dialects._ops import ops as exir_ops |
40 | 40 | from executorch.exir.pass_base import ExportPass, PassResult |
41 | 41 |
|
| 42 | +from torch._subclasses.fake_tensor import FakeTensor |
42 | 43 | from torch.export.graph_signature import InputKind |
43 | 44 |
|
44 | 45 |
|
@@ -350,6 +351,68 @@ def _has_int32_rescale_user(self, node: torch.fx.Node) -> bool: |
350 | 351 | return True |
351 | 352 | return False |
352 | 353 |
|
| 354 | + def _insert_output_conversion( |
| 355 | + self, |
| 356 | + graph_module: torch.fx.GraphModule, |
| 357 | + node: torch.fx.Node, |
| 358 | + tosa_op: torch.fx.Node, |
| 359 | + input_fake_tensor: torch.Tensor, |
| 360 | + tosa_node_fake_tensor: torch.Tensor, |
| 361 | + ) -> tuple[torch.fx.Node, FakeTensor]: |
| 362 | + node_replacement: torch.fx.Node = tosa_op |
| 363 | + node_replacement_fake_tensor = tosa_node_fake_tensor |
| 364 | + if ( |
| 365 | + tosa_node_fake_tensor.dtype == torch.int32 |
| 366 | + and input_fake_tensor.dtype == torch.int8 |
| 367 | + ): |
| 368 | + node_replacement, node_replacement_fake_tensor = self.insert_output_rescale( |
| 369 | + graph_module, node, tosa_op, tosa_node_fake_tensor |
| 370 | + ) |
| 371 | + elif ( |
| 372 | + tosa_node_fake_tensor.dtype == torch.int32 |
| 373 | + and input_fake_tensor.dtype == torch.int16 |
| 374 | + ): |
| 375 | + # Explicit layout paths require a post-conv permute, which does |
| 376 | + # not support INT48. Always rescale before post-permute. |
| 377 | + if self._has_int32_rescale_user(node): |
| 378 | + node_replacement, node_replacement_fake_tensor = ( |
| 379 | + self.insert_identity_int32_rescale( |
| 380 | + graph_module, node, tosa_op, tosa_node_fake_tensor |
| 381 | + ) |
| 382 | + ) |
| 383 | + else: |
| 384 | + node_replacement, node_replacement_fake_tensor = ( |
| 385 | + self.insert_output_rescale( |
| 386 | + graph_module, node, tosa_op, tosa_node_fake_tensor |
| 387 | + ) |
| 388 | + ) |
| 389 | + |
| 390 | + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 |
| 391 | + elif ( |
| 392 | + tosa_node_fake_tensor.dtype == torch.float16 |
| 393 | + and input_fake_tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) |
| 394 | + ): |
| 395 | + node_output_fake_tensor = get_first_fake_tensor(node) |
| 396 | + # TOSA FP8 conv widens the output. Cast back to the exported |
| 397 | + # graph dtype before the post-layout permute. |
| 398 | + node_replacement_fake_tensor = ( |
| 399 | + exir_ops.edge.dim_order_ops._to_dim_order_copy.default( |
| 400 | + tosa_node_fake_tensor, |
| 401 | + dtype=node_output_fake_tensor.dtype, |
| 402 | + ) |
| 403 | + ) |
| 404 | + with graph_module.graph.inserting_after(tosa_op): |
| 405 | + node_replacement = create_node( |
| 406 | + graph=graph_module.graph, |
| 407 | + op_target=exir_ops.edge.dim_order_ops._to_dim_order_copy.default, |
| 408 | + args=(tosa_op,), |
| 409 | + kwargs={"dtype": node_output_fake_tensor.dtype}, |
| 410 | + from_node=tosa_op, |
| 411 | + ) |
| 412 | + node_replacement.meta["val"] = node_replacement_fake_tensor |
| 413 | + |
| 414 | + return node_replacement, cast(FakeTensor, node_replacement_fake_tensor) |
| 415 | + |
353 | 416 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 |
354 | 417 | modified = False |
355 | 418 | for node in graph_module.graph.nodes: |
@@ -561,37 +624,15 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 |
561 | 624 | ) |
562 | 625 | tosa_op.meta["val"] = tosa_node_fake_tensor |
563 | 626 |
|
564 | | - node_replacement: torch.fx.Node = tosa_op |
565 | | - node_replacement_fake_tensor = tosa_node_fake_tensor |
566 | | - if ( |
567 | | - tosa_node_fake_tensor.dtype == torch.int32 |
568 | | - and input_fake_tensor.dtype == torch.int8 |
569 | | - ): |
570 | | - output_rescale, output_rescale_fake = self.insert_output_rescale( |
571 | | - graph_module, node, tosa_op, tosa_node_fake_tensor |
| 627 | + node_replacement, node_replacement_fake_tensor = ( |
| 628 | + self._insert_output_conversion( |
| 629 | + graph_module, |
| 630 | + node, |
| 631 | + tosa_op, |
| 632 | + input_fake_tensor, |
| 633 | + tosa_node_fake_tensor, |
572 | 634 | ) |
573 | | - node_replacement = output_rescale |
574 | | - node_replacement_fake_tensor = output_rescale_fake |
575 | | - elif ( |
576 | | - tosa_node_fake_tensor.dtype == torch.int32 |
577 | | - and input_fake_tensor.dtype == torch.int16 |
578 | | - ): |
579 | | - # Explicit layout paths require a post-conv permute, which does |
580 | | - # not support INT48. Always rescale before post-permute. |
581 | | - if self._has_int32_rescale_user(node): |
582 | | - output_rescale, output_rescale_fake = ( |
583 | | - self.insert_identity_int32_rescale( |
584 | | - graph_module, node, tosa_op, tosa_node_fake_tensor |
585 | | - ) |
586 | | - ) |
587 | | - else: |
588 | | - output_rescale, output_rescale_fake = self.insert_output_rescale( |
589 | | - graph_module, node, tosa_op, tosa_node_fake_tensor |
590 | | - ) |
591 | | - node_replacement = output_rescale |
592 | | - node_replacement_fake_tensor = output_rescale_fake |
593 | | - |
594 | | - tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 |
| 635 | + ) |
595 | 636 |
|
596 | 637 | if post_permute_dims is None: |
597 | 638 | raise RuntimeError("Expected post permute dims for explicit layout") |
|
0 commit comments