Skip to content

Commit 43a8bba

Browse files
add jax function load save
1 parent 114458c commit 43a8bba

File tree

8 files changed

+226
-17
lines changed

8 files changed

+226
-17
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- Add jax jitted function load/save utilities in experimental module
8+
9+
- Add `circuit.to_openqasm_file` function for compatibility of qiskit>1
10+
511
## 1.0.2
612

713
### Added

docs/source/advance.rst

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,65 @@ Advanced Usage
55
MPS Simulator
66
----------------
77

8-
(Still experimental support)
9-
10-
Very simple, we provide the same set of API for ``MPSCircuit`` as ``Circuit``,
8+
Very straightforward to use, we provide the same set of API for ``MPSCircuit`` as ``Circuit``,
119
the only new line is to set the bond dimension for the new simulator.
1210

1311
.. code-block:: python
1412
1513
c = tc.MPSCircuit(n)
1614
c.set_split_rules({"max_singular_values": 50})
1715
18-
The larger bond dimension we set, the better approximation ratio (of course the more computational cost we pay)
16+
The larger bond dimension we set, the better approximation ratio (of course the more computational cost we pay).
17+
18+
19+
Stacked gates
20+
----------------
21+
22+
Stacked gates is a simple grammar sugar to make constructing the circuit easily when multiple gate of the same type are applied on the different qubits, namely, the index for gate function can accept list of ints instead of one integer.
23+
24+
.. code-block:: python
25+
26+
>>> import tensorcircuit as tc
27+
>>> c = tc.Circuit(4)
28+
>>> c.h(range(3))
29+
>>> c.draw()
30+
┌───┐
31+
q_0: ┤ H ├
32+
├───┤
33+
q_1: ┤ H ├
34+
├───┤
35+
q_2: ┤ H ├
36+
└───┘
37+
q_3: ─────
38+
39+
40+
>>> c = tc.Circuit(4)
41+
>>> c.cnot([0, 1], [2, 3])
42+
>>> c.draw()
43+
44+
q_0: ──■───────
45+
46+
q_1: ──┼────■──
47+
┌─┴─┐ │
48+
q_2: ┤ X ├──┼──
49+
└───┘┌─┴─┐
50+
q_3: ─────┤ X ├
51+
└───┘
52+
53+
>>> c = tc.Circuit(4)
54+
>>> c.rx(range(4), theta=tc.backend.convert_to_tensor([0.1, 0.2, 0.3, 0.4]))
55+
>>> c.draw()
56+
┌─────────┐
57+
q_0: ┤ Rx(0.1) ├
58+
├─────────┤
59+
q_1: ┤ Rx(0.2) ├
60+
├─────────┤
61+
q_2: ┤ Rx(0.3) ├
62+
├─────────┤
63+
q_3: ┤ Rx(0.4) ├
64+
└─────────┘
65+
66+
1967
2068
Split Two-qubit Gates
2169
-------------------------
@@ -45,15 +93,92 @@ The two-qubit gates applied on the circuit can be decomposed via SVD, which may
4593
4694
Note ``max_singular_values`` must be specified to make the whole procedure static and thus jittable.
4795

96+
Analog circuit simulation
97+
-----------------------------
98+
99+
TensorCircuit-NG support digital-analog hybrid simulation (say cases in Rydberg atom arrays), where the analog part is simulated by the neural differential equation solver given the API to specify a time dependent Hamiltonian.
100+
The simulation is still differentiable and jittable. Only jax backend is supported for analog simulation as the neural ode engine is built on top of jax.
101+
This utility is super helpful for optimizing quantum control or investigating digital-analog hybrid variational quantum schemes.
102+
We support two modes of analog simulation, where :py:meth:`tensorcircuit.experimentaql.evol_global` evolve the state via a Hamiltonian define on the whole system, and :py:meth:`tensorcircuit.experimentaql.evol_local` evolve the state via a Hamiltonian define on a local subsystem.
103+
104+
.. Note::
105+
106+
``evol_global`` use sparse Hamiltonian while ``evol_local`` use dense Hamiltonian.
107+
108+
109+
.. code-block:: python
110+
111+
# in this demo, we build a jittable and differentiable simulation function `hybrid_evol`
112+
# with both digital gates and local/global analog Hamiltonian evolutions
113+
114+
import optax
115+
import tensorcircuit as tc
116+
from tensorcircuit.experimental import evol_global, evol_local
117+
118+
K = tc.set_backend("jax")
119+
120+
121+
def h_fun(t, b):
122+
return b * tc.gates.x().tensor
123+
124+
125+
hy = tc.quantum.PauliStringSum2COO([[2, 0]])
126+
127+
128+
def h_fun2(t, b):
129+
return b[2] * K.cos(b[0] * t + b[1]) * hy
130+
131+
132+
@K.jit
133+
@K.value_and_grad
134+
def hybrid_evol(params):
135+
c = tc.Circuit(2)
136+
c.x([0, 1])
137+
c = evol_local(c, [1], h_fun, 1.0, params[0])
138+
c.cx(1, 0)
139+
c.h(0)
140+
c = evol_global(c, h_fun2, 1.0, params[1:])
141+
return K.real(c.expectation_ps(z=[0, 1]))
142+
143+
144+
b = K.implicit_randn([4])
145+
v, gs = hybrid_evol(b)
146+
147+
48148
49149
Jitted Function Save/Load
50150
-----------------------------
51151

