Skip to content

Commit 908e25d

Browse files
add more docs
1 parent b1a501a commit 908e25d

File tree

3 files changed

+313
-10
lines changed

3 files changed

+313
-10
lines changed

docs/source/advance.rst

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,135 @@ Please refer to :py:meth:`tensorcircuit.templates.measurements.sparse_expectatio
220220
For different representations to evaluate Hamiltonian expectation in tensorcircuit, please refer to :doc:`tutorials/tfim_vqe_diffreph`.
221221

222222

223+
Hamiltonian Matrix Building
224+
----------------------------
225+
226+
TensorCircuit-NG provides multiple ways to build Hamiltonian matrices, especially for sparse Hamiltonians constructed from Pauli strings. This is crucial for quantum many-body physics simulations and variational quantum algorithms.
227+
228+
**Pauli String Based Construction:**
229+
230+
The most flexible way to build Hamiltonians is through Pauli strings:
231+
232+
.. code-block:: python
233+
234+
import tensorcircuit as tc
235+
236+
# Define Pauli strings and their weights
237+
# Each Pauli string is represented by a list of integers:
238+
# 0: Identity, 1: X, 2: Y, 3: Z
239+
pauli_strings = [
240+
[1, 1, 0], # X₁X₂I₃
241+
[3, 3, 0], # Z₁Z₂I₃
242+
[0, 0, 1], # I₁I₂X₃
243+
]
244+
weights = [0.5, 1.0, -0.2]
245+
246+
# Build sparse Hamiltonian
247+
h_sparse = tc.quantum.PauliStringSum2COO(pauli_strings, weights)
248+
249+
# Or dense Hamiltonian if preferred
250+
h_dense = tc.quantum.PauliStringSum2Dense(pauli_strings, weights)
251+
252+
253+
**High-Level Hamiltonian Construction:**
254+
255+
For common Hamiltonians like Heisenberg model:
256+
257+
.. code-block:: python
258+
259+
# Create a 1D chain with 10 sites
260+
g = tc.templates.graphs.Line1D(10, pbc=True) # periodic boundary condition
261+
262+
# XXZ model
263+
h = tc.quantum.heisenberg_hamiltonian(
264+
g,
265+
hxx=1.0, # XX coupling
266+
hyy=1.0, # YY coupling
267+
hzz=1.2, # ZZ coupling
268+
hx=0.5, # X field
269+
sparse=True
270+
)
271+
272+
273+
**Advanced Usage:**
274+
275+
1. Converting between xyz and Pauli string representations:
276+
277+
.. code-block:: python
278+
279+
# Convert Pauli string to xyz format
280+
xyz_dict = tc.quantum.ps2xyz([1, 2, 2, 0]) # X₁Y₂Y₃I₄
281+
print(xyz_dict) # {'x': [0], 'y': [1, 2], 'z': []}
282+
283+
# Convert back to Pauli string
284+
ps = tc.quantum.xyz2ps(xyz_dict, n=4)
285+
print(ps) # [1, 2, 2, 0]
286+
287+
288+
2. Working with MPO format:
289+
290+
TensorCircuit-NG supports conversion from different MPO (Matrix Product Operator) formats, particularly from TensorNetwork and Quimb libraries. This is useful when you want to leverage existing MPO implementations or convert between different frameworks.
291+
292+
**TensorNetwork MPO:**
293+
294+
For TensorNetwork MPOs, you can convert predefined models like the Transverse Field Ising (TFI) model:
295+
296+
.. code-block:: python
297+
298+
import tensorcircuit as tc
299+
import tensornetwork as tn
300+
301+
# Create TFI Hamiltonian MPO from TensorNetwork
302+
nwires = 6
303+
Jx = np.array([1.0] * (nwires - 1)) # XX coupling strength
304+
Bz = np.array([-1.0] * nwires) # Transverse field strength
305+
306+
# Create TensorNetwork MPO
307+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(
308+
Jx, Bz,
309+
dtype=np.complex64
310+
)
311+
312+
# Convert to TensorCircuit format
313+
tc_mpo = tc.quantum.tn2qop(tn_mpo)
314+
315+
# Get dense matrix representation
316+
h_matrix = tc_mpo.eval_matrix()
317+
318+
Note: TensorNetwork MPO currently only supports open boundary conditions.
319+
320+
**Quimb MPO:**
321+
322+
Quimb provides more flexible MPO construction options:
323+
324+
.. code-block:: python
325+
326+
import tensorcircuit as tc
327+
import quimb.tensor as qtn
328+
329+
# Create Ising Hamiltonian MPO using Quimb
330+
nwires = 6
331+
J = 4.0 # ZZ coupling
332+
h = 2.0 # X field
333+
qb_mpo = qtn.MPO_ham_ising(
334+
nwires,
335+
J, h,
336+
cyclic=True # Periodic boundary conditions
337+
)
338+
339+
# Convert to TensorCircuit format
340+
tc_mpo = tc.quantum.quimb2qop(qb_mpo)
341+
342+
# Custom Hamiltonian construction
343+
builder = qtn.SpinHam1D()
344+
builder += 1.0, "Y" # Add Y term with strength 1.0
345+
builder += 0.5, "X" # Add X term with strength 0.5
346+
H = builder.build_mpo(3) # Build for 3 sites
347+
348+
# Convert to TensorCircuit MPO
349+
h_tc = tc.quantum.quimb2qop(H)
350+
351+
223352
Fermion Gaussian State Simulator
224353
--------------------------------
225354

docs/source/quickstart.rst

Lines changed: 165 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,21 +224,104 @@ The related API design in TensorCircuit-NG closely follows the functional progra
224224

225225
**AD Support:**
226226

227-
Gradients, vjps, jvps, natural gradients, Jacobians, and Hessians.
228-
AD is the base for all modern machine learning libraries.
227+
Automatic Differentiation (AD) is crucial for quantum circuit optimization. TensorCircuit-NG supports various differentiation operations:
229228

229+
* Gradients: First-order derivatives
230+
* Vector-Jacobian products (vjps): Efficient backward-mode differentiation
231+
* Jacobian-vector products (jvps): Efficient forward-mode differentiation
232+
* Natural gradients: Geometry-aware optimization
233+
* Jacobians: Full derivative matrices
234+
* Hessians: Second-order derivatives
235+
236+
Example of gradient computation:
237+
238+
.. code-block:: python
239+
240+
import tensorcircuit as tc
241+
K = tc.set_backend("tensorflow")
242+
243+
def circuit(params):
244+
c = tc.Circuit(2)
245+
c.rx(0, theta=params[0])
246+
c.ry(1, theta=params[1])
247+
c.cnot(0, 1)
248+
return K.real(c.expectation([tc.gates.z(), [0]]))
249+
250+
# Get value and gradient
251+
params = K.ones([2])
252+
value, grad = K.value_and_grad(circuit)(params)
253+
print("Value:", value)
254+
print("Gradient:", grad)
255+
256+
# Compute Hessian
257+
hess = K.hessian(circuit)(params)
258+
print("Hessian:", hess)
230259
231260
**JIT Support:**
232261

233-
Parameterized quantum circuits can run in a blink. Always use jit if the circuit will get evaluations multiple times, it can greatly boost the simulation with two or three order time reduction. But also be cautious, users need to be familiar with jit, otherwise, the jitted function may return unexpected results or recompile on every hit (wasting lots of time).
234-
To learn more about the jit mechanism, one can refer to documentation or blogs on ``tf.function`` or ``jax.jit``, though these two still have subtle differences.
262+
Just-In-Time (JIT) compilation significantly accelerates quantum circuit simulation by optimizing the computation graph. Key points:
263+
264+
* Use JIT for functions that will be called multiple times
265+
* JIT compilation has some overhead, so it's most beneficial for repeated executions
266+
* Ensure input shapes and types are consistent to avoid recompilation
267+
* The input and output of the functions are all tensors, except static inputs.
268+
269+
Example of JIT acceleration:
270+
271+
.. code-block:: python
272+
273+
import time
274+
275+
# Define a quantum circuit function
276+
def noisy_circuit(key):
277+
c = tc.Circuit(5)
278+
for i in range(5):
279+
c.h(i)
280+
c.depolarizing(i, px=0.01, py=0.01, pz=0.01, status=key[i])
281+
return c.expectation_ps(z=[0])
282+
283+
# Compare performance with and without JIT
284+
start = time.time()
285+
for _ in range(100):
286+
noisy_circuit(K.ones([5]))
287+
print("Without JIT:", time.time() - start)
235288
289+
jitted_circuit = K.jit(noisy_circuit)
290+
start = time.time()
291+
for _ in range(100):
292+
jitted_circuit(K.ones([5]))
293+
print("With JIT:", time.time() - start)
236294
237295
**VMAP Support:**
238296

239-
Inputs, parameters, measurements, circuit structures, and Monte Carlo noise can all be evaluated in parallel.
240-
To learn more about vmap mechanism, one can refer to documentation or blogs on ``tf.vectorized_map`` or ``jax.vmap``.
241-
One can also refer to `tutorial <https://tensorcircuit-ng.readthedocs.io/en/latest/whitepaper/6-3-vmap.html>`_ for more details on the vmap usage in TensorCircuit-NG.
297+
Vectorized mapping (vmap) enables parallel evaluation across multiple inputs or parameters:
298+
299+
* Batch processing of quantum circuit input wavefunctions
300+
* Batch processing quantum circuit structure
301+
* Parallel parameter optimization
302+
* Efficient Monte Carlo sampling for noise simulation
303+
* Vectorized measurement operations
304+
305+
Example of vmap for parallel circuit evaluation:
306+
307+
.. code-block:: python
308+
309+
# Define a parameterized circuit
310+
def param_circuit(params):
311+
c = tc.Circuit(2)
312+
c.rx(0, theta=params[0])
313+
c.ry(1, theta=params[1])
314+
return K.real(c.expectation([tc.gates.z(), [0]]))
315+
316+
# Create batch of parameters
317+
batch_params = K.ones([10, 2])
318+
319+
# Vectorize the circuit evaluation
320+
vmap_circuit = K.vmap(param_circuit)
321+
results = vmap_circuit(batch_params)
322+
323+
324+
For more advanced usage patterns and detailed examples of vmap, refer to our `vmap tutorial <https://tensorcircuit-ng.readthedocs.io/en/latest/whitepaper/6-3-vmap.html>`_.
242325

