|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from typing import Callable, Dict, final, List, Optional, Tuple |
| 7 | +from typing import final, List |
8 | 8 |
|
9 | | -import torch |
| 9 | +from executorch.backends.aoti.aoti_partitioner import AotiPartitioner |
10 | 10 | from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip |
11 | 11 | from executorch.exir._warnings import experimental |
12 | 12 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
13 | | -from executorch.exir.backend.partitioner import ( |
14 | | - DelegationSpec, |
15 | | - Partitioner, |
16 | | - PartitionResult, |
17 | | -) |
18 | | -from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer |
19 | | -from torch.export.exported_program import ExportedProgram |
20 | 13 |
|
21 | 14 |
|
22 | 15 | @final |
23 | 16 | @experimental( |
24 | 17 | "This API and all of Metal backend related functionality are experimental." |
25 | 18 | ) |
26 | | -class MetalPartitioner(Partitioner): |
| 19 | +class MetalPartitioner(AotiPartitioner): |
27 | 20 | """ |
28 | | - Metal partitioner for AOTInductor backend integration. |
29 | | -
|
30 | | - This partitioner creates a single partition containing all operators from the input graph. |
31 | | - It skips core ATen decomposition, allowing the Metal backend to handle decomposition using |
32 | | - AOTInductor's MPS-specific decomposition table. |
33 | | -
|
34 | | - Only operators that cannot be handled by the aoti-mps library will be excluded from |
35 | | - the partition and fall back to ExecuTorch's default or custom handling. |
| 21 | + Metal partitioner driven by AOTInductor backend. |
36 | 22 | """ |
37 | 23 |
|
38 | 24 | def __init__(self, compile_spec: List[CompileSpec]) -> None: |
39 | | - self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) |
40 | | - |
41 | | - def partition(self, exported_program: ExportedProgram) -> PartitionResult: |
42 | | - """ |
43 | | - Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. |
44 | | - """ |
45 | | - |
46 | | - partition_tags: Dict[str, DelegationSpec] = {} |
47 | | - tag = "tag0" |
48 | | - |
49 | | - for node in exported_program.graph.nodes: |
50 | | - if node.op != "call_function": |
51 | | - continue |
52 | | - node.meta["delegation_tag"] = tag |
53 | | - |
54 | | - partition_tags[tag] = self.delegation_spec |
55 | | - |
56 | | - tag_constant_data(exported_program) |
57 | | - tag_mutated_buffer(exported_program) |
58 | | - |
59 | | - return PartitionResult( |
60 | | - tagged_exported_program=exported_program, partition_tags=partition_tags |
61 | | - ) |
62 | | - |
63 | | - def ops_to_not_decompose( |
64 | | - self, ep: ExportedProgram |
65 | | - ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: |
66 | | - """ |
67 | | - Return a list of operations that should not be decomposed and let the AOT compiler handle them. |
68 | | - Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. |
69 | | - """ |
70 | | - do_not_decompose = set() |
71 | | - |
72 | | - for node in ep.graph.nodes: |
73 | | - if node.op == "call_function" and isinstance( |
74 | | - node.target, torch._ops.OpOverload |
75 | | - ): |
76 | | - do_not_decompose.add(node.target) |
77 | | - return list(do_not_decompose), None |
| 25 | + super().__init__(MetalBackend.__name__, compile_spec) |
0 commit comments