Skip to content

Commit 126f2cc

Browse files
Support hyperedges in TensorCircuit using cotengra.
This change introduces compatibility for hyperedges (represented by CopyNodes in TensorNetwork) when using the cotengra contractor. It processes the tensor network graph to merge edges connected via CopyNodes into hyperedges for cotengra's path finding, and then correctly executes the contraction path by absorbing CopyNodes into tensors during contraction. Key changes: - `tensorcircuit/cons.py`: - Updated `_get_path_cache_friendly` to use UnionFind to group edges connected by CopyNodes. - Updated `_base` to absorb CopyNodes when contracting nodes that share them. - Added post-processing to absorb any remaining CopyNodes connected to the final result. - Added `examples/hyperedge_demo.py` demonstrating the feature. - Added `tests/test_hyperedge.py` for verification. Co-authored-by: refraction-ray <35157286+refraction-ray@users.noreply.github.com>
1 parent 8607526 commit 126f2cc

File tree

3 files changed

+252
-17
lines changed

3 files changed

+252
-17
lines changed

examples/hyperedge_demo.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Demonstration of hyperedge support using cotengra in TensorCircuit.
3+
"""
4+
5+
import numpy as np
6+
import tensornetwork as tn
7+
import tensorcircuit as tc
8+
9+
def hyperedge_demo():
10+
print("Demonstrating hyperedge contraction with cotengra...")
11+
12+
# 1. Single Hyperedge Example
13+
# Three tensors A, B, C connected by a single hyperedge (CopyNode)
14+
# Result should be sum_i A_i * B_i * C_i
15+
16+
dim = 2
17+
a = tn.Node(np.array([1.0, 2.0]), name="A")
18+
b = tn.Node(np.array([1.0, 2.0]), name="B")
19+
c = tn.Node(np.array([1.0, 2.0]), name="C")
20+
cn = tn.CopyNode(3, dim, name="CN")
21+
22+
a[0] ^ cn[0]
23+
b[0] ^ cn[1]
24+
c[0] ^ cn[2]
25+
26+
nodes = [a, b, c, cn]
27+
28+
# Set contractor to cotengra
29+
try:
30+
tc.set_contractor("cotengra")
31+
except ImportError:
32+
print("cotengra not installed, skipping demo")
33+
return
34+
35+
res = tc.contractor(nodes)
36+
print("Single Hyperedge Result:", res.tensor)
37+
expected = 1*1*1 + 2*2*2
38+
print(f"Expected: {expected}")
39+
assert np.allclose(res.tensor, expected)
40+
41+
# 2. Chained Hyperedge Example
42+
# A-CN1-B, CN1-CN2, C-CN2-D
43+
# Effectively A, B, C, D share an index
44+
45+
a = tn.Node(np.array([1.0, 2.0]), name="A")
46+
b = tn.Node(np.array([1.0, 2.0]), name="B")
47+
c = tn.Node(np.array([1.0, 2.0]), name="C")
48+
d = tn.Node(np.array([1.0, 2.0]), name="D")
49+
50+
cn1 = tn.CopyNode(3, dim, name="CN1")
51+
cn2 = tn.CopyNode(3, dim, name="CN2")
52+
53+
a[0] ^ cn1[0]
54+
b[0] ^ cn1[1]
55+
cn1[2] ^ cn2[0] # Link between hyperedges
56+
c[0] ^ cn2[1]
57+
d[0] ^ cn2[2]
58+
59+
nodes = [a, b, c, d, cn1, cn2]
60+
res = tc.contractor(nodes)
61+
print("Chained Hyperedge Result:", res.tensor)
62+
expected = 1*1*1*1 + 2*2*2*2
63+
print(f"Expected: {expected}")
64+
assert np.allclose(res.tensor, expected)
65+
66+
if __name__ == "__main__":
67+
hyperedge_demo()

tensorcircuit/cons.py

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import opt_einsum
1717
import tensornetwork as tn
1818
from tensornetwork.backend_contextmanager import get_default_backend
19+
from networkx.utils import UnionFind
1920

2021
from .backends.numpy_backend import NumpyBackend
2122
from .backends import get_backend
@@ -522,29 +523,66 @@ def _get_path_cache_friendly(
522523
nodes = list(nodes)
523524

524525
nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
525-
# if isinstance(algorithm, list):
526-
# return algorithm, [nodes_new]
527526

527+
# split nodes into regular nodes and CopyNodes
528+
regular_nodes = [n for n in nodes_new if not isinstance(n, tn.CopyNode)]
529+
copy_nodes = [n for n in nodes_new if isinstance(n, tn.CopyNode)]
530+
531+
uf = UnionFind()
528532
all_edges = tn.get_all_edges(nodes_new)
529-
all_edges_sorted = sorted_edges(all_edges)
533+
534+
for edge in all_edges:
535+
uf[edge] # init
536+
537+
for cn in copy_nodes:
538+
edges = cn.edges
539+
if edges:
540+
root_edge = edges[0]
541+
for i in range(1, len(edges)):
542+
uf.union(root_edge, edges[i])
543+
530544
mapping_dict = {}
531-
i = 0
532-
for edge in all_edges_sorted:
533-
if id(edge) not in mapping_dict:
534-
mapping_dict[id(edge)] = get_symbol(i)
535-
i += 1
536-
537-
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
538-
output_set = list(
539-
[mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
540-
)
541-
size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
545+
symbol_counter = 0
546+
547+
for node in regular_nodes:
548+
sorted_node_edges = sorted(
549+
node.edges, key=lambda e: e.axis1 if e.node1 is node else e.axis2
550+
)
551+
for edge in sorted_node_edges:
552+
root = uf[edge]
553+
if root not in mapping_dict:
554+
mapping_dict[root] = get_symbol(symbol_counter)
555+
symbol_counter += 1
556+
557+
input_sets = []
558+
for node in regular_nodes:
559+
node_symbols = []
560+
sorted_node_edges = sorted(
561+
node.edges, key=lambda e: e.axis1 if e.node1 is node else e.axis2
562+
)
563+
for edge in sorted_node_edges:
564+
root = uf[edge]
565+
node_symbols.append(mapping_dict[root])
566+
input_sets.append(node_symbols)
567+
568+
dangling_edges = sorted_edges(tn.get_subgraph_dangling(nodes_new))
569+
output_set = []
570+
for edge in dangling_edges:
571+
root = uf[edge]
572+
if root not in mapping_dict:
573+
mapping_dict[root] = get_symbol(symbol_counter)
574+
symbol_counter += 1
575+
output_set.append(mapping_dict[root])
576+
577+
size_dict = {}
578+
for root, symbol in mapping_dict.items():
579+
size_dict[symbol] = root.dimension
580+
542581
logger.debug("input_sets: %s" % input_sets)
543582
logger.debug("output_set: %s" % output_set)
544583
logger.debug("size_dict: %s" % size_dict)
545584
logger.debug("path finder algorithm: %s" % algorithm)
546-
return algorithm(input_sets, output_set, size_dict), nodes_new
547-
# directly get input_sets, output_set and size_dict by using identity function as algorithm
585+
return algorithm(input_sets, output_set, size_dict), regular_nodes
548586

549587

550588
get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
@@ -676,12 +714,38 @@ def _base(
676714
continue
677715
a, b = ab
678716

717+
node_a = nodes[a]
718+
node_b = nodes[b]
719+
720+
node_a_neighbors = set()
721+
for e in node_a.edges:
722+
n = e.node1 if e.node1 is not node_a else e.node2
723+
if n is not None:
724+
node_a_neighbors.add(n)
725+
726+
node_b_neighbors = set()
727+
for e in node_b.edges:
728+
n = e.node1 if e.node1 is not node_b else e.node2
729+
if n is not None:
730+
node_b_neighbors.add(n)
731+
732+
shared_cns = set()
733+
for n in node_a_neighbors:
734+
if isinstance(n, tn.CopyNode) and n in node_b_neighbors:
735+
shared_cns.add(n)
736+
737+
curr_node_a = node_a
738+
for cn in shared_cns:
739+
curr_node_a = tn.contract_between(curr_node_a, cn)
740+
679741
if debug_level == 1:
680742
from .simplify import pseudo_contract_between
681743

682744
new_node = pseudo_contract_between(nodes[a], nodes[b])
683745
else:
684-
new_node = tn.contract_between(nodes[a], nodes[b], allow_outer_product=True)
746+
new_node = tn.contract_between(
747+
curr_node_a, node_b, allow_outer_product=True
748+
)
685749
nodes.append(new_node)
686750
# nodes[a] = backend.zeros([1])
687751
# nodes[b] = backend.zeros([1])
@@ -694,6 +758,17 @@ def _base(
694758
# if the final node has more than one edge,
695759
# output_edge_order has to be specified
696760
final_node = nodes[0] # nodes were connected, we checked this
761+
762+
while True:
763+
cns = []
764+
for e in final_node.edges:
765+
n = e.node1 if e.node1 is not final_node else e.node2
766+
if n is not None and isinstance(n, tn.CopyNode):
767+
cns.append(n)
768+
if not cns:
769+
break
770+
final_node = tn.contract_between(final_node, cns[0])
771+
697772
if not ignore_edge_order:
698773
final_node.reorder_edges(output_edge_order)
699774
return final_node

tests/test_hyperedge.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import numpy as np
2+
import tensornetwork as tn
3+
import tensorcircuit as tc
4+
import pytest
5+
6+
# Ensure cotengra is available, otherwise skip tests
7+
try:
8+
import cotengra
9+
has_cotengra = True
10+
except ImportError:
11+
has_cotengra = False
12+
13+
@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
14+
def test_single_hyperedge():
15+
# A(i), B(i), C(i)
16+
dim = 2
17+
a = tn.Node(np.array([1.0, 2.0]), name="A")
18+
b = tn.Node(np.array([1.0, 2.0]), name="B")
19+
c = tn.Node(np.array([1.0, 2.0]), name="C")
20+
cn = tn.CopyNode(3, dim, name="CN")
21+
22+
a[0] ^ cn[0]
23+
b[0] ^ cn[1]
24+
c[0] ^ cn[2]
25+
26+
nodes = [a, b, c, cn]
27+
tc.set_contractor("cotengra")
28+
res = tc.contractor(nodes)
29+
assert np.allclose(res.tensor, 9.0)
30+
31+
@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
32+
def test_chained_hyperedge():
33+
# A(i), B(i), C(i), D(i)
34+
# Connected via two CopyNodes: A-CN1-B, CN1-CN2, C-CN2-D
35+
dim = 2
36+
a = tn.Node(np.array([1.0, 2.0]), name="A")
37+
b = tn.Node(np.array([1.0, 2.0]), name="B")
38+
c = tn.Node(np.array([1.0, 2.0]), name="C")
39+
d = tn.Node(np.array([1.0, 2.0]), name="D")
40+
41+
cn1 = tn.CopyNode(3, dim, name="CN1")
42+
cn2 = tn.CopyNode(3, dim, name="CN2")
43+
44+
a[0] ^ cn1[0]
45+
b[0] ^ cn1[1]
46+
cn1[2] ^ cn2[0] # Link
47+
c[0] ^ cn2[1]
48+
d[0] ^ cn2[2]
49+
50+
nodes = [a, b, c, d, cn1, cn2]
51+
tc.set_contractor("cotengra")
52+
res = tc.contractor(nodes)
53+
# sum i A_i B_i C_i D_i = 1+16 = 17
54+
assert np.allclose(res.tensor, 17.0)
55+
56+
@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
57+
def test_dangling_hyperedge():
58+
# A(i), B(i), Output(i)
59+
dim = 2
60+
a = tn.Node(np.array([1.0, 2.0]), name="A")
61+
b = tn.Node(np.array([1.0, 2.0]), name="B")
62+
cn = tn.CopyNode(3, dim, name="CN")
63+
64+
a[0] ^ cn[0]
65+
b[0] ^ cn[1]
66+
# cn[2] is dangling
67+
68+
nodes = [a, b, cn]
69+
tc.set_contractor("cotengra")
70+
res = tc.contractor(nodes) # Should return a tensor of shape (2,)
71+
72+
# Expected: C_i = A_i * B_i => [1, 4]
73+
assert np.allclose(res.tensor, np.array([1.0, 4.0]))
74+
75+
@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
76+
def test_tensorcircuit_circuit_hyperedge_support():
77+
# While TC circuit doesn't typically create CopyNodes directly in gates,
78+
# ensuring the contractor works with general graphs is key.
79+
# This test just ensures normal circuit simulation still works with cotengra
80+
# (which implies the new logic handles regular nodes correctly too).
81+
c = tc.Circuit(2)
82+
c.H(0)
83+
c.CNOT(0, 1)
84+
85+
tc.set_contractor("cotengra")
86+
state = c.state()
87+
# Bell state |00> + |11>
88+
expected = np.array([1, 0, 0, 1]) / np.sqrt(2)
89+
# The phase might vary? No, standard gates are deterministic.
90+
# But H gate normalization 1/sqrt(2).
91+
# |0> -> (|0>+|1>)/rt2 -> |00> + |11> / rt2.
92+
# Check absolute values
93+
assert np.allclose(np.abs(state), np.abs(expected))

0 commit comments

Comments
 (0)