243326

244327
Backend Agnosticism
@@ -424,15 +507,14 @@ and the other part is implemented in `TensorCircuit package <modules.html#module
424507
'vvag',
425508
'zeros']
426509
427-
428510
429511
Switch the Dtype
430512
--------------------
431513
432514
TensorCircuit-NG supports simulation using 32/64 bit precession. The default dtype is 32-bit as "complex64".
433515
Change this by ``tc.set_dtype("complex128")``.
434516
435-
``tc.dtypestr`` always returns the current dtype string: either "complex64" or "complex128".
517+
``tc.dtypestr`` always returns the current dtype string: either "complex64" or "complex128". Accordingly, ``tc.rdtypestr`` always returns the current real dtype string: either "float32" or "float64".
436518
437519
438520
Setup the Contractor
@@ -769,6 +851,79 @@ We also provider wrapper of quantum function for keras layer as :py:meth:`tensor
769851
l = layer(v)
770852
grad = tape.gradient(l, layer.trainable_variables)
771853
854+
**JAX interfaces:**
855+
856+
TensorCircuit-NG also newly introduces JAX interface to seamlessly integrate with JAX's ecosystem.
857+
This allows you to use JAX's powerful features like automatic differentiation, JIT compilation, and vectorization with quantum circuits or functions running on any backend.
858+
859+
Basic usage with JAX interface:
860+
861+
.. code-block:: python
862+
863+
import tensorcircuit as tc
864+
import jax
865+
import jax.numpy as jnp
866+
867+
# Set non-jax backend
868+
tc.set_backend("tensorflow")
869+
870+
def circuit(params):
871+
c = tc.Circuit(2)
872+
c.rx(0, theta=params[0])
873+
c.ry(1, theta=params[1])
874+
c.cnot(0, 1)
875+
return tc.backend.real(c.expectation_ps(z=[1]))
876+
877+
# Wrap the circuit with JAX interface
878+
jax_circuit = tc.interfaces.jax_interface(circuit, jit=True)
879+
880+
# Now you can use JAX features
881+
params = jnp.ones(2)
882+
value, grad = jax.value_and_grad(jax_circuit)(params)
883+
print("Value:", value)
884+
print("Gradient:", grad)
885+
886+
Some advanced features:
887+
888+
1. DLPack support for efficient tensor conversion:
889+
890+
.. code-block:: python
891+
892+
# Enable DLPack for zero-copy tensor conversion
893+
jax_circuit = tc.interfaces.jax_interface(circuit,
894+
jit=True,
895+
enable_dlpack=True)
896+
897+
2. Explicit output shape specification for better performance:
898+
899+
.. code-block:: python
900+
901+
# Specify output shape and dtype
902+
jax_circuit = tc.interfaces.jax_interface(circuit,
903+
jit=True,
904+
output_shape=(1,),
905+
output_dtype=jnp.float32)
906+
907+
3. Multiple outputs support:
908+
909+
.. code-block:: python
910+
911+
def multi_output_circuit(params):
912+
c = tc.Circuit(2)
913+
c.rx(0, theta=params[0])
914+
c.ry(1, theta=params[1])
915+
z0 = c.expectation([tc.gates.z(), [0]])
916+
z1 = c.expectation([tc.gates.z(), [1]])
917+
return tc.backend.real(z0), tc.backend.real(z1)
918+
919+
jax_circuit = tc.interfaces.jax_interface(multi_output_circuit,
920+
jit=True,
921+
output_shape=[[], []],
922+
output_dtype=[jnp.float32, jnp.float32])
923+
# Now you can use JAX features
924+
params = jnp.ones(2)
925+
value, grad = jax.value_and_grad(tc.utils.append(jax_circuit, sum))(params)
926+
772927
773928
774929
**Scipy Interface to Utilize Scipy Optimizers:**
@@ -827,3 +982,4 @@ See :py:meth:`tensorcircuit.templates.measurements.heisenberg_measurements`
827982
828983
.. figure:: statics/bell_pair_block.png
829984
:scale: 50%
985+

