Skip to content

Commit 54b9f16

Browse files
robust slurm cluster support
1 parent 4a53d86 commit 54b9f16

File tree

8 files changed

+884
-41
lines changed

8 files changed

+884
-41
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
data
2+
tree.pkl
23
*.bk
34
git_stat_v2.sh
45
*.token

examples/multi_host/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
## single host multi controller
2+
3+
- `multicontroller_vqe.py`: one end-to-end script.
4+
5+
- `pathfinding.py` + `multicontroller_vqe_with_path.py`: path search is separated to save the GPU time.
6+
7+
## multiple host managed by slurm
8+
9+
- `pathfind.py` + `slurm_vqe_with_path.py`: used in a slurm cluster. The slurm batch script is `slurm_submit.sh`.
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
multi controller distributed VQE on one server
3+
"""
4+
5+
import os
6+
import time
7+
import argparse
8+
import logging
9+
10+
import jax
11+
import jax.distributed
12+
import numpy as np
13+
import optax
14+
import tensornetwork as tn
15+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
16+
17+
import tensorcircuit as tc
18+
from tensorcircuit.experimental import DistributedContractor, broadcast_py_object
19+
20+
21+
# --- Static Configuration ---
22+
NUM_DEVICES_TOTAL = 4
23+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES_TOTAL}"
24+
# delete the above fake lines when using GPU devices
25+
logging.basicConfig(level=logging.INFO)
26+
logger = logging.getLogger(__name__)
27+
28+
K = tc.set_backend("jax")
29+
tc.set_dtype("complex64")
30+
31+
N_QUBITS = 10
32+
DEPTH = 4
33+
34+
35+
def circuit_ansatz(n, d, params):
36+
c = tc.Circuit(n)
37+
c.h(range(n))
38+
for i in range(d):
39+
for j in range(0, n - 1):
40+
c.rzz(j, j + 1, theta=params[j, i, 0])
41+
for j in range(n):
42+
c.rx(j, theta=params[j, i, 1])
43+
for j in range(n):
44+
c.ry(j, theta=params[j, i, 2])
45+
return c
46+
47+
48+
def get_tfi_mpo(n):
49+
Jx = np.ones(n - 1)
50+
Bz = -1.0 * np.ones(n)
51+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
52+
return tc.quantum.tn2qop(tn_mpo)
53+
54+
55+
def get_nodes_fn(n, d, mpo):
56+
def nodes_fn(params):
57+
psi = circuit_ansatz(n, d, params).get_quvector()
58+
expression = psi.adjoint() @ mpo @ psi
59+
return expression.nodes
60+
61+
return nodes_fn
62+
63+
64+
def run_vqe_main(coordinator_address: str, num_processes: int, process_id: int):
65+
"""
66+
Main logic run by ALL processes.
67+
"""
68+
jax.distributed.initialize(
69+
coordinator_address=coordinator_address,
70+
num_processes=num_processes,
71+
process_id=process_id,
72+
)
73+
print(
74+
f"[Process {process_id}] Initialized. jax.process_index() reports: {jax.process_index()}"
75+
)
76+
77+
global_mesh = Mesh(jax.devices(), axis_names=("devices",))
78+
if jax.process_index() == 0:
79+
print(f"--- Global mesh created with devices: {global_mesh.devices}")
80+
81+
tfi_mpo = get_tfi_mpo(N_QUBITS)
82+
nodes_fn = get_nodes_fn(N_QUBITS, DEPTH, tfi_mpo)
83+
params_shape = [N_QUBITS, DEPTH, 3]
84+
85+
# --- KEY CHANGE: Create params on host 0 and broadcast to all others ---
86+
params_cpu = None
87+
if jax.process_index() == 0:
88+
key = jax.random.PRNGKey(42)
89+
params_cpu = (
90+
jax.random.normal(key, shape=params_shape, dtype=tc.rdtypestr) * 0.1
91+
)
92+
93+
# Broadcast the CPU array. Now all processes have a concrete `params_cpu`.
94+
# This is CRITICAL to prevent the NoneType error upon contractor initialization.
95+
params_cpu = broadcast_py_object(params_cpu)
96+
97+
# Now that all processes have `params_cpu`, we can initialize the contractor safely.
98+
# The contractor will use this concrete array to run its (now internal)
99+
# "find path on 0 and broadcast" logic.
100+
DC = DistributedContractor(
101+
nodes_fn=nodes_fn,
102+
params=params_cpu,
103+
mesh=global_mesh,
104+
cotengra_options={
105+
"slicing_reconf_opts": {"target_size": 2**8},
106+
"max_repeats": 16,
107+
"progbar": True,
108+
"minimize": "write",
109+
"parallel": 4,
110+
},
111+
)
112+
113+
# Shard the parameters onto devices for the actual GPU/TPU computation.
114+
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
115+
params = jax.device_put(params_cpu, params_sharding)
116+
117+
# Initialize the optimizer and its state.
118+
optimizer = optax.adam(2e-2)
119+
opt_state = optimizer.init(params) # Can init directly with sharded params
120+
121+
@jax.jit
122+
def opt_update(params, opt_state, grads):
123+
updates, new_opt_state = optimizer.update(grads, opt_state, params)
124+
new_params = optax.apply_updates(params, updates)
125+
return new_params, new_opt_state
126+
127+
# Run the optimization loop.
128+
n_steps = 100
129+
if jax.process_index() == 0:
130+
print("\nStarting VQE optimization loop...")
131+
132+
for i in range(n_steps):
133+
t0 = time.time()
134+
loss, grads = DC.value_and_grad(params)
135+
params, opt_state = opt_update(params, opt_state, grads)
136+
t1 = time.time()
137+
138+
if jax.process_index() == 0:
139+
print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s")
140+
141+
jax.distributed.shutdown()
142+
143+
144+
if __name__ == "__main__":
145+
parser = argparse.ArgumentParser(description="JAX Multi-Host VQE Simulation")
146+
parser.add_argument(
147+
"--process_id",
148+
type=int,
149+
required=True,
150+
help="Rank of the current process (e.g., 0, 1).",
151+
)
152+
parser.add_argument(
153+
"--num_processes", type=int, default=2, help="Total number of processes."
154+
)
155+
parser.add_argument(
156+
"--coordinator_address",
157+
type=str,
158+
default="127.0.0.1:8888",
159+
help="IP address and port of the coordinator (process 0).",
160+
)
161+
args = parser.parse_args()
162+
163+
print(f"--- Starting Process {args.process_id}/{args.num_processes} ---")
164+
165+
if args.process_id == 0:
166+
print(
167+
"\n>>> This is the coordinator process. Waiting for other processes to connect."
168+
)
169+
170+
run_vqe_main(
171+
coordinator_address=args.coordinator_address,
172+
num_processes=args.num_processes,
173+
process_id=args.process_id,
174+
)
175+
176+
# 5090: CUDA_VISIBLE_DEVICES=0,1 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
177+
# CUDA_VISIBLE_DEVICES=2,3 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
178+
# H200: no need to disable P2P and SHM due to well configured nvlink
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
multi controller distributed VQE on one server
3+
"""
4+
5+
import os
6+
import time
7+
import argparse
8+
import logging
9+
10+
import jax
11+
import jax.distributed
12+
import numpy as np
13+
import optax
14+
import tensornetwork as tn
15+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
16+
17+
import tensorcircuit as tc
18+
from tensorcircuit.experimental import DistributedContractor, broadcast_py_object
19+
20+
21+
# --- Static Configuration ---
22+
NUM_DEVICES_TOTAL = 4
23+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES_TOTAL}"
24+
# delete the above fake lines when using GPU devices
25+
logging.basicConfig(level=logging.INFO)
26+
logger = logging.getLogger(__name__)
27+
28+
K = tc.set_backend("jax")
29+
tc.set_dtype("complex64")
30+
31+
N_QUBITS = 15
32+
DEPTH = 5
33+
34+
35+
def circuit_ansatz(n, d, params):
36+
c = tc.Circuit(n)
37+
c.h(range(n))
38+
for i in range(d):
39+
for j in range(0, n - 1):
40+
c.rzz(j, j + 1, theta=params[j, i, 0])
41+
for j in range(n):
42+
c.rx(j, theta=params[j, i, 1])
43+
for j in range(n):
44+
c.ry(j, theta=params[j, i, 2])
45+
return c
46+
47+
48+
def get_tfi_mpo(n):
49+
Jx = np.ones(n - 1)
50+
Bz = -1.0 * np.ones(n)
51+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
52+
return tc.quantum.tn2qop(tn_mpo)
53+
54+
55+
def get_nodes_fn(n, d, mpo):
56+
def nodes_fn(params):
57+
psi = circuit_ansatz(n, d, params).get_quvector()
58+
expression = psi.adjoint() @ mpo @ psi
59+
return expression.nodes
60+
61+
return nodes_fn
62+
63+
64+
def run_vqe_main(coordinator_address: str, num_processes: int, process_id: int):
65+
"""
66+
Main logic run by ALL processes.
67+
"""
68+
jax.distributed.initialize(
69+
coordinator_address=coordinator_address,
70+
num_processes=num_processes,
71+
process_id=process_id,
72+
)
73+
print(
74+
f"[Process {process_id}] Initialized. jax.process_index() reports: {jax.process_index()}"
75+
)
76+
77+
global_mesh = Mesh(jax.devices(), axis_names=("devices",))
78+
if jax.process_index() == 0:
79+
print(f"--- Global mesh created with devices: {global_mesh.devices}")
80+
81+
tfi_mpo = get_tfi_mpo(N_QUBITS)
82+
nodes_fn = get_nodes_fn(N_QUBITS, DEPTH, tfi_mpo)
83+
params_shape = [N_QUBITS, DEPTH, 3]
84+
85+
# --- KEY CHANGE: Create params on host 0 and broadcast to all others ---
86+
params_cpu = None
87+
if jax.process_index() == 0:
88+
key = jax.random.PRNGKey(42)
89+
params_cpu = (
90+
jax.random.normal(key, shape=params_shape, dtype=tc.rdtypestr) * 0.1
91+
)
92+
93+
# Broadcast the CPU array. Now all processes have a concrete `params_cpu`.
94+
# This is CRITICAL to prevent the NoneType error upon contractor initialization.
95+
params_cpu = broadcast_py_object(params_cpu)
96+
97+
# Now that all processes have `params_cpu`, we can initialize the contractor safely.
98+
# The contractor will use this concrete array to run its (now internal)
99+
# "find path on 0 and broadcast" logic.},
100+
101+
DC = DistributedContractor.from_path(
102+
filepath="tree.pkl",
103+
nodes_fn=nodes_fn,
104+
mesh=global_mesh,
105+
)
106+
107+
# Shard the parameters onto devices for the actual GPU/TPU computation.
108+
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
109+
params = jax.device_put(params_cpu, params_sharding)
110+
111+
# Initialize the optimizer and its state.
112+
optimizer = optax.adam(2e-2)
113+
opt_state = optimizer.init(params) # Can init directly with sharded params
114+
115+
@jax.jit
116+
def opt_update(params, opt_state, grads):
117+
updates, new_opt_state = optimizer.update(grads, opt_state, params)
118+
new_params = optax.apply_updates(params, updates)
119+
return new_params, new_opt_state
120+
121+
# Run the optimization loop.
122+
n_steps = 100
123+
if jax.process_index() == 0:
124+
print("\nStarting VQE optimization loop...")
125+
126+
for i in range(n_steps):
127+
t0 = time.time()
128+
loss, grads = DC.value_and_grad(params)
129+
params, opt_state = opt_update(params, opt_state, grads)
130+
t1 = time.time()
131+
132+
if jax.process_index() == 0:
133+
print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s")
134+
135+
jax.distributed.shutdown()
136+
137+
138+
if __name__ == "__main__":
139+
parser = argparse.ArgumentParser(description="JAX Multi-Host VQE Simulation")
140+
parser.add_argument(
141+
"--process_id",
142+
type=int,
143+
required=True,
144+
help="Rank of the current process (e.g., 0, 1).",
145+
)
146+
parser.add_argument(
147+
"--num_processes", type=int, default=2, help="Total number of processes."
148+
)
149+
parser.add_argument(
150+
"--coordinator_address",
151+
type=str,
152+
default="127.0.0.1:8888",
153+
help="IP address and port of the coordinator (process 0).",
154+
)
155+
args = parser.parse_args()
156+
157+
print(f"--- Starting Process {args.process_id}/{args.num_processes} ---")
158+
159+
if args.process_id == 0:
160+
print(
161+
"\n>>> This is the coordinator process. Waiting for other processes to connect."
162+
)
163+
164+
run_vqe_main(
165+
coordinator_address=args.coordinator_address,
166+
num_processes=args.num_processes,
167+
process_id=args.process_id,
168+
)
169+
170+
# 5090: CUDA_VISIBLE_DEVICES=0,1 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
171+
# CUDA_VISIBLE_DEVICES=2,3 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1
172+
# H200: no need to disable P2P and SHM due to well configured nvlink

0 commit comments

Comments
 (0)