Skip to content

Commit 14fcdb0

Browse files
more options for benchmark cli
1 parent 71fffb4 commit 14fcdb0

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

benchmarks/scripts_v2/benchmark_cli.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import json
44
import argparse
55
import time
6-
import numpy as np
7-
import cpuinfo
86
import sys
97
import datetime
8+
import numpy as np
9+
import cpuinfo
1010
import tensorcircuit as tc
1111

1212
from benchmark_core import benchmark_mega_function
@@ -131,6 +131,22 @@ def arg():
131131
help="Contractor setting (e.g., cotengra-16-128)",
132132
default=[None],
133133
)
134+
parser.add_argument(
135+
"-bond_dim",
136+
dest="bond_dim",
137+
type=int,
138+
nargs=1,
139+
help="Bond dimension for MPS circuits",
140+
default=[16],
141+
)
142+
parser.add_argument(
143+
"-jit_compile",
144+
dest="jit_compile",
145+
type=int,
146+
nargs=1,
147+
help="Whether to use JIT compilation (0 or 1)",
148+
default=[1],
149+
)
134150
args = parser.parse_args()
135151
return [
136152
args.n[0],
@@ -152,6 +168,8 @@ def arg():
152168
args.backend[0],
153169
args.dtype[0],
154170
args.contractor[0],
171+
args.bond_dim[0],
172+
args.jit_compile[0],
155173
]
156174

157175

@@ -204,6 +222,8 @@ def benchmark_cli(
204222
backend,
205223
dtype,
206224
contractor,
225+
bond_dim,
226+
jit_compile,
207227
path,
208228
):
209229
meta = {}
@@ -244,6 +264,8 @@ def benchmark_cli(
244264
"backend": backend,
245265
"dtype": dtype,
246266
"contractor": contractor,
267+
"bond_dim": bond_dim,
268+
"jit_compile": jit_compile,
247269
}
248270
meta["UUID"] = uuid
249271
meta["Benchmark Time"] = (
@@ -258,13 +280,15 @@ def benchmark_cli(
258280
lx=lx,
259281
ly=ly,
260282
circuit_type=circuit_type,
283+
bond_dim=bond_dim,
261284
layout_type=layout_type,
262285
operation=operation,
263286
noisy=bool(noisy),
264287
noisy_type=noisy_type,
265288
use_grad=bool(use_grad),
266289
use_vmap=bool(use_vmap),
267290
contractor=contractor,
291+
jit_compile=bool(jit_compile),
268292
)
269293

270294
# Create parameters for testing
@@ -313,6 +337,8 @@ def benchmark_cli(
313337
backend,
314338
dtype,
315339
contractor,
340+
bond_dim,
341+
jit_compile,
316342
) = arg()
317343

318344
results = benchmark_cli(
@@ -335,6 +361,8 @@ def benchmark_cli(
335361
backend,
336362
dtype,
337363
contractor,
364+
bond_dim,
365+
jit_compile,
338366
path,
339367
)
340368
save(results, _uuid, path)

benchmarks/scripts_v2/benchmark_core.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def generate_2d_circuit(c, lx, ly, params, nqubits, nlayers):
3636

3737
def generate_noisy_circuit(c, status, type="depolarizing"):
3838
noise_conf = noisemodel.NoiseConf()
39-
# print(type)
4039
if type == "depolarizing":
4140
error1 = channels.depolarizingchannel(0.1, 0.1, 0.1) # px, py, pz probabilities
4241
elif type == "amplitudedamping":
@@ -78,6 +77,7 @@ def benchmark_mega_function(
7877
use_grad=False, # True, False
7978
use_vmap=False, # True, False
8079
contractor=None, # contractor setting like "cotengra-16-128"
80+
jit_compile=True, # True, False
8181
):
8282
"""
8383
Mega benchmark function that can control all parameters via arguments.
@@ -88,12 +88,13 @@ def benchmark_mega_function(
8888
lx: Lattice size x (for 2D)
8989
ly: Lattice size y (for 2D)
9090
circuit_type: Type of circuit ("circuit", "dmcircuit", "mpscircuit")
91+
bond_dim: Bond dimension for MPS circuits
9192
layout_type: Circuit layout ("1d", "2d")
9293
operation: Operation to perform ("state", "sample", "exps")
9394
noisy: Whether to add noise (only for "circuit" and "dmcircuit")
95+
noisy_type: Type of noise channel ("depolarizing", "amplitudedamping")
9496
use_grad: Whether to compute gradient (AD)
9597
use_vmap: Whether to use vectorized operations
96-
batch_size: Batch size for vmap operations
9798
contractor: Contractor setting like "cotengra-16-128"
9899
99100
Returns:
@@ -136,17 +137,17 @@ def circuit_func(params):
136137
# Handle gradient computation
137138
if use_grad and not use_vmap:
138139
grad_func = tc.backend.grad(circuit_func)
139-
return tc.backend.jit(grad_func, jit_compile=True)
140+
return tc.backend.jit(grad_func, jit_compile=jit_compile)
140141

141142
# Handle vmap computation
142143
if use_vmap and not use_grad:
143-
return tc.backend.jit(tc.backend.vmap(circuit_func), jit_compile=True)
144+
return tc.backend.jit(tc.backend.vmap(circuit_func), jit_compile=jit_compile)
144145

145146
# Handle both grad and vmap
146147
if use_grad and use_vmap:
147148
vvag_func = tc.backend.vvag(circuit_func)
148-
return tc.backend.jit(vvag_func, jit_compile=True)
149+
return tc.backend.jit(vvag_func, jit_compile=jit_compile)
149150

150151
# Regular operation (no grad, no vmap)
151152
# Always JIT the returned function
152-
return tc.backend.jit(circuit_func, jit_compile=True)
153+
return tc.backend.jit(circuit_func, jit_compile=jit_compile)

0 commit comments

Comments
 (0)