Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/hyperedge_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Demonstration of hyperedge support using cotengra in TensorCircuit.
"""

import numpy as np
import tensornetwork as tn
import tensorcircuit as tc

def hyperedge_demo():
print("Demonstrating hyperedge contraction with cotengra...")

# 1. Single Hyperedge Example
# Three tensors A, B, C connected by a single hyperedge (CopyNode)
# Result should be sum_i A_i * B_i * C_i

dim = 2
a = tn.Node(np.array([1.0, 2.0]), name="A")
b = tn.Node(np.array([1.0, 2.0]), name="B")
c = tn.Node(np.array([1.0, 2.0]), name="C")
cn = tn.CopyNode(3, dim, name="CN")

a[0] ^ cn[0]
b[0] ^ cn[1]
c[0] ^ cn[2]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should demonstrate something large, say a copynode with 20 legs, to test the memory and time efficiency of the hyperedge solution end to end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a large-scale example with a 20-leg CopyNode to examples/hyperedge_demo.py to demonstrate efficiency.


nodes = [a, b, c, cn]

# Set contractor to cotengra
try:
tc.set_contractor("cotengra")
except ImportError:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont over protection in examples, no need for except here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the try-except block in the example script.

print("cotengra not installed, skipping demo")
return

res = tc.contractor(nodes)
print("Single Hyperedge Result:", res.tensor)
expected = 1*1*1 + 2*2*2
print(f"Expected: {expected}")
assert np.allclose(res.tensor, expected)

# 2. Chained Hyperedge Example
# A-CN1-B, CN1-CN2, C-CN2-D
# Effectively A, B, C, D share an index

a = tn.Node(np.array([1.0, 2.0]), name="A")
b = tn.Node(np.array([1.0, 2.0]), name="B")
c = tn.Node(np.array([1.0, 2.0]), name="C")
d = tn.Node(np.array([1.0, 2.0]), name="D")

cn1 = tn.CopyNode(3, dim, name="CN1")
cn2 = tn.CopyNode(3, dim, name="CN2")

a[0] ^ cn1[0]
b[0] ^ cn1[1]
cn1[2] ^ cn2[0] # Link between hyperedges
c[0] ^ cn2[1]
d[0] ^ cn2[2]

nodes = [a, b, c, d, cn1, cn2]
res = tc.contractor(nodes)
print("Chained Hyperedge Result:", res.tensor)
expected = 1*1*1*1 + 2*2*2*2
print(f"Expected: {expected}")
assert np.allclose(res.tensor, expected)

if __name__ == "__main__":
hyperedge_demo()
109 changes: 92 additions & 17 deletions tensorcircuit/cons.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import opt_einsum
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black, mypy and pylint check!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran black, mypy, and pylint. Fixed a new mypy error related to UnionFind usage.

import tensornetwork as tn
from tensornetwork.backend_contextmanager import get_default_backend
from networkx.utils import UnionFind

from .backends.numpy_backend import NumpyBackend
from .backends import get_backend
Expand Down Expand Up @@ -522,29 +523,66 @@ def _get_path_cache_friendly(
nodes = list(nodes)

nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
# if isinstance(algorithm, list):
# return algorithm, [nodes_new]

# split nodes into regular nodes and CopyNodes
regular_nodes = [n for n in nodes_new if not isinstance(n, tn.CopyNode)]
copy_nodes = [n for n in nodes_new if isinstance(n, tn.CopyNode)]

uf = UnionFind()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is no copy_nodes, all these unionfind stuff should never be activated, the whole processing logic should be identical to the original version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimized _get_path_cache_friendly to check for CopyNodes first. If none are present, it falls back to the original logic, skipping UnionFind overhead.

all_edges = tn.get_all_edges(nodes_new)
all_edges_sorted = sorted_edges(all_edges)

for edge in all_edges:
uf[edge] # init

for cn in copy_nodes:
edges = cn.edges
if edges:
root_edge = edges[0]
for i in range(1, len(edges)):
uf.union(root_edge, edges[i])

mapping_dict = {}
i = 0
for edge in all_edges_sorted:
if id(edge) not in mapping_dict:
mapping_dict[id(edge)] = get_symbol(i)
i += 1

input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
output_set = list(
[mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
)
size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
symbol_counter = 0

for node in regular_nodes:
sorted_node_edges = sorted(
node.edges, key=lambda e: e.axis1 if e.node1 is node else e.axis2
)
for edge in sorted_node_edges:
root = uf[edge]
if root not in mapping_dict:
mapping_dict[root] = get_symbol(symbol_counter)
symbol_counter += 1

input_sets = []
for node in regular_nodes:
node_symbols = []
sorted_node_edges = sorted(
node.edges, key=lambda e: e.axis1 if e.node1 is node else e.axis2
)
for edge in sorted_node_edges:
root = uf[edge]
node_symbols.append(mapping_dict[root])
input_sets.append(node_symbols)

dangling_edges = sorted_edges(tn.get_subgraph_dangling(nodes_new))
output_set = []
for edge in dangling_edges:
root = uf[edge]
if root not in mapping_dict:
mapping_dict[root] = get_symbol(symbol_counter)
symbol_counter += 1
output_set.append(mapping_dict[root])

size_dict = {}
for root, symbol in mapping_dict.items():
size_dict[symbol] = root.dimension

logger.debug("input_sets: %s" % input_sets)
logger.debug("output_set: %s" % output_set)
logger.debug("size_dict: %s" % size_dict)
logger.debug("path finder algorithm: %s" % algorithm)
return algorithm(input_sets, output_set, size_dict), nodes_new
# directly get input_sets, output_set and size_dict by using identity function as algorithm
return algorithm(input_sets, output_set, size_dict), regular_nodes


get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
Expand Down Expand Up @@ -676,12 +714,38 @@ def _base(
continue
a, b = ab

node_a = nodes[a]
node_b = nodes[b]

node_a_neighbors = set()
for e in node_a.edges:
n = e.node1 if e.node1 is not node_a else e.node2
if n is not None:
node_a_neighbors.add(n)

node_b_neighbors = set()
for e in node_b.edges:
n = e.node1 if e.node1 is not node_b else e.node2
if n is not None:
node_b_neighbors.add(n)

shared_cns = set()
for n in node_a_neighbors:
if isinstance(n, tn.CopyNode) and n in node_b_neighbors:
shared_cns.add(n)

curr_node_a = node_a
for cn in shared_cns:
curr_node_a = tn.contract_between(curr_node_a, cn)

if debug_level == 1:
from .simplify import pseudo_contract_between

new_node = pseudo_contract_between(nodes[a], nodes[b])
else:
new_node = tn.contract_between(nodes[a], nodes[b], allow_outer_product=True)
new_node = tn.contract_between(
curr_node_a, node_b, allow_outer_product=True
)
nodes.append(new_node)
# nodes[a] = backend.zeros([1])
# nodes[b] = backend.zeros([1])
Expand All @@ -694,6 +758,17 @@ def _base(
# if the final node has more than one edge,
# output_edge_order has to be specified
final_node = nodes[0] # nodes were connected, we checked this

while True:
cns = []
for e in final_node.edges:
n = e.node1 if e.node1 is not final_node else e.node2
if n is not None and isinstance(n, tn.CopyNode):
cns.append(n)
if not cns:
break
final_node = tn.contract_between(final_node, cns[0])

if not ignore_edge_order:
final_node.reorder_edges(output_edge_order)
return final_node
Expand Down
93 changes: 93 additions & 0 deletions tests/test_hyperedge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code duplication in this file is high

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduced duplication by using parametrized tests and fixtures.

import tensornetwork as tn
import tensorcircuit as tc
import pytest

# Ensure cotengra is available, otherwise skip tests
try:
import cotengra
has_cotengra = True
except ImportError:
has_cotengra = False

@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
def test_single_hyperedge():
# A(i), B(i), C(i)
dim = 2
a = tn.Node(np.array([1.0, 2.0]), name="A")
b = tn.Node(np.array([1.0, 2.0]), name="B")
c = tn.Node(np.array([1.0, 2.0]), name="C")
cn = tn.CopyNode(3, dim, name="CN")

a[0] ^ cn[0]
b[0] ^ cn[1]
c[0] ^ cn[2]

nodes = [a, b, c, cn]
tc.set_contractor("cotengra")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also have some test with default contractor to test it works

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case using the default greedy contractor to ensure it still works correctly (or fails gracefully if unsupported, though simple CopyNode cases might work).

res = tc.contractor(nodes)
assert np.allclose(res.tensor, 9.0)

@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
def test_chained_hyperedge():
# A(i), B(i), C(i), D(i)
# Connected via two CopyNodes: A-CN1-B, CN1-CN2, C-CN2-D
dim = 2
a = tn.Node(np.array([1.0, 2.0]), name="A")
b = tn.Node(np.array([1.0, 2.0]), name="B")
c = tn.Node(np.array([1.0, 2.0]), name="C")
d = tn.Node(np.array([1.0, 2.0]), name="D")

cn1 = tn.CopyNode(3, dim, name="CN1")
cn2 = tn.CopyNode(3, dim, name="CN2")

a[0] ^ cn1[0]
b[0] ^ cn1[1]
cn1[2] ^ cn2[0] # Link
c[0] ^ cn2[1]
d[0] ^ cn2[2]

nodes = [a, b, c, d, cn1, cn2]
tc.set_contractor("cotengra")
res = tc.contractor(nodes)
# sum i A_i B_i C_i D_i = 1+16 = 17
assert np.allclose(res.tensor, 17.0)

@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
def test_dangling_hyperedge():
# A(i), B(i), Output(i)
dim = 2
a = tn.Node(np.array([1.0, 2.0]), name="A")
b = tn.Node(np.array([1.0, 2.0]), name="B")
cn = tn.CopyNode(3, dim, name="CN")

a[0] ^ cn[0]
b[0] ^ cn[1]
# cn[2] is dangling

nodes = [a, b, cn]
tc.set_contractor("cotengra")
res = tc.contractor(nodes) # Should return a tensor of shape (2,)

# Expected: C_i = A_i * B_i => [1, 4]
assert np.allclose(res.tensor, np.array([1.0, 4.0]))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test case for non scalar output?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test_hyperedge_output_reordering to verifying non-scalar output and edge reordering.

@pytest.mark.skipif(not has_cotengra, reason="cotengra not installed")
def test_tensorcircuit_circuit_hyperedge_support():
# While TC circuit doesn't typically create CopyNodes directly in gates,
# ensuring the contractor works with general graphs is key.
# This test just ensures normal circuit simulation still works with cotengra
# (which implies the new logic handles regular nodes correctly too).
c = tc.Circuit(2)
c.H(0)
c.CNOT(0, 1)

tc.set_contractor("cotengra")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont set contengra like this, what is the test fails, the set cannot be cleared, should define and use as a fixture

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored tests to use a contractor_setup fixture that handles setup and teardown, ensuring the contractor is reset even if tests fail.

state = c.state()
# Bell state |00> + |11>
expected = np.array([1, 0, 0, 1]) / np.sqrt(2)
# The phase might vary? No, standard gates are deterministic.
# But H gate normalization 1/sqrt(2).
# |0> -> (|0>+|1>)/rt2 -> |00> + |11> / rt2.
# Check absolute values
assert np.allclose(np.abs(state), np.abs(expected))
Loading