Skip to content

Commit 4a53d86

Browse files
multi host support
1 parent 0f89bb7 commit 4a53d86

File tree

3 files changed

+360
-57
lines changed

3 files changed

+360
-57
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
- Add `su4` as a generic parameterized two-qubit gates.
1414

15+
- Add multi controller jax support for distrubuted contraction.
16+
1517
### Fixed
1618

1719
- Fix the breaking logic change in jax from dlpack API, dlcapsule -> tensor.

examples/multihost_vqe.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import os
2+
import time
3+
import argparse
4+
import logging
5+
6+
import jax
7+
import jax.distributed
8+
import numpy as np
9+
import optax
10+
import tensornetwork as tn
11+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
12+
13+
import tensorcircuit as tc
14+
from tensorcircuit.experimental import DistributedContractor, broadcast_py_object
15+
16+
17+
# --- Static Configuration ---
18+
NUM_DEVICES_TOTAL = 4
19+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES_TOTAL}"
20+
# delete the above fake lines when using GPU devices
21+
logging.basicConfig(level=logging.INFO)
22+
logger = logging.getLogger(__name__)
23+
24+
K = tc.set_backend("jax")
25+
tc.set_dtype("complex64")
26+
27+
N_QUBITS = 10
28+
DEPTH = 4
29+
30+
31+
def circuit_ansatz(n, d, params):
32+
c = tc.Circuit(n)
33+
c.h(range(n))
34+
for i in range(d):
35+
for j in range(0, n - 1):
36+
c.rzz(j, j + 1, theta=params[j, i, 0])
37+
for j in range(n):
38+
c.rx(j, theta=params[j, i, 1])
39+
for j in range(n):
40+
c.ry(j, theta=params[j, i, 2])
41+
return c
42+
43+
44+
def get_tfi_mpo(n):
45+
Jx = np.ones(n - 1)
46+
Bz = -1.0 * np.ones(n)
47+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
48+
return tc.quantum.tn2qop(tn_mpo)
49+
50+
51+
def get_nodes_fn(n, d, mpo):
52+
def nodes_fn(params):
53+
psi = circuit_ansatz(n, d, params).get_quvector()
54+
expression = psi.adjoint() @ mpo @ psi
55+
return expression.nodes
56+
57+
return nodes_fn
58+
59+
60+
def run_vqe_main(coordinator_address: str, num_processes: int, process_id: int):
61+
"""
62+
Main logic run by ALL processes.
63+
"""
64+
jax.distributed.initialize(
65+
coordinator_address=coordinator_address,
66+
num_processes=num_processes,
67+
process_id=process_id,
68+
)
69+
print(
70+
f"[Process {process_id}] Initialized. jax.process_index() reports: {jax.process_index()}"
71+
)
72+
73+
global_mesh = Mesh(jax.devices(), axis_names=("devices",))
74+
if jax.process_index() == 0:
75+
print(f"--- Global mesh created with devices: {global_mesh.devices}")
76+
77+
tfi_mpo = get_tfi_mpo(N_QUBITS)
78+
nodes_fn = get_nodes_fn(N_QUBITS, DEPTH, tfi_mpo)
79+
params_shape = [N_QUBITS, DEPTH, 3]
80+
81+
# --- KEY CHANGE: Create params on host 0 and broadcast to all others ---
82+
params_cpu = None
83+
if jax.process_index() == 0:
84+
key = jax.random.PRNGKey(42)
85+
params_cpu = (
86+
jax.random.normal(key, shape=params_shape, dtype=tc.rdtypestr) * 0.1
87+
)
88+
89+
# Broadcast the CPU array. Now all processes have a concrete `params_cpu`.
90+
# This is CRITICAL to prevent the NoneType error upon contractor initialization.
91+
params_cpu = broadcast_py_object(params_cpu)
92+
93+
# Now that all processes have `params_cpu`, we can initialize the contractor safely.
94+
# The contractor will use this concrete array to run its (now internal)
95+
# "find path on 0 and broadcast" logic.
96+
DC = DistributedContractor(
97+
nodes_fn=nodes_fn,
98+
params=params_cpu,
99+
mesh=global_mesh,
100+
cotengra_options={
101+
"slicing_reconf_opts": {"target_size": 2**8},
102+
"max_repeats": 16,
103+
"progbar": True,
104+
"minimize": "write",
105+
"parallel": 4,
106+
},
107+
)
108+
109+
# Shard the parameters onto devices for the actual GPU/TPU computation.
110+
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
111+
params = jax.device_put(params_cpu, params_sharding)
112+
113+
# Initialize the optimizer and its state.
114+
optimizer = optax.adam(2e-2)
115+
opt_state = optimizer.init(params) # Can init directly with sharded params
116+
117+
@jax.jit
118+
def opt_update(params, opt_state, grads):
119+
updates, new_opt_state = optimizer.update(grads, opt_state, params)
120+
new_params = optax.apply_updates(params, updates)
121+
return new_params, new_opt_state
122+
123+
# Run the optimization loop.
124+
n_steps = 100
125+
if jax.process_index() == 0:
126+
print("\nStarting VQE optimization loop...")
127+
128+
for i in range(n_steps):
129+
t0 = time.time()
130+
loss, grads = DC.value_and_grad(params)
131+
params, opt_state = opt_update(params, opt_state, grads)
132+
t1 = time.time()
133+
134+
if jax.process_index() == 0:
135+
print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s")
136+
137+
jax.distributed.shutdown()
138+
139+
140+
if __name__ == "__main__":
141+
parser = argparse.ArgumentParser(description="JAX Multi-Host VQE Simulation")
142+
parser.add_argument(
143+
"--process_id",
144+
type=int,
145+
required=True,
146+
help="Rank of the current process (e.g., 0, 1).",
147+
)
148+
parser.add_argument(
149+
"--num_processes", type=int, default=2, help="Total number of processes."
150+
)
151+
parser.add_argument(
152+
"--coordinator_address",
153+
type=str,
154+
default="127.0.0.1:8888",
155+
help="IP address and port of the coordinator (process 0).",
156+
)
157+
args = parser.parse_args()
158+
159+
print(f"--- Starting Process {args.process_id}/{args.num_processes} ---")
160+
161+
if args.process_id == 0:
162+
print(
163+
"\n>>> This is the coordinator process. Waiting for other processes to connect."
164+
)
165+
166+
run_vqe_main(
167+
coordinator_address=args.coordinator_address,
168+
num_processes=args.num_processes,
169+
process_id=args.process_id,
170+
)
171+
172+
# 5090: CUDA_VISIBLE_DEVICES=0,1 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
173+
# CUDA_VISIBLE_DEVICES=2,3 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
174+
# H200: no need to disable P2P and SHM due to well configured nvlink

0 commit comments

Comments
 (0)