docs/source/sharpbits.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ If the device is not consistent, one can move the tensor between devices by ``tc
191191
AD Consistency
192192
---------------------
193193

194+
Gradients in terms of complex dtypes
195+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196+
194197
TF and JAX backend manage the differentiation rules differently for complex-valued function (actually up to a complex conjuagte). See issue discussion `tensorflow issue <https://github.com/tensorflow/tensorflow/issues/3348>`_.
195198

196199
In TensorCircuit-NG, currently we make the difference in AD transparent, namely, when switching the backend, the AD behavior and result for complex valued function can be different and determined by the nature behavior of the corresponding backend framework.
@@ -222,4 +225,19 @@ Also see the code below for a reference:
222225
# tf.Tensor([0.90929717 0.90929717], shape=(2,), dtype=float32)
223226
# jax backend
224227
# [0.90929747-0.9228759j 0.90929747-0.9228759j]
225-
# [0.90929747 0.90929747]
228+
# [0.90929747 0.90929747]
229+
230+
231+
VMAP outside grad-like function on tensorflow backend
232+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
233+
234+
Vmap (vectorized map) outside a grad-like function may cause incorrected results on TensorFlow backends due to a long existing `bug <https://github.com/tensorflow/tensorflow/issues/52148>`_ in TensorFlow codebase. So better always stick to the first-vmap-then-differentiated paradigm.
235+
236+
Grad over vmap function
237+
~~~~~~~~~~~~~~~~~~~~~~~~~
238+
239+
A related issue is the different behavior for `K.grad(K.vmap(f))` on different backends. For tensorflow backend, the function to be differentiated has a scalar output which is the sum of all outputs.
240+
241+
However, for Jax backend, the function simply raise error as only scalar output function can be differentiated, no implicit sum of the vectorized ``f`` is assumed. For non-scalar output, one should use `jacrev` or `jacfwd` to get the gradient information.
242+
243+
Specifically, `K.grad(K.vmap(f))` on TensorFlow backend is equilvalent to `K.grad(K.append(K.vamp(f), K.sum))` on Jax backend.

0 commit comments

Comments
 (0)