|
| 1 | +""" |
| 2 | +multi controller distributed VQE on one server |
| 3 | +""" |
| 4 | + |
| 5 | +import os |
| 6 | +import time |
| 7 | +import argparse |
| 8 | +import logging |
| 9 | + |
| 10 | +import jax |
| 11 | +import jax.distributed |
| 12 | +import numpy as np |
| 13 | +import optax |
| 14 | +import tensornetwork as tn |
| 15 | +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P |
| 16 | + |
| 17 | +import tensorcircuit as tc |
| 18 | +from tensorcircuit.experimental import DistributedContractor, broadcast_py_object |
| 19 | + |
| 20 | + |
| 21 | +# --- Static Configuration --- |
| 22 | +NUM_DEVICES_TOTAL = 4 |
| 23 | +os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES_TOTAL}" |
| 24 | +# delete the above fake lines when using GPU devices |
| 25 | +logging.basicConfig(level=logging.INFO) |
| 26 | +logger = logging.getLogger(__name__) |
| 27 | + |
| 28 | +K = tc.set_backend("jax") |
| 29 | +tc.set_dtype("complex64") |
| 30 | + |
| 31 | +N_QUBITS = 10 |
| 32 | +DEPTH = 4 |
| 33 | + |
| 34 | + |
| 35 | +def circuit_ansatz(n, d, params): |
| 36 | + c = tc.Circuit(n) |
| 37 | + c.h(range(n)) |
| 38 | + for i in range(d): |
| 39 | + for j in range(0, n - 1): |
| 40 | + c.rzz(j, j + 1, theta=params[j, i, 0]) |
| 41 | + for j in range(n): |
| 42 | + c.rx(j, theta=params[j, i, 1]) |
| 43 | + for j in range(n): |
| 44 | + c.ry(j, theta=params[j, i, 2]) |
| 45 | + return c |
| 46 | + |
| 47 | + |
| 48 | +def get_tfi_mpo(n): |
| 49 | + Jx = np.ones(n - 1) |
| 50 | + Bz = -1.0 * np.ones(n) |
| 51 | + tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64) |
| 52 | + return tc.quantum.tn2qop(tn_mpo) |
| 53 | + |
| 54 | + |
| 55 | +def get_nodes_fn(n, d, mpo): |
| 56 | + def nodes_fn(params): |
| 57 | + psi = circuit_ansatz(n, d, params).get_quvector() |
| 58 | + expression = psi.adjoint() @ mpo @ psi |
| 59 | + return expression.nodes |
| 60 | + |
| 61 | + return nodes_fn |
| 62 | + |
| 63 | + |
| 64 | +def run_vqe_main(coordinator_address: str, num_processes: int, process_id: int): |
| 65 | + """ |
| 66 | + Main logic run by ALL processes. |
| 67 | + """ |
| 68 | + jax.distributed.initialize( |
| 69 | + coordinator_address=coordinator_address, |
| 70 | + num_processes=num_processes, |
| 71 | + process_id=process_id, |
| 72 | + ) |
| 73 | + print( |
| 74 | + f"[Process {process_id}] Initialized. jax.process_index() reports: {jax.process_index()}" |
| 75 | + ) |
| 76 | + |
| 77 | + global_mesh = Mesh(jax.devices(), axis_names=("devices",)) |
| 78 | + if jax.process_index() == 0: |
| 79 | + print(f"--- Global mesh created with devices: {global_mesh.devices}") |
| 80 | + |
| 81 | + tfi_mpo = get_tfi_mpo(N_QUBITS) |
| 82 | + nodes_fn = get_nodes_fn(N_QUBITS, DEPTH, tfi_mpo) |
| 83 | + params_shape = [N_QUBITS, DEPTH, 3] |
| 84 | + |
| 85 | + # --- KEY CHANGE: Create params on host 0 and broadcast to all others --- |
| 86 | + params_cpu = None |
| 87 | + if jax.process_index() == 0: |
| 88 | + key = jax.random.PRNGKey(42) |
| 89 | + params_cpu = ( |
| 90 | + jax.random.normal(key, shape=params_shape, dtype=tc.rdtypestr) * 0.1 |
| 91 | + ) |
| 92 | + |
| 93 | + # Broadcast the CPU array. Now all processes have a concrete `params_cpu`. |
| 94 | + # This is CRITICAL to prevent the NoneType error upon contractor initialization. |
| 95 | + params_cpu = broadcast_py_object(params_cpu) |
| 96 | + |
| 97 | + # Now that all processes have `params_cpu`, we can initialize the contractor safely. |
| 98 | + # The contractor will use this concrete array to run its (now internal) |
| 99 | + # "find path on 0 and broadcast" logic. |
| 100 | + DC = DistributedContractor( |
| 101 | + nodes_fn=nodes_fn, |
| 102 | + params=params_cpu, |
| 103 | + mesh=global_mesh, |
| 104 | + cotengra_options={ |
| 105 | + "slicing_reconf_opts": {"target_size": 2**8}, |
| 106 | + "max_repeats": 16, |
| 107 | + "progbar": True, |
| 108 | + "minimize": "write", |
| 109 | + "parallel": 4, |
| 110 | + }, |
| 111 | + ) |
| 112 | + |
| 113 | + # Shard the parameters onto devices for the actual GPU/TPU computation. |
| 114 | + params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape)))) |
| 115 | + params = jax.device_put(params_cpu, params_sharding) |
| 116 | + |
| 117 | + # Initialize the optimizer and its state. |
| 118 | + optimizer = optax.adam(2e-2) |
| 119 | + opt_state = optimizer.init(params) # Can init directly with sharded params |
| 120 | + |
| 121 | + @jax.jit |
| 122 | + def opt_update(params, opt_state, grads): |
| 123 | + updates, new_opt_state = optimizer.update(grads, opt_state, params) |
| 124 | + new_params = optax.apply_updates(params, updates) |
| 125 | + return new_params, new_opt_state |
| 126 | + |
| 127 | + # Run the optimization loop. |
| 128 | + n_steps = 100 |
| 129 | + if jax.process_index() == 0: |
| 130 | + print("\nStarting VQE optimization loop...") |
| 131 | + |
| 132 | + for i in range(n_steps): |
| 133 | + t0 = time.time() |
| 134 | + loss, grads = DC.value_and_grad(params) |
| 135 | + params, opt_state = opt_update(params, opt_state, grads) |
| 136 | + t1 = time.time() |
| 137 | + |
| 138 | + if jax.process_index() == 0: |
| 139 | + print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s") |
| 140 | + |
| 141 | + jax.distributed.shutdown() |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + parser = argparse.ArgumentParser(description="JAX Multi-Host VQE Simulation") |
| 146 | + parser.add_argument( |
| 147 | + "--process_id", |
| 148 | + type=int, |
| 149 | + required=True, |
| 150 | + help="Rank of the current process (e.g., 0, 1).", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--num_processes", type=int, default=2, help="Total number of processes." |
| 154 | + ) |
| 155 | + parser.add_argument( |
| 156 | + "--coordinator_address", |
| 157 | + type=str, |
| 158 | + default="127.0.0.1:8888", |
| 159 | + help="IP address and port of the coordinator (process 0).", |
| 160 | + ) |
| 161 | + args = parser.parse_args() |
| 162 | + |
| 163 | + print(f"--- Starting Process {args.process_id}/{args.num_processes} ---") |
| 164 | + |
| 165 | + if args.process_id == 0: |
| 166 | + print( |
| 167 | + "\n>>> This is the coordinator process. Waiting for other processes to connect." |
| 168 | + ) |
| 169 | + |
| 170 | + run_vqe_main( |
| 171 | + coordinator_address=args.coordinator_address, |
| 172 | + num_processes=args.num_processes, |
| 173 | + process_id=args.process_id, |
| 174 | + ) |
| 175 | + |
| 176 | +# 5090: CUDA_VISIBLE_DEVICES=0,1 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1 |
| 177 | +# CUDA_VISIBLE_DEVICES=2,3 NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1 python multihost_vqe.py --process_id=1 |
| 178 | +# H200: no need to disable P2P and SHM due to well configured nvlink |
0 commit comments