|
23 | 23 |
|
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 |
|
| 26 | +## monkey patch |
| 27 | +_NODE_CREATION_COUNTER = 0 |
| 28 | +_original_node_init = tn.Node.__init__ |
| 29 | + |
| 30 | + |
| 31 | +@wraps(_original_node_init) |
| 32 | +def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None: |
| 33 | + """Patched Node.__init__ to add a stable creation ID.""" |
| 34 | + global _NODE_CREATION_COUNTER |
| 35 | + _original_node_init(self, *args, **kwargs) |
| 36 | + self._stable_id_ = _NODE_CREATION_COUNTER |
| 37 | + _NODE_CREATION_COUNTER += 1 |
| 38 | + |
| 39 | + |
| 40 | +tn.Node.__init__ = _patched_node_init |
| 41 | + |
| 42 | + |
| 43 | +def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]: |
| 44 | + n1, n2 = edge.node1, edge.node2 |
| 45 | + id1 = getattr(n1, "_stable_id_", -1) |
| 46 | + id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2 # -2 for dangling |
| 47 | + |
| 48 | + if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2): |
| 49 | + id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1 |
| 50 | + else: |
| 51 | + ax1, ax2 = edge.axis1, edge.axis2 |
| 52 | + return (id1, ax1, id2, ax2) |
| 53 | + |
| 54 | + |
| 55 | +def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]: |
| 56 | + return sorted(edges, key=_get_edge_stable_key) |
| 57 | + |
| 58 | + |
26 | 59 | package_name = "tensorcircuit" |
27 | 60 | thismodule = sys.modules[__name__] |
28 | 61 | dtypestr = "complex64" |
@@ -477,39 +510,29 @@ def _identity(*args: Any, **kws: Any) -> Any: |
477 | 510 | return args |
478 | 511 |
|
479 | 512 |
|
480 | | -def _sort_tuple_list(input_list: List[Any], output_list: List[Any]) -> List[Any]: |
481 | | - sorted_elements = [(tuple(sorted(t)), i) for i, t in enumerate(input_list)] |
482 | | - sorted_elements.sort() |
483 | | - return [output_list[i] for _, i in sorted_elements] |
484 | | - |
485 | | - |
486 | 513 | def _get_path_cache_friendly( |
487 | 514 | nodes: List[tn.Node], algorithm: Any |
488 | 515 | ) -> Tuple[List[Tuple[int, int]], List[tn.Node]]: |
489 | 516 | nodes = list(nodes) |
490 | | - mapping_dict = {} |
491 | | - i = 0 |
492 | | - for n in nodes: |
493 | | - for e in n: |
494 | | - if id(e) not in mapping_dict: |
495 | | - mapping_dict[id(e)] = get_symbol(i) |
496 | | - i += 1 |
497 | | - # TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed? |
498 | | - input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes] |
499 | | - # placeholder = [[1e20 for _ in range(100)]] |
500 | | - # order = np.argsort(np.array(list(map(sorted, input_sets)), dtype=object)) # type: ignore |
501 | | - # nodes_new = [nodes[i] for i in order] |
502 | | - nodes_new = _sort_tuple_list(input_sets, nodes) |
| 517 | + |
| 518 | + nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1)) |
503 | 519 | if isinstance(algorithm, list): |
504 | 520 | return algorithm, nodes_new |
505 | 521 |
|
| 522 | + all_edges = tn.get_all_edges(nodes_new) |
| 523 | + all_edges_sorted = sorted_edges(all_edges) |
| 524 | + mapping_dict = {} |
| 525 | + i = 0 |
| 526 | + for edge in all_edges_sorted: |
| 527 | + if id(edge) not in mapping_dict: |
| 528 | + mapping_dict[id(edge)] = get_symbol(i) |
| 529 | + i += 1 |
| 530 | + |
506 | 531 | input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new] |
507 | 532 | output_set = list( |
508 | | - [mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)] |
| 533 | + [mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))] |
509 | 534 | ) |
510 | | - size_dict = { |
511 | | - mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new) |
512 | | - } |
| 535 | + size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted} |
513 | 536 | logger.debug("input_sets: %s" % input_sets) |
514 | 537 | logger.debug("output_set: %s" % output_set) |
515 | 538 | logger.debug("size_dict: %s" % size_dict) |
|
0 commit comments