1- #!/usr/bin/env fbpython
21# Copyright (c) Meta Platforms, Inc. and affiliates.
32# All rights reserved.
43#
76
87import argparse
98import os
9+ import re
1010import sys
11+ from functools import reduce
12+ from pathlib import Path
1113from typing import Any , List
1214
13- from tools_copy .code_analyzer import gen_oplist_copy_from_core
15+ import yaml
16+ from torchgen .selective_build .selector import (
17+ combine_selective_builders ,
18+ SelectiveBuilder ,
19+ )
20+
21+
22+ def throw_if_any_op_includes_overloads (selective_builder : SelectiveBuilder ) -> None :
23+ ops = []
24+ for op_name , op in selective_builder .operators .items ():
25+ if op .include_all_overloads :
26+ ops .append (op_name )
27+ if ops :
28+ raise Exception ( # noqa: TRY002
29+ (
30+ "Operators that include all overloads are "
31+ + "not allowed since --allow-include-all-overloads "
32+ + "was not specified: {}"
33+ ).format (", " .join (ops ))
34+ )
35+
36+
37+ def resolve_model_file_path_to_buck_target (model_file_path : str ) -> str :
38+ real_path = str (Path (model_file_path ).resolve (strict = True ))
39+ # try my best to convert to buck target
40+ prog = re .compile (
41+ r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml"
42+ )
43+ match = prog .match (real_path )
44+ if match :
45+ return f"{ match .group (1 )} //{ match .group (2 )} :{ match .group (3 )} "
46+ else :
47+ return real_path
1448
1549
1650def main (argv : List [Any ]) -> None :
17- """This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18- This is needed because we intend to error out for the case where `model_file_list_path`
19- is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20- is provided as a dependency to ExecuTorch build.
51+ """This binary generates 3 files:
52+
53+ 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
54+ dtypes captured by tracing
55+ 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
2156 """
2257 parser = argparse .ArgumentParser (description = "Generate operator lists" )
2358 parser .add_argument (
59+ "--output-dir" ,
2460 "--output_dir" ,
2561 help = ("The directory to store the output yaml file (selected_operators.yaml)" ),
2662 required = True ,
2763 )
2864 parser .add_argument (
65+ "--model-file-list-path" ,
2966 "--model_file_list_path" ,
3067 help = (
3168 "Path to a file that contains the locations of individual "
@@ -36,6 +73,7 @@ def main(argv: List[Any]) -> None:
3673 required = True ,
3774 )
3875 parser .add_argument (
76+ "--allow-include-all-overloads" ,
3977 "--allow_include_all_overloads" ,
4078 help = (
4179 "Flag to allow operators that include all overloads. "
@@ -46,26 +84,112 @@ def main(argv: List[Any]) -> None:
4684 default = False ,
4785 required = False ,
4886 )
87+ parser .add_argument (
88+ "--check-ops-not-overlapping" ,
89+ "--check_ops_not_overlapping" ,
90+ help = (
91+ "Flag to check if the operators in the model file list are overlapping. "
92+ + "If not set, the script will not error out for overlapping operators."
93+ ),
94+ action = "store_true" ,
95+ default = False ,
96+ required = False ,
97+ )
98+ options = parser .parse_args (argv )
4999
50- # check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
100+ # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
51101 # 1. a yaml file containing selected ops (could be empty), or
52- # 2. a non-empty list of yaml files in the `model_file_list_path`.
53- # If none of the two things happened, the build target has no dependency on any selective build and we should error out .
54- options = parser . parse_args ( argv )
102+ # 2. a non-empty list of yaml files in the `model_file_list_path` or
103+ # 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file .
104+ # If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55105 if os .path .isfile (options .model_file_list_path ):
56- pass
106+ print ("Processing model file: " , options .model_file_list_path )
107+ model_dicts = []
108+ model_dict = yaml .safe_load (open (options .model_file_list_path ))
109+ model_dicts .append (model_dict )
57110 else :
111+ print (
112+ "Processing model file list or model directory list: " ,
113+ options .model_file_list_path ,
114+ )
58115 assert (
59116 options .model_file_list_path [0 ] == "@"
60117 ), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
118+
61119 model_file_list_path = options .model_file_list_path [1 :]
120+
121+ model_dicts = []
62122 with open (model_file_list_path ) as model_list_file :
63123 model_file_names = model_list_file .read ().split ()
64124 assert (
65125 len (model_file_names ) > 0
66126 ), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67127 "build. Please refer to Selective Build wiki page to add at least one."
68- gen_oplist_copy_from_core .main (argv )
128+ for model_file_name in model_file_names :
129+ if not os .path .isfile (model_file_name ):
130+ model_file_name = os .path .join (
131+ model_file_name , "selected_operators.yaml"
132+ )
133+ print ("Processing model file: " , model_file_name )
134+ assert os .path .isfile (
135+ model_file_name
136+ ), f"{ model_file_name } is not a valid file path. This is likely a BUCK issue."
137+ with open (model_file_name , "rb" ) as model_file :
138+ model_dict = yaml .safe_load (model_file )
139+ resolved = resolve_model_file_path_to_buck_target (model_file_name )
140+ for op in model_dict ["operators" ]:
141+ model_dict ["operators" ][op ]["debug_info" ] = [resolved ]
142+ model_dicts .append (model_dict )
143+
144+ selective_builders = [SelectiveBuilder .from_yaml_dict (m ) for m in model_dicts ]
145+
146+ # Optionally check if the operators in the model file list are overlapping.
147+ if options .check_ops_not_overlapping :
148+ ops = {}
149+ for model_dict in model_dicts :
150+ for op_name in model_dict ["operators" ]:
151+ if op_name in ops :
152+ debug_info_1 = "," .join (ops [op_name ]["debug_info" ])
153+ debug_info_2 = "," .join (
154+ model_dict ["operators" ][op_name ]["debug_info" ]
155+ )
156+ error = f"Operator { op_name } is used in 2 models: { debug_info_1 } and { debug_info_2 } "
157+ if "//" not in debug_info_1 and "//" not in debug_info_2 :
158+ error += "\n We can't determine what BUCK targets these model files belong to."
159+ tail = "."
160+ else :
161+ error += "\n Please run the following commands to find out where is the BUCK target being added as a dependency to your target:\n "
162+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_1 } )"'
163+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_2 } )"'
164+ tail = "as well as results from BUCK commands listed above."
165+
166+ error += (
167+ "\n \n If issue is not resolved, please post in PyTorch Edge Q&A with this error message"
168+ + tail
169+ )
170+ raise Exception (error ) # noqa: TRY002
171+ ops [op_name ] = model_dict ["operators" ][op_name ]
172+ # We may have 0 selective builders since there may not be any viable
173+ # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
174+ # This is potentially an error, and we should probably raise an assertion
175+ # failure here. However, this needs to be investigated further.
176+ selective_builder = SelectiveBuilder .from_yaml_dict ({})
177+ if len (selective_builders ) > 0 :
178+ selective_builder = reduce (
179+ combine_selective_builders ,
180+ selective_builders ,
181+ )
182+
183+ if not options .allow_include_all_overloads :
184+ throw_if_any_op_includes_overloads (selective_builder )
185+ with open (
186+ os .path .join (options .output_dir , "selected_operators.yaml" ), "wb"
187+ ) as out_file :
188+ out_file .write (
189+ yaml .safe_dump (
190+ selective_builder .to_dict (), default_flow_style = False
191+ ).encode ("utf-8" ),
192+ )
69193
70194
71195if __name__ == "__main__" :
0 commit comments