Skip to content

Commit c15f9ec

Browse files
add hypernode support for contractor
1 parent d2c63f5 commit c15f9ec

File tree

5 files changed

+594
-441
lines changed

5 files changed

+594
-441
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Benchmark JAX JIT staging and execution time for VQE for two different contraction policies.
3+
Comparing use_primitives=True (Algebraic Path) vs use_primitives=False (Legacy Path).
4+
"""
5+
6+
import time
7+
import jax
8+
import jax.numpy as jnp
9+
import tensorcircuit as tc
10+
11+
# Set backend to JAX
12+
tc.set_backend("jax")
13+
14+
15+
def run_vqe_benchmark(n, nlayers, use_primitives):
16+
# Set the contractor configuration
17+
tc.set_contractor("cotengra", use_primitives=use_primitives)
18+
19+
def energy_fn(params):
20+
c = tc.Circuit(n)
21+
idx = 0
22+
for _ in range(nlayers):
23+
# Single qubit rotations
24+
for i in range(n):
25+
c.rx(i, theta=params[idx])
26+
idx += 1
27+
c.ry(i, theta=params[idx])
28+
idx += 1
29+
# Entangling Rzz gates
30+
for i in range(n - 1):
31+
c.rzz(i, i + 1, theta=params[idx])
32+
idx += 1
33+
34+
# Compute expectation value of a TFIM Hamiltonian
35+
e = 0.0
36+
for i in range(n - 1):
37+
e += c.expectation_ps(z=[i, i + 1])
38+
for i in range(n):
39+
e += c.expectation_ps(x=[i])
40+
return jnp.real(e)
41+
42+
# Use value_and_grad for a more realistic VQE benchmark
43+
val_grad_fn = jax.jit(jax.value_and_grad(energy_fn))
44+
45+
# Prepare random parameters
46+
num_params = nlayers * (2 * n + (n - 1))
47+
params = jax.random.normal(jax.random.PRNGKey(42), (num_params,))
48+
49+
# 1. Staging Time (Compilation + First Execution)
50+
start = time.time()
51+
v, g = val_grad_fn(params)
52+
v.block_until_ready()
53+
# Gradient nodes are usually ready when the value is ready if fused,
54+
# but we can block on jnp.sum(g) to be sure
55+
jnp.sum(g).block_until_ready()
56+
staging_time = time.time() - start
57+
58+
# 2. Execution Time (Subsequent Runs)
59+
iters = 10
60+
start = time.time()
61+
for _ in range(iters):
62+
v, g = val_grad_fn(params)
63+
v.block_until_ready()
64+
jnp.sum(g).block_until_ready()
65+
exec_time = (time.time() - start) / iters
66+
67+
return staging_time, exec_time
68+
69+
70+
def main():
71+
n = 12
72+
nlayers = 5
73+
print(f"--- VQE Benchmark: {n} Qubits, {nlayers} Layers ---")
74+
print(f"Backend: {tc.backend.name}")
75+
76+
# Test use_primitives=False (Legacy Path)
77+
print("\n[Case 1] use_primitives=False (Legacy Path)")
78+
s1, e1 = run_vqe_benchmark(n, nlayers, use_primitives=False)
79+
print(f" Staging Time: {s1:.4f}s")
80+
print(f" Execution Time: {e1:.6f}s")
81+
82+
# Test use_primitives=True (Algebraic Path)
83+
print("\n[Case 2] use_primitives=True (Algebraic Path)")
84+
s2, e2 = run_vqe_benchmark(n, nlayers, use_primitives=True)
85+
print(f" Staging Time: {s2:.4f}s")
86+
print(f" Execution Time: {e2:.6f}s")
87+
88+
print("\n--- Summary ---")
89+
staging_gain = (s1 - s2) / s1 * 100 if s1 > 0 else 0
90+
exec_gain = (e1 - e2) / e1 * 100 if e1 > 0 else 0
91+
print(f"Staging Speedup: {staging_gain:.2f}%")
92+
print(f"Execution Speedup: {exec_gain:.2f}%")
93+
94+
95+
if __name__ == "__main__":
96+
main()
97+
98+
# to me, the diff is not significant

examples/hyperedge_demo.py

Lines changed: 0 additions & 78 deletions
This file was deleted.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Physics-relevant demonstration of hyperedge support in TensorCircuit.
3+
Computing the partition function of a 2D classical Ising model using CopyNodes.
4+
"""
5+
6+
import time
7+
import jax
8+
import jax.numpy as jnp
9+
import tensornetwork as tn
10+
import tensorcircuit as tc
11+
12+
# Set backend to JAX for JIT and AD support
13+
tc.set_backend("jax")
14+
tc.set_dtype("complex128")
15+
tc.set_contractor("cotengra")
16+
17+
18+
def ising_partition_function(L, beta, J=1.0):
19+
"""
20+
Compute the partition function of a 2D Ising model on an L x L grid.
21+
Uses CopyNodes to represent spins and 2-index tensors for Boltzmann factors.
22+
23+
The partition function is Z = sum_{s} exp(beta * J * sum_{<i,j>} s_i * s_j).
24+
Each site i has a spin s_i in {1, -1}.
25+
Each bond <i,j> contributes a factor exp(beta * J * s_i * s_j).
26+
"""
27+
# Boltzmann factor matrix M_{si, sj} = exp(beta * J * si * sj)
28+
# spins are {1, -1}, mapped to indices {0, 1}
29+
# si*sj = 1 if si==sj (indices 00 or 11), -1 if si!=sj (indices 01 or 10)
30+
# Using jnp.exp to allow AD through beta
31+
M = jnp.array(
32+
[
33+
[jnp.exp(beta * J), jnp.exp(-beta * J)],
34+
[jnp.exp(-beta * J), jnp.exp(beta * J)],
35+
]
36+
)
37+
38+
nodes = []
39+
# Grid of CopyNodes (Delta tensors) representing the spins
40+
grid = [[None for _ in range(L)] for _ in range(L)]
41+
42+
for i in range(L):
43+
for j in range(L):
44+
# Determine degree of CopyNode based on neighbors (open BC)
45+
degree = 0
46+
if i > 0:
47+
degree += 1
48+
if i < L - 1:
49+
degree += 1
50+
if j > 0:
51+
degree += 1
52+
if j < L - 1:
53+
degree += 1
54+
55+
# A CopyNode(degree, 2) enforces that all connected legs have the same value (spin state)
56+
cn = tn.CopyNode(degree, 2, name=f"site_{i}_{j}")
57+
grid[i][j] = cn
58+
59+
# Track which axis of each CopyNode is used as we connect bonds
60+
axis_ptr = [[0 for _ in range(L)] for _ in range(L)]
61+
62+
# Add bond tensors and connect to the CopyNodes
63+
for i in range(L):
64+
for j in range(L):
65+
# Horizontal bond to the right
66+
if j < L - 1:
67+
bond_h = tn.Node(M, name=f"bond_h_{i}_{j}")
68+
nodes.append(bond_h)
69+
grid[i][j][axis_ptr[i][j]] ^ bond_h[0]
70+
grid[i][j + 1][axis_ptr[i][j + 1]] ^ bond_h[1]
71+
axis_ptr[i][j] += 1
72+
axis_ptr[i][j + 1] += 1
73+
74+
# Vertical bond downwards
75+
if i < L - 1:
76+
bond_v = tn.Node(M, name=f"bond_v_{i}_{j}")
77+
nodes.append(bond_v)
78+
grid[i][j][axis_ptr[i][j]] ^ bond_v[0]
79+
grid[i + 1][j][axis_ptr[i + 1][j]] ^ bond_v[1]
80+
axis_ptr[i][j] += 1
81+
axis_ptr[i + 1][j] += 1
82+
83+
# Multi-node contraction with cotengra (which handles hyperedges efficiently)
84+
# The algebraic path is triggered automatically because CopyNodes are present.
85+
all_nodes = nodes + [grid[i][j] for i in range(L) for j in range(L)]
86+
87+
# Ensure cotengra is used for high-performance contraction
88+
z_node = tc.contractor(all_nodes)
89+
90+
return z_node.tensor
91+
92+
93+
def main():
94+
L = 8
95+
J = 1.0
96+
beta = 0.4 # Near critical point beta_c approx 0.44 for 2D Ising
97+
print(f"--- 2D Ising Model Partition Function ({L}x{L} lattice) ---")
98+
print(f"Backend: {tc.backend.name}, Parameters: J={J}, beta={beta}")
99+
100+
# 1. Direct computation
101+
z = ising_partition_function(L, beta, J)
102+
print(f"Z({beta}) = {z:.6f}")
103+
104+
# 2. JIT-compiled version
105+
# JIT significantly accelerates repeated calls with the same topology
106+
print("\nDemonstrating JIT acceleration...")
107+
ising_jit = jax.jit(ising_partition_function, static_argnums=(0,))
108+
109+
start = time.time()
110+
_ = ising_jit(L, beta, J)
111+
print(f"First run (with JIT warmup): {time.time() - start:.4f}s")
112+
113+
start = time.time()
114+
_ = ising_jit(L, beta, J)
115+
print(f"Second run (JIT cached): {time.time() - start:.4f}s")
116+
117+
# 3. Automatic Differentiation (AD)
118+
# Internal Energy U = - d(ln Z) / d(beta)
119+
print("\nComputing Internal Energy via Automatic Differentiation...")
120+
121+
def log_z(beta_val):
122+
# We take the real part as the partition function is real
123+
val = ising_partition_function(L, beta_val, J)
124+
return jnp.log(tc.backend.real(val))
125+
126+
energy_fn = jax.grad(log_z)
127+
energy = energy_fn(beta)
128+
129+
# U = -d(ln Z)/d(beta) in our convention
130+
print(f"Expectation of Internal Energy <E> = {-energy:.6f}")
131+
132+
# 4. Scaling demonstration
133+
L_larger = 12
134+
print(f"\nScaling check: L={L_larger} ({L_larger*L_larger} spins)...")
135+
start = time.time()
136+
z_large = ising_jit(L_larger, beta, J)
137+
print(
138+
f"Result for L={L_larger}: {z_large:.2e} (computed in {time.time()-start:.4f}s)"
139+
)
140+
141+
142+
if __name__ == "__main__":
143+
main()

0 commit comments

Comments
 (0)