|
3 | 3 | import os |
4 | 4 | import sys |
5 | 5 | import warnings |
6 | | -from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union |
| 6 | +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union |
7 | 7 | import numpy as np |
8 | 8 | import numpy.typing as npt |
9 | 9 | import onnx |
|
15 | 15 | GraphProto, |
16 | 16 | ModelProto, |
17 | 17 | NodeProto, |
| 18 | + OperatorSetIdProto, |
18 | 19 | TensorProto, |
19 | 20 | ValueInfoProto, |
20 | 21 | load as onnx_load, |
@@ -1195,3 +1196,104 @@ def shadowing_names( |
1195 | 1196 | existing |= not_empty |
1196 | 1197 | created |= not_empty |
1197 | 1198 | return shadow, post_shadow, created |
| 1199 | + |
| 1200 | + |
| 1201 | +def extract_subset_of_nodes( |
| 1202 | + model: ModelProto, |
| 1203 | + name: str, |
| 1204 | + node_index: Optional[int] = None, |
| 1205 | + cut_points: Optional[Set[str]] = None, |
| 1206 | +) -> List[NodeProto]: |
| 1207 | + """ |
| 1208 | + Extracts the minimal subgraphs which can produce the output ``name`` |
| 1209 | + knowing ``cut_points``. |
| 1210 | +
|
| 1211 | + :param model: original model |
| 1212 | + :param name: result name |
| 1213 | + :param node_index: if the node index is known, otherwise searches for it |
| 1214 | + :param cut_points: the known results or input name otherwise |
| 1215 | + :return: minimal list of nodes |
| 1216 | + """ |
| 1217 | + if node_index is None: |
| 1218 | + for i, node in enumerate(model.graph.node): |
| 1219 | + if name in node.output: |
| 1220 | + node_index = i |
| 1221 | + break |
| 1222 | + assert ( |
| 1223 | + node_index is not None |
| 1224 | + and node_index < len(model.graph.node) |
| 1225 | + and name in model.graph.node[node_index].output |
| 1226 | + ), f"node_index is still empty or wrong for result {name!r}" |
| 1227 | + if cut_points is None: |
| 1228 | + cut_points = {n.name for n in model.graph.input} | { |
| 1229 | + n.name for n in model.graph.initializer |
| 1230 | + } |
| 1231 | + elif model.graph.initializer: |
| 1232 | + cut_points = cut_points | {n.name for n in model.graph.initializer} |
| 1233 | + |
| 1234 | + node = model.graph.node[node_index] |
| 1235 | + selected = {node_index} |
| 1236 | + current_node_index = node_index |
| 1237 | + current_input_index = 0 |
| 1238 | + intermediate = {name} |
| 1239 | + inputs = set(k for k in node.input if k) |
| 1240 | + while not (inputs <= cut_points) and current_node_index >= 0: |
| 1241 | + node = model.graph.node[current_node_index] |
| 1242 | + if current_input_index == 0: |
| 1243 | + needs = [o for o in node.output if o in intermediate and o not in cut_points] |
| 1244 | + if needs: |
| 1245 | + selected.add(current_node_index) |
| 1246 | + else: |
| 1247 | + current_node_index -= 1 |
| 1248 | + continue |
| 1249 | + res = node.input[current_input_index] |
| 1250 | + if res not in cut_points: |
| 1251 | + intermediate.add(res) |
| 1252 | + current_input_index += 1 |
| 1253 | + if current_input_index >= len(node.input): |
| 1254 | + current_node_index -= 1 |
| 1255 | + current_input_index = 0 |
| 1256 | + |
| 1257 | + return [model.graph.node[i] for i in sorted(selected)] |
| 1258 | + |
| 1259 | + |
| 1260 | +def make_submodel( |
| 1261 | + nodes: List[NodeProto], |
| 1262 | + ir_version: int, |
| 1263 | + opset_imports: List[OperatorSetIdProto], |
| 1264 | + output_names: List[str], |
| 1265 | + type_rank_fn: Callable[[str], Tuple[int, int]], |
| 1266 | +) -> ModelProto: |
| 1267 | + """ |
| 1268 | + Creates a model with the given list of nodes. |
| 1269 | + It computes the minimum list of inputs needed for this model. |
| 1270 | + The function assumes the nodes are sorted. |
| 1271 | + It does not handle yet subgraphs. |
| 1272 | +
|
| 1273 | + :param nodes: list of nodes |
| 1274 | + :param ir_version: ir version |
| 1275 | + :param opset_imports: opset import |
| 1276 | + :param output_names: desired outputs |
| 1277 | + :param function: function returning the type and the rank of a result |
| 1278 | + :return: model proto |
| 1279 | + """ |
| 1280 | + |
| 1281 | + def _mkv_(name, itype, irank): |
| 1282 | + return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)]) |
| 1283 | + |
| 1284 | + not_known: Set[str] = set() |
| 1285 | + for node in nodes[::-1]: |
| 1286 | + not_known -= set(node.output) |
| 1287 | + not_known |= set(node.input) |
| 1288 | + |
| 1289 | + model = oh.make_model( |
| 1290 | + oh.make_graph( |
| 1291 | + nodes, |
| 1292 | + "submodel", |
| 1293 | + [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)], |
| 1294 | + [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)], |
| 1295 | + ), |
| 1296 | + ir_version=ir_version, |
| 1297 | + opset_imports=opset_imports, |
| 1298 | + ) |
| 1299 | + return model |
0 commit comments