52152
To reuse the jitted function, we can save it on the disk via support from the TensorFlow `SavedModel <https://www.tensorflow.org/guide/saved_model>`_. That is to say, only jitted quantum function on the TensorFlow backend can be saved on the disk.
53153

154+
We wrap the tf-backend `SavedModel` as very easy-to-use function :py:meth:`tensorcircuit.keras.save_func` and :py:meth:`tensorcircuit.keras.load_func`.
155+
54156
For the JAX-backend quantum function, one can first transform them into the tf-backend function via JAX experimental support: `jax2tf <https://github.com/google/jax/tree/main/jax/experimental/jax2tf>`_.
55157

56-
We wrap the tf-backend `SavedModel` as very easy-to-use function :py:meth:`tensorcircuit.keras.save_func` and :py:meth:`tensorcircuit.keras.load_func`.
158+
**Updates**: jax now also support jitted function save/load via ``export`` module, see `jax documentation <https://jax.readthedocs.io/en/latest/export/export.html>_`.
159+
160+
We wrape the jax function export capability in ``experimental`` module and can be used as follows
161+
162+
.. code-block:: python
163+
164+
from tensorcircuit import experimental
165+
166+
K = tc.set_backend("jax")
167+
168+
@K.jit
169+
def f(weights):
170+
c = tc.Circuit(3)
171+
c.rx(range(3), theta=weights)
172+
return K.real(c.expectation_ps(z=[0]))
173+
174+
print(f(K.ones([3])))
175+
176+
experimental.jax_jitted_function_save("temp.bin", f, K.ones([3]))
177+
178+
f_load = tc.experimental.jax_jitted_function_load("temp.bin")
179+
f_load(K.ones([3]))
180+
181+
57182
58183
Parameterized Measurements
59184
-----------------------------

