Skip to content

Commit 5891a61

Browse files
add node capture infrastructure
1 parent 4dc09da commit 5891a61

File tree

4 files changed

+95
-2
lines changed

4 files changed

+95
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
- Add `with_prob` for `stabilizercircuit.measure()`.
1414

15+
- Add `tc.cons.function_nodes_capture` decorator and `tc.cons.runtime_nodes_capture` context manager for directly return nodes before real contraction.
16+
1517
### Fixed
1618

1719
- Fix the nodes order in contraction by giving each node a global `_stable_id_`.

tensorcircuit/cons.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ def _get_path_cache_friendly(
516516
nodes = list(nodes)
517517

518518
nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
519-
if isinstance(algorithm, list):
520-
return algorithm, nodes_new
519+
# if isinstance(algorithm, list):
520+
# return algorithm, [nodes_new]
521521

522522
all_edges = tn.get_all_edges(nodes_new)
523523
all_edges_sorted = sorted_edges(all_edges)
@@ -693,6 +693,51 @@ def _base(
693693
return final_node
694694

695695

696+
class NodesReturn(Exception):
697+
"""
698+
Intentionally stop execution to return a value.
699+
"""
700+
701+
def __init__(self, value_to_return: Any):
702+
self.value = value_to_return
703+
super().__init__(
704+
f"Intentionally stopping execution to return: {value_to_return}"
705+
)
706+
707+
708+
def _get_sorted_nodes(nodes: List[Any], *args: Any, **kws: Any) -> Any:
709+
nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
710+
raise NodesReturn(nodes_new)
711+
712+
713+
def function_nodes_capture(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
714+
@wraps(func)
715+
def wrapper(*args: Any, **kwargs: Any) -> Any:
716+
with runtime_contractor(method="before"):
717+
try:
718+
result = func(*args, **kwargs)
719+
return result
720+
except NodesReturn as e:
721+
return e.value
722+
723+
return wrapper
724+
725+
726+
@contextmanager
727+
def runtime_nodes_capture(key: str = "nodes") -> Iterator[Any]:
728+
old_contractor = getattr(thismodule, "contractor")
729+
set_contractor(method="before")
730+
captured_value: Dict[str, List[tn.Node]] = {}
731+
try:
732+
yield captured_value
733+
except NodesReturn as e:
734+
captured_value[key] = e.value
735+
finally:
736+
for module in sys.modules:
737+
if module.startswith(package_name):
738+
setattr(sys.modules[module], "contractor", old_contractor)
739+
740+
696741
def custom(
697742
nodes: List[Any],
698743
optimizer: Any,
@@ -763,6 +808,16 @@ def custom_stateful(
763808

764809
# only work for custom
765810
def contraction_info_decorator(algorithm: Callable[..., Any]) -> Callable[..., Any]:
811+
"""Decorator to add contraction information logging to an optimizer.
812+
813+
This decorator wraps an optimization algorithm and prints detailed information
814+
about the contraction cost (FLOPs, size, write) and path finding time.
815+
816+
:param algorithm: The optimization algorithm to decorate.
817+
:type algorithm: Callable[..., Any]
818+
:return: The decorated optimization algorithm.
819+
:rtype: Callable[..., Any]
820+
"""
766821
from cotengra import ContractionTree
767822

768823
def new_algorithm(
@@ -869,6 +924,9 @@ def set_contractor(
869924
**kws,
870925
)
871926

927+
elif method == "before": # a hack way to get the nodes
928+
cf = _get_sorted_nodes
929+
872930
else:
873931
# cf = getattr(tn.contractors, method, None)
874932
# if not cf:

tests/test_circuit.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,19 @@ def test_circuit_quoperator(backend):
889889
np.testing.assert_allclose(qo.eval_matrix(), c.matrix(), atol=1e-5)
890890

891891

892+
def test_perm_matrix():
893+
from tensorcircuit.translation import perm_matrix
894+
895+
p2 = perm_matrix(2)
896+
np.testing.assert_allclose(
897+
p2, np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
898+
)
899+
p3 = perm_matrix(3)
900+
v = np.arange(8)
901+
vt = np.array([0, 4, 2, 6, 1, 5, 3, 7])
902+
np.testing.assert_allclose(p3 @ v, vt)
903+
904+
892905
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
893906
def test_qir2cirq(backend):
894907
try:

tests/test_miscs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,23 @@ def baseline(params):
312312
return c.expectation_ps(z=[-1])
313313

314314
np.testing.assert_allclose(value, baseline(params), atol=1e-6)
315+
316+
317+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
318+
def test_runtime_nodes_capture(backend):
319+
with tc.cons.runtime_nodes_capture() as captured:
320+
c = tc.Circuit(3)
321+
c.h(0)
322+
c.amplitude("010")
323+
len(captured["nodes"]) == 7
324+
325+
326+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
327+
def test_function_nodes_capture(backend):
328+
@tc.cons.function_nodes_capture
329+
def exp(theta):
330+
c = tc.Circuit(3)
331+
c.h(0)
332+
return c.expectation_ps(z=[-3], reuse=False)
333+
334+
assert len(exp(0.3)) == 9

0 commit comments

Comments
 (0)