Skip to content

Commit e712310

Browse files
Support hyperedges in TensorCircuit using cotengra with optimized execution.
This change introduces compatibility for hyperedges (CopyNodes) when using the cotengra contractor. It features a new execution engine that avoids instantiating dense CopyNode tensors, preventing OOM errors on large hyperedges. Key changes: - `tensorcircuit/cons.py`: - Updated `_get_path_cache_friendly` to use UnionFind to group edges connected by CopyNodes. - Implemented a new primitive-based execution path (`_base`) using `einsum` on bare tensors. This handles hyperedges (shared indices) naturally without materializing large CopyNodes. - Preserved legacy contraction logic as a safe fallback when no hyperedges are present. - Updated `set_contractor` to accept `use_primitives` for explicit control over the execution engine. - Implemented output edge reordering logic for the new execution path. - Added `examples/hyperedge_demo.py` demonstrating the feature with a large-scale (20-leg) example. - Added `tests/test_hyperedge.py` for verification using pytest fixtures, covering single, chained, dangling hyperedges, and output reordering. Co-authored-by: refraction-ray <35157286+refraction-ray@users.noreply.github.com>
1 parent fa7f09e commit e712310

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed

tensorcircuit/cons.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _identity(*args: Any, **kws: Any) -> Any:
519519

520520
def _get_path_cache_friendly(
521521
nodes: List[tn.Node], algorithm: Any
522-
) -> Tuple[List[Tuple[int, int]], List[tn.Node], List[str]]:
522+
) -> Tuple[List[Tuple[int, int]], List[tn.Node], Dict[Any, str]]:
523523
# Refactored to return output_symbols as well
524524
nodes = list(nodes)
525525

@@ -559,8 +559,8 @@ def _get_path_cache_friendly(
559559
return (
560560
algorithm(input_sets, output_set, size_dict),
561561
nodes_new,
562-
output_set,
563-
) # Added output_set
562+
mapping_dict, # Return mapping dict to reconstruct output order
563+
)
564564

565565
# Hyperedge logic with UnionFind
566566
uf = UnionFind()
@@ -622,17 +622,31 @@ def _get_path_cache_friendly(
622622

623623
size_dict = {}
624624
for root, symbol in mapping_dict.items():
625-
size_dict[symbol] = root.dimension # type: ignore # type: ignore # type: ignore
625+
size_dict[symbol] = root.dimension # type: ignore # type: ignore
626626

627627
logger.debug("input_sets: %s" % input_sets)
628628
logger.debug("output_set: %s" % output_set)
629629
logger.debug("size_dict: %s" % size_dict)
630630
logger.debug("path finder algorithm: %s" % algorithm)
631+
632+
# We need a way to map original edges to symbols for rebind_dangling_edges.
633+
# uf[edge] -> root -> symbol (in mapping_dict)
634+
# So we return uf and mapping_dict? Or just a helper dict.
635+
# Let's return a dict {edge: symbol} for all edges?
636+
# Or better, just return the `uf` and `mapping_dict`.
637+
# But `uf` is not picklable or standard?
638+
# Let's construct a simple dict: {id(edge): symbol}
639+
640+
edge_to_symbol = {}
641+
for edge in all_edges:
642+
if edge in uf: # Should be all
643+
edge_to_symbol[id(edge)] = mapping_dict[uf[edge]]
644+
631645
return (
632646
algorithm(input_sets, output_set, size_dict),
633647
regular_nodes,
634-
output_set,
635-
) # Added output_set
648+
edge_to_symbol, # Use this instead of output_set/mapping_dict for output reordering
649+
)
636650

637651