docs/source/quickstart.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ For more details on docker setup, please refer to `docker readme <https://github
2929

3030
- For Windows, due to the lack of support for Jax, we recommend to use docker or WSL, please refer to `TC via windows docker <contribs/development_windows.html>`_ or `TC via WSL <contribs/development_wsl2.html>`_.
3131

32-
- For MacOS, please refer to `TC on Mac <contribs/development_Mac.html>`_.
33-
34-
Overall, the installation of TensorCircuit is simple, since it is purely in Python and hence very portable.
35-
As long as the users can take care of the installation of ML frameworks on the corresponding system, TensorCircuit will work as expected.
32+
Overall, the installation of TensorCircuit-NG is simple, since it is purely in Python and hence very portable.
33+
As long as the users can take care of the installation of ML frameworks on the corresponding system, TensorCircuit-NG will work as expected.
3634

3735
To debug the installation issue or report bugs, please check the environment information by ``tc.about()``.
3836

tensorcircuit/abstractcircuit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,16 @@ def to_openqasm(self, **kws: Any) -> str:
792792
qasm_str = dumps(qc) # type: ignore
793793
return qasm_str # type: ignore
794794

795+
def to_openqasm_file(self, file: str, **kws: Any) -> None:
796+
"""
797+
save the circuit to openqasm file
798+
799+
:param file: the file path to save the circuit
800+
:type file: str
801+
"""
802+
with open(file, "w") as f:
803+
f.write(self.to_openqasm(**kws))
804+
795805
@classmethod
796806
def from_openqasm(
797807
cls,

tensorcircuit/experimental.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,6 @@ def grad_f(*args: Any, **kws: Any) -> Any:
350350
return grad_f
351351

352352

353-
# TODO(@refraction-ray): add SPSA gradient wrapper similar to parameter shift
354-
# -- using noisyopt package instead
355-
356-
357353
def finite_difference_differentiator(
358354
f: Callable[..., Any],
359355
argnums: Tuple[int, ...] = (0,),
@@ -456,7 +452,7 @@ def evol_local(
456452
457453
:param c: _description_
458454
:type c: Circuit
459-
:param index: _description_
455+
:param index: qubit sites to evolve
460456
:type index: Sequence[int]
461457
:param h_fun: h_fun should return a dense Hamiltonian matrix
462458
with input arguments time and *args
@@ -525,3 +521,48 @@ def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
525521
ts = backend.cast(ts, dtype=rdtypestr)
526522
s1 = odeint(f, s, ts, *args, **solver_kws)
527523
return type(c)(n, inputs=s1[-1])
524+
525+
526+
def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -> None:
527+
"""
528+
save a jitted jax function as a file
529+
530+
:param filename: _description_
531+
:type filename: str
532+
:param f: the jitted function
533+
:type f: Callable[..., Any]
534+
:param args: example function arguments for ``f``
535+
"""
536+
537+
from jax import export
538+
539+
f_export = export.export(f)(*args) # type: ignore
540+
barray = f_export.serialize()
541+
542+
with open(filename, "wb") as file:
543+
file.write(barray)
544+
545+
546+
jax_func_save = jax_jitted_function_save
547+
548+
549+
def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
550+
"""
551+
load a jitted function from file
552+
553+
:param filename: _description_
554+
:type filename: str
555+
:return: the loaded function
556+
:rtype: _type_
557+
"""
558+
from jax import export
559+
560+
with open(filename, "rb") as f:
561+
barray = f.read()
562+
563+
f_load = export.deserialize(barray) # type: ignore
564+
565+
return f_load.call
566+
567+
568+
jax_func_load = jax_jitted_function_load

tensorcircuit/fgs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ def get_cmatrix(self, now_i: bool = True, now_j: bool = True) -> Tensor:
176176
return cmatrix
177177

178178
def get_reduced_cmatrix(self, subsystems_to_trace_out: List[int]) -> Tensor:
179+
"""
180+
get reduced correlation matrix by tracing out subsystems
181+
182+
:param subsystems_to_trace_out: list of sites to be traced out
183+
:type subsystems_to_trace_out: List[int]
184+
:return: reduced density matrix
185+
:rtype: Tensor
186+
"""
179187
m = self.get_cmatrix()
180188
if subsystems_to_trace_out is None:
181189
subsystems_to_trace_out = []

tests/test_circuit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,7 +1497,7 @@ def test_gate_count():
14971497
# {'x': 1, 'h': 2, 'rx': 1, 'multicontrol': 1, 'toffoli': 3}
14981498

14991499

1500-
def test_to_openqasm():
1500+
def test_to_openqasm(tmp_path):
15011501
c = tc.Circuit(3)
15021502
c.H(0)
15031503
c.rz(2, theta=0.2)
@@ -1511,8 +1511,8 @@ def test_to_openqasm():
15111511
c1 = tc.Circuit.from_openqasm(s)
15121512
print(c1.draw())
15131513
np.testing.assert_allclose(c.state(), c1.state())
1514-
c.to_openqasm(filename="test.qasm")
1515-
c2 = tc.Circuit.from_openqasm_file("test.qasm")
1514+
c.to_openqasm_file(os.path.join(tmp_path, "test.qasm"))
1515+
c2 = tc.Circuit.from_openqasm_file(os.path.join(tmp_path, "test.qasm"))
15161516
np.testing.assert_allclose(c.state(), c2.state())
15171517

15181518

tests/test_miscs.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,24 @@ def h_square_sparse(t, b):
254254
def test_energy_baseline():
255255
print(TFIM1Denergy(10))
256256
print(Heisenberg1Denergy(10))
257+
258+
259+
def test_jax_function_load(jaxb, tmp_path):
260+
K = tc.backend
261+
262+
@K.jit
263+
def f(weights):
264+
c = tc.Circuit(3)
265+
c.rx(range(3), theta=weights)
266+
return K.real(c.expectation_ps(z=[0]))
267+
268+
print(f(K.ones([3])))
269+
270+
experimental.jax_jitted_function_save(
271+
os.path.join(tmp_path, "temp.bin"), f, K.ones([3])
272+
)
273+
274+
f_load = tc.experimental.jax_jitted_function_load(
275+
os.path.join(tmp_path, "temp.bin")
276+
)
277+
np.testing.assert_allclose(f_load(K.ones([3])), 0.5403, atol=1e-4)

0 commit comments

Comments
 (0)