diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index f2c45e2..dd1950f 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 name: Install Python @@ -42,7 +42,7 @@ jobs: # -T : display a full traceback if a Python exception occurs - name: Upload built files - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: qutip_qoc_html_docs path: doc/_build/html/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 69e1b03..14185e3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,7 +47,7 @@ jobs: python-version: "3.12" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: diff --git a/README.md b/README.md index 9dde70f..b42d479 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,12 @@ To install the package, use pip install qutip-qoc ``` +By default, the dependencies required for JOPT and for the RL (reinforcement learning) algorithm are omitted. +They can be included by using the targets `qutip-qoc[jopt]` and `qutip-qoc[rl]`, respectively (or `qutip-qoc[full]` for all dependencies). + ## Documentation and tutorials -The documentation of `qutip-qoc` updated to the latest development version is hosted at [qutip-qoc.readthedocs.io](https://qutip-qoc.readthedocs.io/en/latest/). +The documentation of `qutip-qoc` updated to the latest development version is hosted at [qutip-qoc.readthedocs.io](https://qutip-qoc.readthedocs.io/latest/). Tutorials related to using quantum optimal control in `qutip-qoc` can be found [_here_](https://qutip.org/qutip-tutorials/#optimal-control). ## Installation from source @@ -40,7 +43,7 @@ pip install --upgrade pip pip install -e . ``` -which makes sure that you are up to date with the latest `pip` version. Contribution guidelines are available [_here_](https://qutip-qoc.readthedocs.io/en/latest/contribution-code.html). +which makes sure that you are up to date with the latest `pip` version. Contribution guidelines are available [_here_](https://qutip-qoc.readthedocs.io/latest/contribution/code.html). To build and test the documentation, additional packages need to be installed: diff --git a/VERSION b/VERSION index 6da28dd..d917d3e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.1 \ No newline at end of file +0.1.2 diff --git a/doc/changelog.rst b/doc/changelog.rst index a90329f..2e869b0 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -10,6 +10,11 @@ This is an update to the beta release of ``qutip-qoc``. It mainly introduces the new reinforcement learning algorithm ``qutip_qoc._rl``. +- Non-public facing functions have been renamed to start with an underscore. +- As with other QuTiP functions, ``optimize_pulses`` now takes a ``tlist`` argument instead of ``_TimeInterval``. +- The structure for the control guess and bounds has changed and now takes in an optional ``__time__`` keyword. +- The ``result`` does no longer return ``optimized_objectives`` but instead ``optimized_H``. + Features -------- diff --git a/doc/conf.py b/doc/conf.py index ab65dc7..96ca014 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -62,6 +62,10 @@ def qutip_qoc_version(): html_theme = "sphinx_rtd_theme" html_static_path = [] +html_js_files = [ + 'https://code.jquery.com/jquery-3.6.0.min.js', +] + # -- Options for numpydoc --------------------------------------- diff --git a/doc/contribution/code.rst b/doc/contribution/code.rst index 8a961b8..6cf0705 100644 --- a/doc/contribution/code.rst +++ b/doc/contribution/code.rst @@ -7,7 +7,7 @@ Contributing to the source code Build up an development environment =================================== -Please follow the instruction on the `QuTiP contribution guide `_ to +Please follow the instruction on the `QuTiP contribution guide `_ to build a conda environment. You don't need to build ``qutip`` in the editable mode unless you also want to contribute to `qutip`. diff --git a/doc/installation.rst b/doc/installation.rst index f619143..5754564 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -23,7 +23,19 @@ In particular, the following packages are necessary for running ``qutip-qoc``: .. code-block:: bash - numpy scipy jax jaxlib cython qutip qutip-jax qutip-qtrl + numpy scipy cython qutip qutip-qtrl + +The following packages are required for using the JOPT algorithm: + +.. code-block:: bash + + jax jaxlib qutip-jax + +The following packages are required for the RL (reinforcement learning) algorithm: + +.. code-block:: bash + + gymnasium stable-baselines3 The following package is used for testing: diff --git a/doc/requirements.txt b/doc/requirements.txt index 1cf3d0a..3bc7ba6 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -9,18 +9,18 @@ matplotlib>=3.6.1 docutils==0.18.1 alabaster==0.7.12 Babel==2.9.1 -certifi==2020.12.5 +certifi==2024.7.4 chardet==4.0.0 colorama==0.4.4 -idna==2.10 +idna==3.7 imagesize==1.4.1 -Jinja2==3.0.1 +Jinja2==3.1.6 MarkupSafe==2.0.1 packaging==23.2 Pygments==2.17.2 pyparsing==2.4.7 pytz==2021.1 -requests==2.25.1 +requests==2.32.2 snowballstemmer==2.1.0 Sphinx==6.1.3 sphinx-gallery==0.12.2 @@ -32,4 +32,4 @@ sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 -urllib3==1.26.4 +urllib3==1.26.19 diff --git a/requirements.txt b/requirements.txt index 50da21d..bfebcb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,8 @@ cython>=1.0 numpy>=1.16.6,<2.0 scipy>=1.10.1 -jax==0.4.28 -jaxlib==0.4.28 qutip>=5.0.1 qutip-qtrl -qutip-jax pre-commit gymnasium>=0.29.1 stable-baselines3>=2.3.2 diff --git a/setup.cfg b/setup.cfg index 7872023..04fe89b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,15 +28,10 @@ package_dir= packages = find: include_package_data = True install_requires = - jax - jaxlib packaging qutip qutip-qtrl - qutip-jax numpy>=1.16.6,<2.0 - gymnasium>=0.29.1 - stable-baselines3>=2.3.2 setup_requires = cython>=1.0 packaging @@ -44,8 +39,24 @@ setup_requires = [options.packages.find] where = src +[options.entry_points] +qutip.family = + qutip_qoc = qutip_qoc.family + [options.extras_require] +jopt = + jax + jaxlib + qutip-jax +rl = + gymnasium>=0.29.1 + stable-baselines3>=2.3.2 tests = pytest>=6.0 + jupytext + nbconvert + ipykernel full = + %(jopt)s + %(rl)s %(tests)s diff --git a/src/qutip_qoc/_jopt.py b/src/qutip_qoc/_jopt.py index 0e94491..6c9cfe2 100644 --- a/src/qutip_qoc/_jopt.py +++ b/src/qutip_qoc/_jopt.py @@ -5,38 +5,47 @@ import qutip as qt from qutip import Qobj, QobjEvo -from diffrax import Dopri5, PIDController +try: + import jax + from jax import custom_jvp + import jax.numpy as jnp + import qutip_jax # noqa: F401 -import jax -from jax import custom_jvp -import jax.numpy as jnp -import qutip_jax # noqa: F401 + import jaxlib # noqa: F401 + from diffrax import Dopri5, PIDController -@custom_jvp -def _abs(x): - return jnp.abs(x) + _jax_available = True +except ImportError: + _jax_available = False -def _abs_jvp(primals, tangents): - """ - Custom jvp for absolute value of complex functions - """ - (x,) = primals - (t,) = tangents +if _jax_available: + + @custom_jvp + def _abs(x): + return jnp.abs(x) + - abs_x = _abs(x) - res = jnp.where( - abs_x == 0, - 0.0, # prevent division by zero - jnp.real(jnp.multiply(jnp.conj(x), t)) / abs_x, - ) + def _abs_jvp(primals, tangents): + """ + Custom jvp for absolute value of complex functions + """ + (x,) = primals + (t,) = tangents + + abs_x = _abs(x) + res = jnp.where( + abs_x == 0, + 0.0, # prevent division by zero + jnp.real(jnp.multiply(jnp.conj(x), t)) / abs_x, + ) - return abs_x, res + return abs_x, res -# register custom jvp for absolut value of complex functions -_abs.defjvp(_abs_jvp) + # register custom jvp for absolut value of complex functions + _abs.defjvp(_abs_jvp) class _JOPT: @@ -55,6 +64,10 @@ def __init__( guess_params, **integrator_kwargs, ): + if not _jax_available: + raise ImportError("The JOPT algorithm requires the modules jax, " + "jaxlib, and qutip_jax to be installed.") + self._Hd = objective.H[0] self._Hc_lst = objective.H[1:] @@ -137,8 +150,8 @@ def _infid(self, params): if self._fid_type == "TRACEDIFF": diff = X - self._target # to prevent if/else in qobj.dag() and qobj.tr() - diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims) - g = 1 / 2 * (diff_dag * diff).data.trace() + diff_dag = diff.dag() # direct access to JAX array, no fallback! + g = 1 / 2 * jnp.trace(diff_dag.data._jxa @ diff.data._jxa) infid = jnp.real(self._norm_fac * g) else: g = self._norm_fac * self._target.overlap(X) @@ -147,4 +160,4 @@ def _infid(self, params): elif self._fid_type == "SU": # f_SU (incl global phase) infid = 1 - jnp.real(g) - return infid + return infid \ No newline at end of file diff --git a/src/qutip_qoc/_rl.py b/src/qutip_qoc/_rl.py index 9a75a55..21e8cb1 100644 --- a/src/qutip_qoc/_rl.py +++ b/src/qutip_qoc/_rl.py @@ -89,7 +89,7 @@ def create_pulse_func(idx): self._result = Result( objectives=objectives, time_interval=time_interval, - start_local_time=time.localtime(), # initial optimization time + start_local_time=time.time(), # initial optimization time n_iters=0, # Number of iterations(episodes) until convergence iter_seconds=[], # list containing the time taken for each iteration(episode) of the optimization var_time=True, # Whether the optimization was performed with variable time @@ -99,7 +99,7 @@ def create_pulse_func(idx): self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid 1 - else time.mktime(self._result.start_local_time) + else self._result.start_local_time ) self._result.iter_seconds.append(time_diff) self._current_step = 0 # Reset the step counter @@ -281,7 +281,7 @@ def _save_result(self): self._backup_result._final_states = self._result._final_states.copy() self._backup_result.infidelity = self._result.infidelity - result_obj.end_local_time = time.localtime() + result_obj.end_local_time = time.time() result_obj.n_iters = len(self._result.iter_seconds) result_obj.optimized_params = self._actions.copy() + [ self._result.total_seconds @@ -296,20 +296,20 @@ def result(self): """ if self._use_backup_result: self._backup_result.start_local_time = time.strftime( - "%Y-%m-%d %H:%M:%S", self._backup_result.start_local_time - ) # Convert to a string + "%Y-%m-%d %H:%M:%S", time.localtime(self._backup_result.start_local_time) + ) self._backup_result.end_local_time = time.strftime( - "%Y-%m-%d %H:%M:%S", self._backup_result.end_local_time - ) # Convert to a string + "%Y-%m-%d %H:%M:%S", time.localtime(self._backup_result.end_local_time) + ) return self._backup_result else: self._save_result() self._result.start_local_time = time.strftime( - "%Y-%m-%d %H:%M:%S", self._result.start_local_time - ) # Convert to a string + "%Y-%m-%d %H:%M:%S", time.localtime(self._result.start_local_time) + ) self._result.end_local_time = time.strftime( - "%Y-%m-%d %H:%M:%S", self._result.end_local_time - ) # Convert to a string + "%Y-%m-%d %H:%M:%S", time.localtime(self._result.end_local_time) + ) return self._result def train(self): diff --git a/src/qutip_qoc/family.py b/src/qutip_qoc/family.py new file mode 100644 index 0000000..d76645a --- /dev/null +++ b/src/qutip_qoc/family.py @@ -0,0 +1,8 @@ +"""QuTiP family package entry point.""" + +from . import __version__ + + +def version(): + """Return information to include in qutip.about().""" + return "qutip-qoc", __version__ diff --git a/src/qutip_qoc/pulse_optim.py b/src/qutip_qoc/pulse_optim.py index 5f65e46..76ac741 100644 --- a/src/qutip_qoc/pulse_optim.py +++ b/src/qutip_qoc/pulse_optim.py @@ -12,6 +12,14 @@ from qutip_qoc._time import _TimeInterval from qutip_qoc._rl import _RL +import qutip as qt + +try: + from qutip_qoc._rl import _RL + _rl_available = True +except ImportError: + _rl_available = False + __all__ = ["optimize_pulses"] @@ -23,6 +31,7 @@ def optimize_pulses( optimizer_kwargs=None, minimizer_kwargs=None, integrator_kwargs=None, + optimization_type=None, ): """ Run GOAT, JOPT, GRAPE, CRAB or RL optimization. @@ -119,6 +128,11 @@ def optimize_pulses( Options for the solver, see :obj:`MESolver.options` and `Integrator <./classes.html#classes-ode>`_ for a list of all options. + optimization_type : str, optional + Type of optimization. By default, QuTiP-QOC will try to automatically determine + whether this is a *state transfer* or a *gate synthesis* problem. Set this + flag to ``"state_transfer"`` or ``"gate_synthesis"`` to set the mode manually. + Returns ------- result : :class:`qutip_qoc.Result` @@ -183,9 +197,63 @@ def optimize_pulses( "gtol": algorithm_kwargs.get("min_grad", 0.0 if alg == "CRAB" else 1e-8), } + # Iterate over objectives and convert initial and target states based on the optimization type + for objective in objectives: + H_list = objective.H if isinstance(objective.H, list) else [objective.H] + if any(qt.issuper(H_i) for H_i in H_list): + if isinstance(optimization_type, str) and optimization_type.lower() == "state_transfer": + if qt.isket(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + elif qt.isoper(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(objective.initial) + + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) + elif qt.isoper(objective.target): + objective.target = qt.operator_to_vector(objective.target) + + algorithm_kwargs.setdefault("fid_params", {}) + algorithm_kwargs["fid_params"].setdefault("scale_factor", 1.0 / dim) + + elif isinstance(optimization_type, str) and optimization_type.lower() == "gate_synthesis": + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + + elif optimization_type is None: + is_state_transfer = False + if qt.isoper(objective.initial) and qt.isoper(objective.target): + if np.isclose(objective.initial.tr(), 1) and np.isclose(objective.target.tr(), 1): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(objective.initial) + objective.target = qt.operator_to_vector(objective.target) + is_state_transfer = True + else: + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + + if qt.isket(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + is_state_transfer = True + + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) + is_state_transfer = True + + if is_state_transfer: + algorithm_kwargs.setdefault("fid_params", {}) + algorithm_kwargs["fid_params"].setdefault("scale_factor", 1.0 / dim) + # prepare qtrl optimizers qtrl_optimizers = [] if alg == "CRAB" or alg == "GRAPE": + dyn_type = "GEN_MAT" + for objective in objectives: + if any(qt.isoper(H_i) for H_i in (objective.H if isinstance(objective.H, list) else [objective.H])): + dyn_type = "UNIT" + if alg == "GRAPE": # algorithm specific kwargs use_as_amps = True minimizer_kwargs.setdefault("method", "L-BFGS-B") # gradient @@ -242,7 +310,7 @@ def optimize_pulses( "accuracy_factor": None, # deprecated "alg_params": alg_params, "optim_params": algorithm_kwargs.get("optim_params", None), - "dyn_type": algorithm_kwargs.get("dyn_type", "GEN_MAT"), + "dyn_type": algorithm_kwargs.get("dyn_type", dyn_type), "dyn_params": algorithm_kwargs.get("dyn_params", None), "prop_type": algorithm_kwargs.get( "prop_type", "DEF" @@ -353,6 +421,11 @@ def optimize_pulses( qtrl_optimizers.append(qtrl_optimizer) elif alg == "RL": + if not _rl_available: + raise ImportError( + "The required dependencies (gymnasium, stable-baselines3) for " + "the reinforcement learning algorithm are not available." + ) rl_env = _RL( objectives, control_parameters, diff --git a/src/qutip_qoc/result.py b/src/qutip_qoc/result.py index 493104a..7738aa9 100644 --- a/src/qutip_qoc/result.py +++ b/src/qutip_qoc/result.py @@ -2,7 +2,6 @@ This module contains the Result class for storing and reporting the results of a full pulse control optimization run. """ -import jaxlib import pickle import textwrap import numpy as np @@ -11,6 +10,13 @@ import qutip as qt +try: + import jax + import jaxlib + _jitfun_type = type(jax.jit(lambda x: x)) +except ImportError: + _jitfun_type = None + __all__ = ["Result"] @@ -331,8 +337,10 @@ def final_states(self): evo_time = self.time_interval.evo_time # choose solver method based on type of control function - if isinstance( - self.objectives[0].H[1][1], jaxlib.xla_extension.PjitFunction + # if jax is installed, _jitfun_type is set to + # jaxlib.xla_extension.PjitFunction, otherwise it is None + if _jitfun_type is not None and isinstance( + self.objectives[0].H[1][1], _jitfun_type ): method = "diffrax" # for JAX defined contols else: diff --git a/tests/interactive/CRAB_gate_closed.md b/tests/interactive/CRAB_gate_closed.md new file mode 100644 index 0000000..ea36556 --- /dev/null +++ b/tests/interactive/CRAB_gate_closed.md @@ -0,0 +1,102 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# CRAB algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] +H = [Hd, Hc[0], Hc[1], Hc[2]] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## CRAB algorithm + +```python +n_params = 3 # adjust in steps of 3 +control_params = { + "ctrl_x": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, + "ctrl_y": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, + "ctrl_z": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, +} + +res_crab = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "CRAB", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_crab.infidelity) + +plt.plot(times, res_crab.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_crab.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_crab.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('CRAB pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc[0], res_crab.optimized_controls[0]], [Hc[1], res_crab.optimized_controls[1]], [Hc[2], res_crab.optimized_controls[2]]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('CRAB performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_crab.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/CRAB_state_closed.md b/tests/interactive/CRAB_state_closed.md new file mode 100644 index 0000000..278ab21 --- /dev/null +++ b/tests/interactive/CRAB_state_closed.md @@ -0,0 +1,95 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# CRAB algorithm for 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## CRAB algorithm + +```python +n_params = 6 # adjust in steps of 3 +control_params = { + "ctrl_x": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, +} + +res_crab = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "CRAB", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_crab.infidelity) + +plt.plot(times, res_crab.optimized_controls[0], label='optimized pulse') +plt.title('CRAB pulse') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_crab.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title("CRAB performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_crab.infidelity < 0.001 +assert np.abs(evolution.states[-1].overlap(target_state)) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GOAT_gate_closed.md b/tests/interactive/GOAT_gate_closed.md new file mode 100644 index 0000000..4eafa97 --- /dev/null +++ b/tests/interactive/GOAT_gate_closed.md @@ -0,0 +1,280 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GOAT algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +goat_guess = [1, 1] +guess_pulse = goat_guess[0] * np.sin(goat_guess[1] * times) + +H_guess = [Hd] + [[hc, guess_pulse] for hc in Hc] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GOAT algorithm + +```python +# control function +def sin(t, c): + return c[0] * np.sin(c[1] * t) + +# derivatives +def grad_sin(t, c, idx): + if idx == 0: # w.r.t. c0 + return np.sin(c[1] * t) + if idx == 1: # w.r.t. c1 + return c[0] * np.cos(c[1] * t) * t + if idx == 2: # w.r.t. time + return c[0] * np.cos(c[1] * t) * c[1] + +H = [Hd] + [[hc, sin, {"grad": grad_sin}] for hc in Hc] +``` + +### a) not optimized over time + +```python +control_params = { + id: {"guess": goat_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 + for id in ['x', 'y', 'z'] +} + +# run the optimization +res_goat = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_goat.infidelity) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times, res_goat.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_goat.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_goat.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, + [Hc[0], np.array(res_goat.optimized_controls[0])], + [Hc[1], np.array(res_goat.optimized_controls[1])], + [Hc[2], np.array(res_goat.optimized_controls[2])]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('GOAT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_goat_time = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_goat_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_goat_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], 'b', label='optimized pulse sx') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[1])[time_range], 'g', label='optimized pulse sy') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[2])[time_range], 'r', label='optimized pulse sz') +plt.title('GOAT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_goat_time.optimized_controls[0])], + [Hc[1], np.array(res_goat_time.optimized_controls[1])], + [Hc[2], np.array(res_goat_time.optimized_controls[2])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_gate, times2) + +plt.plot(times2, [fidelity(gate, initial_gate) for gate in evolution_time.states], label="Overlap with initial gate") +plt.plot(times2, [fidelity(gate, target_gate) for gate in evolution_time.states], label="Overlap with target gate") + +plt.title('GOAT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Global optimization + +```python +res_goat_global = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, + optimizer_kwargs = { + "method": "basinhopping", + "max_iter": 100, + } +) + +global_time = res_goat_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_goat_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], 'b', label='optimized pulse sx') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[1])[global_range], 'g', label='optimized pulse sy') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[2])[global_range], 'r', label='optimized pulse sz') +plt.title('GOAT pulses (global optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_goat_global.optimized_controls[0])], + [Hc[1], np.array(res_goat_global.optimized_controls[1])], + [Hc[2], np.array(res_goat_global.optimized_controls[2])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_gate, times3) + +plt.plot(times3, [fidelity(gate, initial_gate) for gate in evolution_global.states], label="Overlap with initial gate") +plt.plot(times3, [fidelity(gate, target_gate) for gate in evolution_global.states], label="Overlap with target gate") + +plt.title('GOAT (global optimization) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +fig, axes = plt.subplots(1, 3, figsize=(18, 4)) # 1 row, 3 columns + +titles = ["GOAT sx pulses", "GOAT sy pulses", "GOAT sz pulses"] + +for i, ax in enumerate(axes): + ax.plot(times, guess_pulse, label='initial guess') + ax.plot(times, res_goat.optimized_controls[i], label='optimized pulse') + ax.plot(times[time_range], np.array(res_goat_time.optimized_controls[i])[time_range], label='optimized (over time) pulse') + ax.plot(times[global_range], np.array(res_goat_global.optimized_controls[i])[global_range], label='global optimized pulse') + ax.set_title(titles[i]) + ax.set_xlabel('Time') + ax.set_ylabel('Pulse amplitude') + ax.legend() + +plt.tight_layout() +plt.show() +``` + +## Validation + +```python +assert res_goat.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 + +assert res_goat_time.infidelity < 0.001 +assert fidelity(evolution_time.states[-1], target_gate) > 1-0.001 + +assert res_goat_global.infidelity < 0.001 +assert fidelity(evolution_global.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GOAT_state_closed.md b/tests/interactive/GOAT_state_closed.md new file mode 100644 index 0000000..bbce5f2 --- /dev/null +++ b/tests/interactive/GOAT_state_closed.md @@ -0,0 +1,292 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GOAT algorithm for a 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +goat_guess = [1, 0.5] +guess_pulse = goat_guess[0] * np.sin(goat_guess[1] * times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GOAT algorithm + +```python +# control function +def sin(t, c): + return c[0] * np.sin(c[1] * t) + +# gradient +def grad_sin(t, c, idx): + if idx == 0: # w.r.t. c0 + return np.sin(c[1] * t) + if idx == 1: # w.r.t. c1 + return c[0] * np.cos(c[1] * t) * t + if idx == 2: # w.r.t. time + return c[0] * np.cos(c[1] * t) * c[1] + +H = [Hd] + [[Hc, sin, {"grad": grad_sin}]] +``` + +### a) not optimized over time + +```python +control_params = { + "ctrl_x": {"guess": goat_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 +} + +# run the optimization +res_goat = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_goat.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_goat.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('GOAT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +Here, GOAT is stuck in a local minimum and does not reach the desired fidelity. + + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_goat_time = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, +) + +opt_time = res_goat_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_goat_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], label='optimized (over time) pulse') +plt.title('GOAT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_goat_time.optimized_controls[0])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_state, times2) + +plt.plot(times2, [np.abs(state.overlap(initial_state)) for state in evolution_time.states], label="Overlap with initial state") +plt.plot(times2, [np.abs(state.overlap(target_state)) for state in evolution_time.states], label="Overlap with target state") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], '--', label="Fidelity") + +plt.title('GOAT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +GOAT is still stuck in a local minimum, but the fidelity has improved. + + +### c) global optimization + +```python +res_goat_global = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, + optimizer_kwargs={ + "method": "basinhopping", + "max_iter": 1000 + } +) + +global_time = res_goat_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_goat_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], label='global optimized pulse') +plt.title('GOAT pulses (global)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_goat_global.optimized_controls[0])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_state, times3) + +plt.plot(times3, [np.abs(state.overlap(initial_state)) for state in evolution_global.states], label="Overlap with initial state") +plt.plot(times3, [np.abs(state.overlap(target_state)) for state in evolution_global.states], label="Overlap with target state") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], '--', label="Fidelity") + +plt.title('GOAT (global) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +plt.plot(times, guess_pulse, color='blue', label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], color='orange', label='optimized pulse') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], + color='green', label='optimized (over time) pulse') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], + color='red', label='global optimized pulse') + +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +print('Guess Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) +print('GOAT Fidelity: ', 1 - res_goat.infidelity) +print('Time Fidelity: ', 1 - res_goat_time.infidelity) +print('GLobal Fidelity: ', 1 - res_goat_global.infidelity) + +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], color='blue', label="Guess") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], color='orange', label="GOAT") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], + color='green', label="Time") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], + color='red', label="Global") + +plt.title('Fidelities') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +guess_fidelity = qt.fidelity(evolution_guess.states[-1], target_state) + +# target fidelity not reached in part a), check that it is better than the guess +assert 1 - res_goat.infidelity > guess_fidelity +assert np.allclose(np.abs(evolution.states[-1].overlap(target_state)), 1 - res_goat.infidelity, atol=1e-3) + +# target fidelity not reached in part b), check that it is better than part a) +assert res_goat_time.infidelity < res_goat.infidelity +assert np.allclose(np.abs(evolution_time.states[-1].overlap(target_state)), 1 - res_goat_time.infidelity, atol=1e-3) + +assert res_goat_global.infidelity < 0.001 +assert np.abs(evolution_global.states[-1].overlap(target_state)) > 1 - 0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GRAPE_gate_closed.md b/tests/interactive/GRAPE_gate_closed.md new file mode 100644 index 0000000..bb9b145 --- /dev/null +++ b/tests/interactive/GRAPE_gate_closed.md @@ -0,0 +1,124 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GRAPE algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] +H = [Hd, Hc[0], Hc[1], Hc[2]] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +guess_pulse_x = np.sin(times) +guess_pulse_y = np.cos(times) +guess_pulse_z = np.tanh(times) + +H_guess = [Hd, [Hc[0], guess_pulse_x], [Hc[1], guess_pulse_y], [Hc[2], guess_pulse_z]] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GRAPE algorithm + +```python +control_params = { + "ctrl_x": {"guess": np.sin(times), "bounds": [-1, 1]}, + "ctrl_y": {"guess": np.cos(times), "bounds": [-1, 1]}, + "ctrl_z": {"guess": np.tanh(times), "bounds": [-1, 1]}, +} + +res_grape = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GRAPE", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_grape.infidelity) + +plt.plot(times, guess_pulse_x, 'b--', label='guess pulse sx') +plt.plot(times, res_grape.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, guess_pulse_y, 'g--', label='guess pulse sy') +plt.plot(times, res_grape.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, guess_pulse_z, 'r--', label='guess pulse sz') +plt.plot(times, res_grape.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('GRAPE pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc[0], res_grape.optimized_controls[0]], [Hc[1], res_grape.optimized_controls[1]], [Hc[2], res_grape.optimized_controls[2]]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('GRAPE performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_grape.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GRAPE_state_closed.md b/tests/interactive/GRAPE_state_closed.md new file mode 100644 index 0000000..2681df7 --- /dev/null +++ b/tests/interactive/GRAPE_state_closed.md @@ -0,0 +1,114 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GRAPE algorithm for 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +guess_pulse = np.sin(times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GRAPE algorithm + +```python +control_params = { + "ctrl_x": {"guess": guess_pulse, "bounds": [-1, 1]}, # Control pulse for Hc1 +} + +res_grape = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GRAPE", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_grape.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_grape.optimized_controls[0], label='optimized pulse') +plt.title('GRAPE pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_grape.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('GRAPE performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_grape.infidelity < 0.01 +assert np.abs(evolution.states[-1].overlap(target_state)) > 1-0.01 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/JOPT_gate_closed.md b/tests/interactive/JOPT_gate_closed.md new file mode 100644 index 0000000..1f3370c --- /dev/null +++ b/tests/interactive/JOPT_gate_closed.md @@ -0,0 +1,285 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JOPT algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +try: + from jax import jit + from jax import numpy as jnp +except ImportError: # JAX not available, skip test + import pytest + pytest.skip("JAX not available") + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +jopt_guess = [1, 1] +guess_pulse = jopt_guess[0] * np.sin(jopt_guess[1] * times) + +H_guess = [Hd] + [[hc, guess_pulse] for hc in Hc] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## JOPT algorithm + +```python +@jit +def sin_x(t, c, **kwargs): + return c[0] * jnp.sin(c[1] * t) + +H = [Hd] + [[hc, sin_x] for hc in Hc] +``` + +### a) not optimized over time + +```python +control_params = { + id: {"guess": jopt_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 + for id in ['x', 'y', 'z'] +} + +res_jopt = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs={ + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_jopt.infidelity) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times, res_jopt.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_jopt.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_jopt.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, + [Hc[0], np.array(res_jopt.optimized_controls[0])], + [Hc[1], np.array(res_jopt.optimized_controls[1])], + [Hc[2], np.array(res_jopt.optimized_controls[2])]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('JOPT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_jopt_time = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs={ + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_jopt_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_jopt_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], 'b', label='optimized pulse sx') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[1])[time_range], 'g', label='optimized pulse sy') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[2])[time_range], 'r', label='optimized pulse sz') +plt.title('JOPT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_jopt_time.optimized_controls[0])], + [Hc[1], np.array(res_jopt_time.optimized_controls[1])], + [Hc[2], np.array(res_jopt_time.optimized_controls[2])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_gate, times2) + +plt.plot(times2, [fidelity(gate, initial_gate) for gate in evolution_time.states], label="Overlap with initial gate") +plt.plot(times2, [fidelity(gate, target_gate) for gate in evolution_time.states], label="Overlap with target gate") + +plt.title('JOPT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Global optimization + +```python +res_jopt_global = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, + optimizer_kwargs = { + "method": "basinhopping", + "max_iter": 100, + } +) + +global_time = res_jopt_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_jopt_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], 'b', label='optimized pulse sx') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[1])[global_range], 'g', label='optimized pulse sy') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[2])[global_range], 'r', label='optimized pulse sz') +plt.title('JOPT pulses (global optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_jopt_global.optimized_controls[0])], + [Hc[1], np.array(res_jopt_global.optimized_controls[1])], + [Hc[2], np.array(res_jopt_global.optimized_controls[2])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_gate, times3) + +plt.plot(times3, [fidelity(gate, initial_gate) for gate in evolution_global.states], label="Overlap with initial gate") +plt.plot(times3, [fidelity(gate, target_gate) for gate in evolution_global.states], label="Overlap with target gate") + +plt.title('JOPT (global optimization) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +fig, axes = plt.subplots(1, 3, figsize=(18, 4)) # 1 row, 3 columns + +titles = ["JOPT sx pulses", "JOPT sy pulses", "JOPT sz pulses"] + +for i, ax in enumerate(axes): + ax.plot(times, guess_pulse, label='initial guess') + ax.plot(times, res_jopt.optimized_controls[i], label='optimized pulse') + ax.plot(times[time_range], np.array(res_jopt_time.optimized_controls[i])[time_range], label='optimized (over time) pulse') + ax.plot(times[global_range], np.array(res_jopt_global.optimized_controls[i])[global_range], label='global optimized pulse') + ax.set_title(titles[i]) + ax.set_xlabel('time') + ax.set_ylabel('Pulse amplitude') + ax.legend() + +plt.tight_layout() +plt.show() +``` + +## Validation + +```python +assert res_jopt.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 + +assert res_jopt_time.infidelity < 0.001 +assert fidelity(evolution_time.states[-1], target_gate) > 1-0.001 + +assert res_jopt_global.infidelity < 0.001 +assert fidelity(evolution_global.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` + + diff --git a/tests/interactive/JOPT_state_closed.md b/tests/interactive/JOPT_state_closed.md new file mode 100644 index 0000000..11547cb --- /dev/null +++ b/tests/interactive/JOPT_state_closed.md @@ -0,0 +1,296 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JOPT algorithm for a 2 level system + + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +try: + from jax import jit + from jax import numpy as jnp +except ImportError: # JAX not available, skip test + import pytest + pytest.skip("JAX not available") +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +jopt_guess = [1, 0.5] +guess_pulse = jopt_guess[0] * np.sin(jopt_guess[1] * times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## JOPT algorithm + +```python +@jit +def sin(t, c, **kwargs): + return c[0] * jnp.sin(c[1] * t) + +H = [Hd] + [[Hc, sin]] +``` + +### a) not optimized over time + +```python +control_params = { + "ctrl_x": {"guess": [1, 0], "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 +} + +res_jopt = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_jopt.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_jopt.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('JOPT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +Here, JOPT is stuck in a local minimum and does not reach the desired fidelity. + + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_jopt_time = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_jopt_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_jopt_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], label='optimized (over time) pulse') +plt.title('JOPT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_jopt_time.optimized_controls[0])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_state, times2) + +plt.plot(times2, [np.abs(state.overlap(initial_state)) for state in evolution_time.states], label="Overlap with initial state") +plt.plot(times2, [np.abs(state.overlap(target_state)) for state in evolution_time.states], label="Overlap with target state") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], '--', label="Fidelity") + +plt.title('JOPT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +JOPT is still stuck in a local minimum, but the fidelity has improved. + + +### c) global optimization + +```python +res_jopt_global = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001 + }, + optimizer_kwargs={ + "method": "basinhopping", + "max_iter": 1000 + } +) + +global_time = res_jopt_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_jopt_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], label='global optimized pulse') +plt.title('JOPT pulses (global)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_jopt_global.optimized_controls[0])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_state, times3) + +plt.plot(times3, [np.abs(state.overlap(initial_state)) for state in evolution_global.states], label="Overlap with initial state") +plt.plot(times3, [np.abs(state.overlap(target_state)) for state in evolution_global.states], label="Overlap with target state") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], '--', label="Fidelity") + +plt.title('JOPT (global) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +plt.plot(times, guess_pulse, color='blue', label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], color='orange', label='optimized pulse') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], + color='green', label='optimized (over time) pulse') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], + color='red', label='global optimized pulse') + +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +print('Guess Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) +print('JOPT Fidelity: ', 1 - res_jopt.infidelity) +print('Time Fidelity: ', 1 - res_jopt_time.infidelity) +print('GLobal Fidelity: ', 1 - res_jopt_global.infidelity) + +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], color='blue', label="Guess") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], color='orange', label="Goat") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], + color='green', label="Time") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], + color='red', label="Global") + +plt.title('Fidelities') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +guess_fidelity = qt.fidelity(evolution_guess.states[-1], target_state) + +# target fidelity not reached in part a), check that it is better than the guess +assert 1 - res_jopt.infidelity > guess_fidelity +assert np.allclose(np.abs(evolution.states[-1].overlap(target_state)), 1 - res_jopt.infidelity, atol=1e-3) + +# target fidelity not reached in part b), check that it is better than part a) +assert res_jopt_time.infidelity < res_jopt.infidelity +assert np.allclose(np.abs(evolution_time.states[-1].overlap(target_state)), 1 - res_jopt_time.infidelity, atol=1e-3) + +assert res_jopt_global.infidelity < 0.001 +assert np.abs(evolution_global.states[-1].overlap(target_state)) > 1 - 0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/about.md b/tests/interactive/about.md new file mode 100644 index 0000000..9d22894 --- /dev/null +++ b/tests/interactive/about.md @@ -0,0 +1,21 @@ +This directory contains notebooks showcasing the use of all algorithms contained within `QuTiP-QOC`. +The examples are chosen as simple as possible, for the purpose of demonstrating how to use `QuTiP-QOC` in all of these scenarios, and for the purpose of automatically testing `QuTiP-QOC`'s basic functionality in all of these scenarios. +The included algorithms are: +- GRAPE +- CRAB +- GOAT +- JOPT + +For each algorithm, we have: +- a closed-system state transfer example +- an open-system state transfer example +- a closed-system gate synthesis example +- an open-system gate synthesis example + +The notebooks are included automatically in runs of the test suite (see `test_interactive.py`). + +To view and run the notebooks manually, the `jupytext` package is required. +The notebooks can then either be opened from within Jupyter Notebook using "Open With" -> "Jupytext Notebook", or by converting them first to the `ipynb` format using +``` +jupytext --to ipynb [notebook name].nb +``` \ No newline at end of file diff --git a/tests/test_analytical_pulses.py b/tests/test_analytical_pulses.py index 616106b..ecc1417 100644 --- a/tests/test_analytical_pulses.py +++ b/tests/test_analytical_pulses.py @@ -4,11 +4,16 @@ import pytest import qutip as qt -import qutip_jax # noqa: F401 import numpy as np -import jax.numpy as jnp import collections +try: + import jax.numpy as jnp + import qutip_jax # noqa: F401 + _jax_available = True +except ImportError: + _jax_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective @@ -139,52 +144,59 @@ def grad_sin(t, p, idx): ) -# ----------------------- System and JAX Control --------------------- +if _jax_available: + # ----------------------- System and JAX Control --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -Hc_jax = [ - [qt.sigmax(), lambda t, p: sin_jax(t, p)], - [qt.sigmay(), lambda t, q: sin_jax(t, q)], -] + Hc_jax = [ + [qt.sigmax(), lambda t, p: sin_jax(t, p)], + [qt.sigmay(), lambda t, q: sin_jax(t, q)], + ] -H_jax = H_d + Hc_jax + H_jax = H_d + Hc_jax -# ------------------------------- Objective ------------------------------- + # ------------------------------- Objective ------------------------------- -# state to state transfer -state2state_jax = state2state._replace( - objectives=[Objective(initial, H_jax, target)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # state to state transfer + state2state_jax = state2state._replace( + objectives=[Objective(initial, H_jax, target)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# unitary gate synthesis -unitary_jax = unitary._replace( - objectives=[Objective(initial_U, H_jax, target_U)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # unitary gate synthesis + unitary_jax = unitary._replace( + objectives=[Objective(initial_U, H_jax, target_U)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# unitary gate synthesis - time optimization -time_jax = time._replace( - objectives=[Objective(initial_U, H_jax, target_U)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # unitary gate synthesis - time optimization + time_jax = time._replace( + objectives=[Objective(initial_U, H_jax, target_U)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# map synthesis -Lc_jax = [ - [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], - [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], -] -L_jax = L_d + Lc_jax + # map synthesis + Lc_jax = [ + [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], + [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], + ] + L_jax = L_d + Lc_jax -mapping_jax = mapping._replace( - objectives=[Objective(initial_map, L_jax, target_map)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.1}, # relaxed objective -) + mapping_jax = mapping._replace( + objectives=[Objective(initial_map, L_jax, target_map)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.1}, # relaxed objective + ) +else: + # jax not available, set these to none so tests will be skipped + state2state_jax = None + unitary_jax = None + time_jax = None + mapping_jax = None @pytest.fixture( params=[ @@ -205,6 +217,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("JAX not available") + result = optimize_pulses( tst.objectives, tst.control_parameters, diff --git a/tests/test_fidelity.py b/tests/test_fidelity.py index 8490a3e..2c3509d 100644 --- a/tests/test_fidelity.py +++ b/tests/test_fidelity.py @@ -5,9 +5,14 @@ import pytest import qutip as qt import numpy as np -import jax.numpy as jnp import collections +try: + import jax.numpy as jnp + _jax_available = True +except ImportError: + _jax_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective @@ -108,55 +113,64 @@ def grad_sin(t, p, idx): algorithm_kwargs={"alg": "GOAT", "fid_type": "TRACEDIFF"}, ) -# ----------------------- System and JAX Control --------------------- +if _jax_available: + # ----------------------- System and JAX Control --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -Hc_jax = [ - [qt.sigmax(), lambda t, p: sin_jax(t, p)], - [qt.sigmay(), lambda t, q: sin_jax(t, q)], -] + Hc_jax = [ + [qt.sigmax(), lambda t, p: sin_jax(t, p)], + [qt.sigmay(), lambda t, q: sin_jax(t, q)], + ] -H_jax = H_d + Hc_jax + H_jax = H_d + Hc_jax -# ------------------------------- Objective ------------------------------- + # ------------------------------- Objective ------------------------------- -# state to state transfer -PSU_state2state_jax = PSU_state2state._replace( - objectives=[Objective(initial, H_jax, (-1j) * initial)], - algorithm_kwargs={"alg": "JOPT"}, -) + # state to state transfer + PSU_state2state_jax = PSU_state2state._replace( + objectives=[Objective(initial, H_jax, (-1j) * initial)], + algorithm_kwargs={"alg": "JOPT"}, + ) -SU_state2state_jax = SU_state2state._replace( - objectives=[Objective(initial, H_jax, initial)], algorithm_kwargs={"alg": "JOPT"} -) + SU_state2state_jax = SU_state2state._replace( + objectives=[Objective(initial, H_jax, initial)], algorithm_kwargs={"alg": "JOPT"} + ) -# unitary gate synthesis -PSU_unitary_jax = PSU_unitary._replace( - objectives=[Objective(initial_U, H_jax, (-1j) * initial_U)], - algorithm_kwargs={"alg": "JOPT"}, -) + # unitary gate synthesis + PSU_unitary_jax = PSU_unitary._replace( + objectives=[Objective(initial_U, H_jax, (-1j) * initial_U)], + algorithm_kwargs={"alg": "JOPT"}, + ) -SU_unitary_jax = SU_unitary._replace( - objectives=[Objective(initial_U, H_jax, initial_U)], - algorithm_kwargs={"alg": "JOPT"}, -) + SU_unitary_jax = SU_unitary._replace( + objectives=[Objective(initial_U, H_jax, initial_U)], + algorithm_kwargs={"alg": "JOPT"}, + ) -# map synthesis -Lc_jax = [ - [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], - [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], -] -L_jax = L_d + Lc_jax + # map synthesis + Lc_jax = [ + [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], + [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], + ] + L_jax = L_d + Lc_jax -TRCDIFF_map_jax = TRCDIFF_map._replace( - objectives=[Objective(initial_map, L_jax, initial_map)], - algorithm_kwargs={"alg": "JOPT", "fid_type": "TRACEDIFF"}, -) + TRCDIFF_map_jax = TRCDIFF_map._replace( + objectives=[Objective(initial_map, L_jax, initial_map)], + algorithm_kwargs={"alg": "JOPT", "fid_type": "TRACEDIFF"}, + ) + +else: + # jax not available, set these to none so tests will be skipped + PSU_state2state_jax = None + SU_state2state_jax = None + PSU_unitary_jax = None + SU_unitary_jax = None + TRCDIFF_map_jax = None @pytest.fixture( @@ -182,6 +196,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("JAX not available") + result = optimize_pulses( tst.objectives, tst.control_parameters, diff --git a/tests/test_interactive.py b/tests/test_interactive.py new file mode 100644 index 0000000..0961db0 --- /dev/null +++ b/tests/test_interactive.py @@ -0,0 +1,32 @@ +""" +This file contains the test suite for running the interactive test notebooks +in the 'tests/interactive' directory. +Taken, modified from https://github.com/SeldonIO/alibi/blob/master/testing/test_notebooks.py +""" + +import glob +import nbclient.exceptions +import pytest + +from pathlib import Path +from jupytext.cli import jupytext +import nbclient + +# Set of all example notebooks +NOTEBOOK_DIR = 'tests/interactive' +ALL_NOTEBOOKS = { + Path(x).name + for x in glob.glob(str(Path(NOTEBOOK_DIR).joinpath('*.md'))) + if Path(x).name != 'about.md' +} + +@pytest.mark.parametrize("notebook", ALL_NOTEBOOKS) +def test_notebook(notebook): + notebook = Path(NOTEBOOK_DIR, notebook) + try: + jupytext(args=[str(notebook), "--execute"]) + except nbclient.exceptions.CellExecutionError as e: + if e.ename == "Skipped": + pytest.skip(e.evalue) + else: + raise e \ No newline at end of file diff --git a/tests/test_jopt_open_system_bug.py b/tests/test_jopt_open_system_bug.py new file mode 100644 index 0000000..a4c9d26 --- /dev/null +++ b/tests/test_jopt_open_system_bug.py @@ -0,0 +1,39 @@ +import numpy as np +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +from jax import jit, numpy + +def test_open_system_jopt_runs_without_error(): + Hd = qt.Qobj(np.diag([1, 2])) + c_ops = [np.sqrt(0.1) * qt.sigmam()] + Hc = qt.sigmax() + + Ld = qt.liouvillian(H=Hd, c_ops=c_ops) + Lc = qt.liouvillian(Hc) + + initial_state = qt.fock_dm(2, 0) + target_state = qt.fock_dm(2, 1) + + times = np.linspace(0, 2 * np.pi, 250) + + @jit + def sin_x(t, c, **kwargs): + return c[0] * numpy.sin(c[1] * t) + L = [Ld, [Lc, sin_x]] + + guess_params = [1, 0.5] + + res_jopt = optimize_pulses( + objectives = Objective(initial_state, L, target_state), + control_parameters = { + "ctrl_x": {"guess": guess_params, "bounds": [(-1, 1), (0, 2 * np.pi)]} + }, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, + ) + + assert res_jopt.infidelity < 0.25, f"Fidelity error too high: {res_jopt.infidelity}" \ No newline at end of file diff --git a/tests/test_result.py b/tests/test_result.py index 7146b04..407ee49 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -5,10 +5,22 @@ import pytest import qutip as qt import numpy as np -import jax -import jax.numpy as jnp import collections +try: + import jax + import jax.numpy as jnp + _jax_available = True +except ImportError: + _jax_available = False + +try: + import gymnasium + import stable_baselines3 + _rl_available = True +except ImportError: + _rl_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective from qutip_qoc._time import _TimeInterval @@ -84,37 +96,42 @@ def grad_sin(t, p, idx): objectives=[Objective(initial, H, target)], algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01}, ) -# ----------------------- JAX --------------------- +if _jax_available: + # ----------------------- JAX --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -@jax.jit -def sin_x_jax(t, p, **kwargs): - return sin_jax(t, p) + @jax.jit + def sin_x_jax(t, p, **kwargs): + return sin_jax(t, p) -@jax.jit -def sin_y_jax(t, q, **kwargs): - return sin_jax(t, q) + @jax.jit + def sin_y_jax(t, q, **kwargs): + return sin_jax(t, q) -@jax.jit -def sin_z_jax(t, r, **kwargs): - return sin_jax(t, r) + @jax.jit + def sin_z_jax(t, r, **kwargs): + return sin_jax(t, r) -Hc_jax = [[qt.sigmax(), sin_x_jax], [qt.sigmay(), sin_y_jax], [qt.sigmaz(), sin_z_jax]] -H_jax = H_d + Hc_jax + Hc_jax = [[qt.sigmax(), sin_x_jax], [qt.sigmay(), sin_y_jax], [qt.sigmaz(), sin_z_jax]] -# state to state transfer -state2state_jax = state2state_goat._replace( - objectives=[Objective(initial, H_jax, target)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + H_jax = H_d + Hc_jax + + # state to state transfer + state2state_jax = state2state_goat._replace( + objectives=[Objective(initial, H_jax, target)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) + +else: + state2state_jax = None # ------------------- discrete CRAB / GRAPE control ------------------------ @@ -152,52 +169,56 @@ def sin_z_jax(t, r, **kwargs): algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01, "fix_frequency": False}, ) -# ----------------------- RL -------------------- - -# state to state transfer -initial = qt.basis(2, 0) -target = (qt.basis(2, 0) + qt.basis(2, 1)).unit() # |+⟩ - -H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians - -w, d, y = 0.1, 1.0, 0.1 -H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian - -H = [H_d] + H_c # total Hamiltonian - -state2state_rl = Case( - objectives=[Objective(initial, H, target)], - control_parameters={ - "p": {"bounds": [(-13, 13)]}, - }, - tlist=np.linspace(0, 10, 100), - algorithm_kwargs={ - "fid_err_targ": 0.01, - "alg": "RL", - "max_iter": 20000, - "shorter_pulses": True, - }, - optimizer_kwargs={}, -) - -# no big difference for unitary evolution - -initial = qt.qeye(2) # Identity -target = qt.gates.hadamard_transform() +if _rl_available: + # ----------------------- RL -------------------- + + # state to state transfer + initial = qt.basis(2, 0) + target = (qt.basis(2, 0) + qt.basis(2, 1)).unit() # |+⟩ + + H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians + + w, d, y = 0.1, 1.0, 0.1 + H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian + + H = [H_d] + H_c # total Hamiltonian + + state2state_rl = Case( + objectives=[Objective(initial, H, target)], + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + tlist=np.linspace(0, 10, 100), + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 20000, + "shorter_pulses": True, + }, + optimizer_kwargs={}, + ) -unitary_rl = state2state_rl._replace( - objectives=[Objective(initial, H, target)], - control_parameters={ - "p": {"bounds": [(-13, 13)]}, - }, - algorithm_kwargs={ - "fid_err_targ": 0.01, - "alg": "RL", - "max_iter": 300, - "shorter_pulses": True, - }, -) + # no big difference for unitary evolution + + initial = qt.qeye(2) # Identity + target = qt.gates.hadamard_transform() + + unitary_rl = state2state_rl._replace( + objectives=[Objective(initial, H, target)], + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 300, + "shorter_pulses": True, + }, + ) +else: # skip RL tests + state2state_rl = None + unitary_rl = None @pytest.fixture( params=[ @@ -215,6 +236,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("Dependency not available") + result = optimize_pulses( tst.objectives, tst.control_parameters,