638652
get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
@@ -687,7 +701,24 @@ def _base(
687701
) -> tn.Node:
688702
"""
689703
The base method for all `opt_einsum` contractors.
690-
...
704+
705+
:param nodes: A collection of connected nodes.
706+
:type nodes: List[tn.Node]
707+
:pram algorithm: `opt_einsum` contraction method to use.
708+
:type algorithm: Any
709+
:param output_edge_order: An optional list of edges. Edges of the
710+
final node in `nodes_set` are reordered into `output_edge_order`;
711+
if final node has more than one edge, `output_edge_order` must be provided.
712+
:type output_edge_order: Optional[Sequence[tn.Edge]], optional
713+
:param ignore_edge_order: An option to ignore the output edge order.
714+
:type ignore_edge_order: bool
715+
:param total_size: The total size of the tensor network.
716+
:type total_size: Optional[int], optional
717+
:raises ValueError:"The final node after contraction has more than
718+
one remaining edge. In this case `output_edge_order` has to be provided," or
719+
"Output edges are not equal to the remaining non-contracted edges of the final node."
720+
:return: The final node after full contraction.
721+
:rtype: tn.Node
691722
"""
692723
# rewrite tensornetwork default to add logging infras
693724
nodes_set = set(nodes)
@@ -729,7 +760,7 @@ def _base(
729760
# nodes = list(nodes_set)
730761

731762
# 1. FRONTEND: Resolve topology
732-
path, regular_nodes, output_symbols = _get_path_cache_friendly(nodes, algorithm)
763+
path, regular_nodes, edge_to_symbol = _get_path_cache_friendly(nodes, algorithm)
733764

734765
# Detect if we should use the new primitive-based engine
735766
# If the number of regular nodes returned differs from input nodes (meaning CopyNodes were filtered out),
@@ -796,6 +827,16 @@ def _base(
796827
# be = regular_nodes[0].backend
797828
be = backend
798829

830+
# Determine output symbols
831+
output_symbols = set()
832+
if output_edge_order:
833+
for edge in output_edge_order:
834+
# We must use edge_to_symbol
835+
# If edge was not in edge_to_symbol (e.g. not connected?), this would key error.
836+
# But edge_to_symbol covers all edges in nodes_new.
837+
# Since output_edge_order edges are dangling edges of the graph, they are in all_edges.
838+
output_symbols.add(edge_to_symbol[id(edge)])
839+
799840
# Extract bare tensors and their initial symbols into a working pool
800841
# Pool elements are just tuples: (raw_tensor, ["a", "b", ...])
801842
# _symbols was attached in _get_path_cache_friendly
@@ -876,7 +917,7 @@ def _base(
876917
total_size += reduce(mul, be.shape_tuple(new_tensor) + (1,)) # type: ignore
877918

878919
# RE-ENTRY: Wrap the final bare tensor back into the Graph world
879-
final_raw_tensor, _ = tensor_pool[0]
920+
final_raw_tensor, final_symbols = tensor_pool[0]
880921

881922
# We need to ensure the final tensor's axes match the output_edge_order
882923
# output_edge_order is a list of Edges.
@@ -915,7 +956,16 @@ def _base(
915956
# To match robustly:
916957
# We need the `uf` and `mapping_dict` from `_get_path_cache_friendly`.
917958
# Refactoring `_get_path_cache_friendly` to return a `get_symbol_for_edge` callable or dict?
918-
pass # Handling below
959+
960+
target_symbols = []
961+
for edge in output_edge_order:
962+
target_symbols.append(edge_to_symbol[id(edge)])
963+
964+
# We need to find permutation such that final_symbols[p] matches target_symbols
965+
# final_symbols should be a permutation of target_symbols (same set)
966+
967+
perm = [final_symbols.index(s) for s in target_symbols]
968+
final_node.tensor = be.transpose(final_node.tensor, perm)
919969

920970
# For now, let's assume the user wants the result.
921971
# TensorNetwork's `reorder_edges` expects the node to have those specific edge objects attached.

tests/test_hyperedge.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,40 @@ def test_tensorcircuit_circuit_hyperedge_support(contractor_setup):
107107
# Bell state |00> + |11>
108108
expected = np.array([1, 0, 0, 1]) / np.sqrt(2)
109109
assert np.allclose(np.abs(state), np.abs(expected))
110+
111+
112+
@pytest.mark.parametrize("contractor_setup", ["cotengra"], indirect=True)
113+
def test_hyperedge_output_reordering(contractor_setup):
114+
# Test ensuring that output edge ordering works with hyperedge contraction logic
115+
# A(i, j) -> connected to two separate CopyNodes?
116+
# Let's simple case: A(i) -> CN1 -> Output1(i)
117+
# B(j) -> CN2 -> Output2(j)
118+
# Output edges order [Output2, Output1] -> result should be transpose(A x B)
119+
120+
dim = 2
121+
a = tn.Node(np.array([1.0, 2.0]), name="A")
122+
b = tn.Node(np.array([3.0, 4.0]), name="B")
123+
124+
cn1 = tn.CopyNode(2, dim, name="CN1")
125+
cn2 = tn.CopyNode(2, dim, name="CN2")
126+
127+
a[0] ^ cn1[0]
128+
b[0] ^ cn2[0]
129+
130+
# Dangling: cn1[1], cn2[1]
131+
132+
nodes = [a, b, cn1, cn2]
133+
134+
# Default order is usually arbitrary or determined by graph traversal
135+
# We specify explicit order
136+
output_edge_order = [cn2[1], cn1[1]]
137+
138+
res = tc.contractor(nodes, output_edge_order=output_edge_order)
139+
140+
# Result should be B outer A = [3, 4] outer [1, 2] = [[3, 6], [4, 8]]
141+
# (j, i) where j is from B, i is from A
142+
143+
expected = np.outer(np.array([3.0, 4.0]), np.array([1.0, 2.0]))
144+
145+
assert np.allclose(res.tensor, expected)
146+
assert res.tensor.shape == (2, 2)

0 commit comments

Comments
 (0)