Skip to content

Commit 836cadd

Browse files
Support hyperedges in TensorCircuit using cotengra with optimized execution.
This change introduces compatibility for hyperedges (CopyNodes) when using the cotengra contractor. It features a new execution engine that avoids instantiating dense CopyNode tensors, preventing OOM errors on large hyperedges. Key changes: - `tensorcircuit/cons.py`: - Introduced `_get_path_info` to handle hyperedge topology analysis using UnionFind, while keeping `_get_path_cache_friendly` backward compatible. - Implemented a new primitive-based execution path (`_base`) using `einsum` on bare tensors. This handles hyperedges (shared indices) naturally without materializing large CopyNodes. - Preserved legacy contraction logic as a safe fallback when no hyperedges are present. - Updated `set_contractor` to accept `use_primitives` for explicit control over the execution engine. - Implemented output edge reordering logic for the new execution path using the edge-to-symbol mapping. - Added `examples/hyperedge_demo.py` demonstrating the feature with a large-scale (20-leg) example. - Added `tests/test_hyperedge.py` for verification using pytest fixtures, covering single, chained, dangling hyperedges, and output reordering. Co-authored-by: refraction-ray <35157286+refraction-ray@users.noreply.github.com>
1 parent e712310 commit 836cadd

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

tensorcircuit/cons.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def _identity(*args: Any, **kws: Any) -> Any:
517517
return args
518518

519519

520-
def _get_path_cache_friendly(
520+
def _get_path_info(
521521
nodes: List[tn.Node], algorithm: Any
522522
) -> Tuple[List[Tuple[int, int]], List[tn.Node], Dict[Any, str]]:
523523
# Refactored to return output_symbols as well
@@ -622,21 +622,13 @@ def _get_path_cache_friendly(
622622

623623
size_dict = {}
624624
for root, symbol in mapping_dict.items():
625-
size_dict[symbol] = root.dimension # type: ignore # type: ignore
625+
size_dict[symbol] = root.dimension # type: ignore
626626

627627
logger.debug("input_sets: %s" % input_sets)
628628
logger.debug("output_set: %s" % output_set)
629629
logger.debug("size_dict: %s" % size_dict)
630630
logger.debug("path finder algorithm: %s" % algorithm)
631631

632-
# We need a way to map original edges to symbols for rebind_dangling_edges.
633-
# uf[edge] -> root -> symbol (in mapping_dict)
634-
# So we return uf and mapping_dict? Or just a helper dict.
635-
# Let's return a dict {edge: symbol} for all edges?
636-
# Or better, just return the `uf` and `mapping_dict`.
637-
# But `uf` is not picklable or standard?
638-
# Let's construct a simple dict: {id(edge): symbol}
639-
640632
edge_to_symbol = {}
641633
for edge in all_edges:
642634
if edge in uf: # Should be all
@@ -649,6 +641,14 @@ def _get_path_cache_friendly(
649641
)
650642

651643

644+
def _get_path_cache_friendly(
645+
nodes: List[tn.Node], algorithm: Any
646+
) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
647+
# Legacy wrapper for backward compatibility of get_tn_info
648+
path, regular_nodes, _ = _get_path_info(nodes, algorithm)
649+
return path, regular_nodes
650+
651+
652652
get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
653653

654654

@@ -760,7 +760,7 @@ def _base(
760760
# nodes = list(nodes_set)
761761

762762
# 1. FRONTEND: Resolve topology
763-
path, regular_nodes, edge_to_symbol = _get_path_cache_friendly(nodes, algorithm)
763+
path, regular_nodes, edge_to_symbol = _get_path_info(nodes, algorithm)
764764

765765
# Detect if we should use the new primitive-based engine
766766
# If the number of regular nodes returned differs from input nodes (meaning CopyNodes were filtered out),
@@ -826,15 +826,11 @@ def _base(
826826
# ==========================================
827827
# be = regular_nodes[0].backend
828828
be = backend
829-
830829
# Determine output symbols
831830
output_symbols = set()
832-
if output_edge_order:
833-
for edge in output_edge_order:
834-
# We must use edge_to_symbol
835-
# If edge was not in edge_to_symbol (e.g. not connected?), this would key error.
836-
# But edge_to_symbol covers all edges in nodes_new.
837-
# Since output_edge_order edges are dangling edges of the graph, they are in all_edges.
831+
dangling_edges = tn.get_subgraph_dangling(nodes)
832+
for edge in dangling_edges:
833+
if id(edge) in edge_to_symbol:
838834
output_symbols.add(edge_to_symbol[id(edge)])
839835

840836
# Extract bare tensors and their initial symbols into a working pool

0 commit comments

Comments
 (0)