Skip to content

Commit 4267446

Browse files
more stable contractor infra
1 parent aaf221a commit 4267446

File tree

3 files changed

+48
-28
lines changed

3 files changed

+48
-28
lines changed

examples/lattice_neighbor_benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
neighbor search and a baseline all-to-all distance matrix method.
66
As shown by the results, the KDTree approach offers a significant speedup,
77
especially when calculating for a large number of neighbor shells (large max_k).
8-
9-
To run this script from the project's root directory:
10-
python examples/templates/lattice_neighbor_benchmark.py
118
"""
129

1310
import timeit

tensorcircuit/cons.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,39 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
## monkey patch
27+
_NODE_CREATION_COUNTER = 0
28+
_original_node_init = tn.Node.__init__
29+
30+
31+
@wraps(_original_node_init)
32+
def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None:
33+
"""Patched Node.__init__ to add a stable creation ID."""
34+
global _NODE_CREATION_COUNTER
35+
_original_node_init(self, *args, **kwargs)
36+
self._stable_id_ = _NODE_CREATION_COUNTER
37+
_NODE_CREATION_COUNTER += 1
38+
39+
40+
tn.Node.__init__ = _patched_node_init
41+
42+
43+
def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]:
44+
n1, n2 = edge.node1, edge.node2
45+
id1 = getattr(n1, "_stable_id_", -1)
46+
id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2 # -2 for dangling
47+
48+
if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2):
49+
id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1
50+
else:
51+
ax1, ax2 = edge.axis1, edge.axis2
52+
return (id1, ax1, id2, ax2)
53+
54+
55+
def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]:
56+
return sorted(edges, key=_get_edge_stable_key)
57+
58+
2659
package_name = "tensorcircuit"
2760
thismodule = sys.modules[__name__]
2861
dtypestr = "complex64"
@@ -477,39 +510,29 @@ def _identity(*args: Any, **kws: Any) -> Any:
477510
return args
478511

479512

480-
def _sort_tuple_list(input_list: List[Any], output_list: List[Any]) -> List[Any]:
481-
sorted_elements = [(tuple(sorted(t)), i) for i, t in enumerate(input_list)]
482-
sorted_elements.sort()
483-
return [output_list[i] for _, i in sorted_elements]
484-
485-
486513
def _get_path_cache_friendly(
487514
nodes: List[tn.Node], algorithm: Any
488515
) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
489516
nodes = list(nodes)
490-
mapping_dict = {}
491-
i = 0
492-
for n in nodes:
493-
for e in n:
494-
if id(e) not in mapping_dict:
495-
mapping_dict[id(e)] = get_symbol(i)
496-
i += 1
497-
# TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
498-
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
499-
# placeholder = [[1e20 for _ in range(100)]]
500-
# order = np.argsort(np.array(list(map(sorted, input_sets)), dtype=object)) # type: ignore
501-
# nodes_new = [nodes[i] for i in order]
502-
nodes_new = _sort_tuple_list(input_sets, nodes)
517+
518+
nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
503519
if isinstance(algorithm, list):
504520
return algorithm, nodes_new
505521

522+
all_edges = tn.get_all_edges(nodes_new)
523+
all_edges_sorted = sorted_edges(all_edges)
524+
mapping_dict = {}
525+
i = 0
526+
for edge in all_edges_sorted:
527+
if id(edge) not in mapping_dict:
528+
mapping_dict[id(edge)] = get_symbol(i)
529+
i += 1
530+
506531
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
507532
output_set = list(
508-
[mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)]
533+
[mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
509534
)
510-
size_dict = {
511-
mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new)
512-
}
535+
size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
513536
logger.debug("input_sets: %s" % input_sets)
514537
logger.debug("output_set: %s" % output_set)
515538
logger.debug("size_dict: %s" % size_dict)

tensorcircuit/quantum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _reachable(nodes: List[AbstractNode]) -> List[AbstractNode]:
7171
if n not in seen_nodes and n not in node_que[i + 1 :]:
7272
node_que.append(n)
7373
i += 1
74-
return seen_nodes
74+
return sorted(seen_nodes, key=lambda node: getattr(node, "_stable_id_", -1))
7575

7676

7777
def reachable(
@@ -1164,7 +1164,7 @@ def tn2qop(tn_mpo: Any) -> QuOperator:
11641164
nwires = len(tn_mpo)
11651165
mpo = []
11661166
for i in range(nwires):
1167-
mpo.append(Node(tn_mpo[i]))
1167+
mpo.append(Node(tn_mpo[i], name=f"mpo_{i}"))
11681168

11691169
for i in range(nwires - 1):
11701170
connect(mpo[i][1], mpo[i + 1][0])

0 commit comments

Comments
 (0)