diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d99f56f..bc8c88c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,199 +11,50 @@ on: default: "" jobs: - # The deploy_test job is part of the test of whether we should deploy to PyPI - # or test.PyPI. The job will succeed if either the confirmation reference is - # empty, 'test' or if the confirmation is the selected branch or tag name. - # It will fail if it is nonempty and does not match. All later jobs depend - # on this job, so that they will be immediately cancelled if the confirmation - # is bad. The dependency is currently necessary (2021-03) because GitHub - # Actions does not have a simpler method of cancelling an entire workflow--- - # the normal use-case expects to try and run as much as possible despite one - # or two failures. - deploy_test: - name: Verify PyPI deployment confirmation - runs-on: ubuntu-latest - env: - GITHUB_REF: ${{ github.ref }} - CONFIRM_REF: ${{ github.event.inputs.confirm_ref }} - steps: - - name: Compare confirmation to current reference - shell: bash - run: | - [[ -z $CONFIRM_REF || $GITHUB_REF =~ ^refs/(heads|tags)/$CONFIRM_REF$ || $CONFIRM_REF == "test" ]] - if [[ $CONFIRM_REF == "test" ]]; then - echo "Build and deploy to test.pypi.org." - elif [[ -z $CONFIRM_REF ]]; then - echo "Build only. Nothing will be uploaded to PyPI." - else - echo "Full build and deploy. Wheels and source will be uploaded to PyPI." - fi - - build_sdist: - name: Build sdist on Ubuntu - needs: deploy_test - runs-on: ubuntu-latest - env: - OVERRIDE_VERSION: ${{ github.event.inputs.override_version }} - - steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - name: Install Python - with: - # For the sdist we should be as conservative as possible with our - # Python version. This should be the lowest supported version. This - # means that no unsupported syntax can sneak through. - python-version: "3.10" - - - name: Install pip build - run: | - python -m pip install 'build' - - - name: Build sdist tarball - shell: bash - run: | - if [[ ! -z "$OVERRIDE_VERSION" ]]; then echo "$OVERRIDE_VERSION" > VERSION; fi - # The build package is the reference PEP 517 package builder. All - # dependencies are specified by our setup code. - python -m build --sdist . - - # Zip files are not part of PEP 517, so we need to make our own. - - name: Create zipfile from tarball - shell: bash - working-directory: dist - run: | - # First assert that there is exactly one tarball, and find its name. - shopt -s failglob - tarball_pattern="*.tar.gz" - tarballs=($tarball_pattern) - [[ ${#tarballs[@]} == 1 ]] - tarball="${tarballs[0]}" - # Get the stem and make the zipfile name. - stem="${tarball%.tar.gz}" - zipfile="${stem}.zip" - # Extract the tarball and rezip it. - tar -xzf "$tarball" - zip "$zipfile" -r "$stem" - rm -r "$stem" - - - uses: actions/upload-artifact@v3 - with: - name: sdist - path: | - dist/*.tar.gz - dist/*.zip - if-no-files-found: error - - build_wheels: - name: Build wheels on ${{ matrix.os }} - needs: deploy_test + build: + name: Build distribution 📦 runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - env: - # Set up wheels matrix. This is CPython 3.10--3.12 for all OS targets. - CIBW_BUILD: "cp3{10,11,12}-*" - # Numpy and SciPy do not supply wheels for i686 or win32 for - # Python 3.10+, so we skip those: - CIBW_SKIP: "*-musllinux* cp3{10,11,12}-manylinux_i686 cp3{10,11,12}-win32" - OVERRIDE_VERSION: ${{ github.event.inputs.override_version }} + os: [ubuntu-latest, windows-latest, macOS-latest] steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - name: Install Python + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 with: - # This is about the build environment, not the released wheel version. python-version: "3.10" - - name: Install cibuildwheel - run: | - # cibuildwheel does the heavy lifting for us. Originally tested on - # 2.11.3, but should be fine at least up to any minor new release. - python -m pip install 'cibuildwheel==2.11.*' - - - name: Build wheels - shell: bash - run: | - # If the version override was specified, then write it the VERSION - # file with it. - if [[ ! -z "$OVERRIDE_VERSION" ]]; then echo "$OVERRIDE_VERSION" > VERSION; fi - python -m cibuildwheel --output-dir wheelhouse - - - uses: actions/upload-artifact@v3 + run: >- + python3 -m + pip install + cibuildwheel + --user + - name: Build a binary wheel and a source tarball + run: python3 -m cibuildwheel --output-dir dist + - name: Store the distribution packages + uses: actions/upload-artifact@v4 with: - name: wheels - path: ./wheelhouse/*.whl + name: python-package-distributions + path: dist/ - deploy: - name: "Deploy to PyPI if desired" - # The confirmation is tested explicitly in `deploy_test`, so we know it is - # either a missing confirmation (so we shouldn't run this job), 'test' or a - # valid confirmation. We don't need to retest the value of the - # confirmation, beyond checking that one existed. + publish-to-pypi: + name: Publish Python 🐍 distribution 📦 to PyPI if: ${{ github.event.inputs.confirm_ref != '' && github.event.inputs.confirm_ref != 'test' }} - needs: [deploy_test, build_sdist, build_wheels] - runs-on: ubuntu-latest - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} - TWINE_NON_INTERACTIVE: 1 - TWINE_REPOSITORY: pypi - - steps: - - name: Download build artifacts to local runner - uses: actions/download-artifact@v3 - - - uses: actions/setup-python@v4 - name: Install Python - with: - python-version: "3.10" - - - name: Verify this is not a dev version - shell: bash - run: | - python -m pip install wheels/*-cp310-cp310-manylinux*.whl - python -c 'import qutip_qoc; print(qutip_qoc.__version__); assert "dev" not in qutip_qoc.__version__; assert "+" not in qutip_qoc.__version__' - - # We built the zipfile for convenience distributing to Windows users on - # our end, but PyPI only needs the tarball. - - name: Upload sdist and wheels to PyPI - run: | - python -m pip install "twine" - python -m twine upload --verbose wheels/*.whl sdist/*.tar.gz - - deploy_testpypi: - name: "Deploy to TestPyPI if desired" - if: ${{ github.event.inputs.confirm_ref == 'test' }} - needs: [deploy_test, build_sdist, build_wheels] + needs: + - build runs-on: ubuntu-latest - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TESTPYPI_TOKEN }} - TWINE_NON_INTERACTIVE: 1 + environment: + name: pypi + url: https://pypi.org/p/qutip-qoc + permissions: + id-token: write steps: - - name: Download build artifacts to local runner - uses: actions/download-artifact@v3 - - - uses: actions/setup-python@v4 - name: Install Python + - name: Download all the dists + uses: actions/download-artifact@v4 with: - python-version: "3.10" - - - name: Verify this is not a dev version - shell: bash - run: | - python -m pip install wheels/*-cp310-cp310-manylinux*.whl - python -c 'import qutip_qoc; print(qutip_qoc.__version__); assert "dev" not in qutip_qoc.__version__; assert "+" not in qutip_qoc.__version__' - - # We built the zipfile for convenience distributing to Windows users on - # our end, but PyPI only needs the tarball. - - name: Upload sdist and wheels to TestPyPI - run: | - python -m pip install "twine" - python -m twine upload --repository testpypi --verbose wheels/*.whl sdist/*.tar.gz + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index f2c45e2..dd1950f 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 name: Install Python @@ -42,7 +42,7 @@ jobs: # -T : display a full traceback if a Python exception occurs - name: Upload built files - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: qutip_qoc_html_docs path: doc/_build/html/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 69e1b03..14185e3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,7 +47,7 @@ jobs: python-version: "3.12" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bcf23a..0c027ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -92,3 +92,5 @@ repos: - jaxlib - diffrax - pytest + - gymnasium + - stable-baselines3 diff --git a/MANIFEST.in b/MANIFEST.in index 4681b1f..7eac3bf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include README.md +include VERSION include LICENSE include requirements.txt include pyproject.toml diff --git a/README.md b/README.md index 9dde70f..b42d479 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,12 @@ To install the package, use pip install qutip-qoc ``` +By default, the dependencies required for JOPT and for the RL (reinforcement learning) algorithm are omitted. +They can be included by using the targets `qutip-qoc[jopt]` and `qutip-qoc[rl]`, respectively (or `qutip-qoc[full]` for all dependencies). + ## Documentation and tutorials -The documentation of `qutip-qoc` updated to the latest development version is hosted at [qutip-qoc.readthedocs.io](https://qutip-qoc.readthedocs.io/en/latest/). +The documentation of `qutip-qoc` updated to the latest development version is hosted at [qutip-qoc.readthedocs.io](https://qutip-qoc.readthedocs.io/latest/). Tutorials related to using quantum optimal control in `qutip-qoc` can be found [_here_](https://qutip.org/qutip-tutorials/#optimal-control). ## Installation from source @@ -40,7 +43,7 @@ pip install --upgrade pip pip install -e . ``` -which makes sure that you are up to date with the latest `pip` version. Contribution guidelines are available [_here_](https://qutip-qoc.readthedocs.io/en/latest/contribution-code.html). +which makes sure that you are up to date with the latest `pip` version. Contribution guidelines are available [_here_](https://qutip-qoc.readthedocs.io/latest/contribution/code.html). To build and test the documentation, additional packages need to be installed: diff --git a/VERSION b/VERSION index 8acdd82..7b32b52 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1 +0.1.0.dev diff --git a/doc/changelog.rst b/doc/changelog.rst index f2c5afe..2e869b0 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -3,6 +3,30 @@ Changelog ********* +Version 0.0.2 (Oct 04, 2024) ++++++++++++++++++++++++++++++++++ + +This is an update to the beta release of ``qutip-qoc``. + +It mainly introduces the new reinforcement learning algorithm ``qutip_qoc._rl``. + +- Non-public facing functions have been renamed to start with an underscore. +- As with other QuTiP functions, ``optimize_pulses`` now takes a ``tlist`` argument instead of ``_TimeInterval``. +- The structure for the control guess and bounds has changed and now takes in an optional ``__time__`` keyword. +- The ``result`` does no longer return ``optimized_objectives`` but instead ``optimized_H``. + +Features +-------- + +- New reinforcement learning algorithm, developed during GSOC24 (#19, #18, by LegionAtol) +- Automatic transfromation of initial and target operator to superoperator (#23, by flowerthrower) + +Bug Fixes +--------- + +- Prevent loss of `__time__` keyword in optimize_pulses (#22, by flowerthrower) + + Version 0.0.1 (May xx, 2024) +++++++++++++++++++++++++++++++++ diff --git a/doc/conf.py b/doc/conf.py index ab65dc7..96ca014 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -62,6 +62,10 @@ def qutip_qoc_version(): html_theme = "sphinx_rtd_theme" html_static_path = [] +html_js_files = [ + 'https://code.jquery.com/jquery-3.6.0.min.js', +] + # -- Options for numpydoc --------------------------------------- diff --git a/doc/contribution/code.rst b/doc/contribution/code.rst index 8a961b8..6cf0705 100644 --- a/doc/contribution/code.rst +++ b/doc/contribution/code.rst @@ -7,7 +7,7 @@ Contributing to the source code Build up an development environment =================================== -Please follow the instruction on the `QuTiP contribution guide `_ to +Please follow the instruction on the `QuTiP contribution guide `_ to build a conda environment. You don't need to build ``qutip`` in the editable mode unless you also want to contribute to `qutip`. diff --git a/doc/guide/guide-control.rst b/doc/guide/guide-control.rst index b2029cd..915b6b6 100644 --- a/doc/guide/guide-control.rst +++ b/doc/guide/guide-control.rst @@ -195,6 +195,44 @@ experimental systematic noise, ...) can be done all in one, using this algorithm. +The RL Algorithm +================ +Reinforcement Learning (RL) represents a different approach compared to traditional +quantum control methods, such as GRAPE and CRAB. Instead of relying on gradients or +prior knowledge of the system, RL uses an agent that autonomously learns to optimize +control policies by interacting with the quantum environment. + +The RL algorithm consists of three main components: + +**Agent**: The RL agent is responsible for making decisions regarding control +parameters at each time step. The agent observes the current state of the quantum +system and chooses an action (i.e., a set of control parameters) based on the current policy. +**Environment**: The environment represents the quantum system that evolves over time. +The environment is defined by the system's dynamics, which include drift and control Hamiltonians. +Each action chosen by the agent induces a response in the environment, which manifests as an +evolution of the system's state. From this, a reward can be derived. +**Reward**: The reward is a measure of how much the action chosen by the agent brings the +quantum system closer to the desired objective. In this context, the objective could be the +preparation of a specific state, state-to-state transfer, or the synthesis of a quantum gate. + +Each interaction between the agent and the environment defines a step. +A sequence of steps forms an episode. The episode ends when certain conditions, such as reaching +a specific fidelity, are met. +The reward function is a crucial component of the RL algorithm, carefully designed to +reflect the objective of the quantum control problem. +It guides the algorithm in updating its policy to maximize the reward obtained during the various +training episodes. +For example, in a state-to-state transfer problem, the reward is based on the fidelity +between the achieved final state and the desired target state. +In addition, a constant penalty term is subtracted in order to encourages the agent to reach the +objective in as few steps as possible. + +In QuTiP, the RL environment is modeled as a custom class derived from the gymnasium library. +This class allows defining the quantum system's dynamics at each step, the actions the agent +can take, the observation space, and so on. The RL agent is trained using the Proximal Policy Optimization +(PPO) algorithm from the stable baselines3 library. + + Optimal Quantum Control in QuTiP ================================ Defining a control problem with QuTiP is very easy. diff --git a/doc/installation.rst b/doc/installation.rst index f619143..5754564 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -23,7 +23,19 @@ In particular, the following packages are necessary for running ``qutip-qoc``: .. code-block:: bash - numpy scipy jax jaxlib cython qutip qutip-jax qutip-qtrl + numpy scipy cython qutip qutip-qtrl + +The following packages are required for using the JOPT algorithm: + +.. code-block:: bash + + jax jaxlib qutip-jax + +The following packages are required for the RL (reinforcement learning) algorithm: + +.. code-block:: bash + + gymnasium stable-baselines3 The following package is used for testing: diff --git a/doc/requirements.txt b/doc/requirements.txt index 48e2d81..3bc7ba6 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.16.6 +numpy>=1.16.6,<2.0 scipy>=1.10.1 jax==0.4.28 jaxlib==0.4.28 @@ -9,18 +9,18 @@ matplotlib>=3.6.1 docutils==0.18.1 alabaster==0.7.12 Babel==2.9.1 -certifi==2020.12.5 +certifi==2024.7.4 chardet==4.0.0 colorama==0.4.4 -idna==2.10 +idna==3.7 imagesize==1.4.1 -Jinja2==3.0.1 +Jinja2==3.1.6 MarkupSafe==2.0.1 packaging==23.2 Pygments==2.17.2 pyparsing==2.4.7 pytz==2021.1 -requests==2.25.1 +requests==2.32.2 snowballstemmer==2.1.0 Sphinx==6.1.3 sphinx-gallery==0.12.2 @@ -32,4 +32,4 @@ sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 -urllib3==1.26.4 +urllib3==1.26.19 diff --git a/doc/rtd-environment.yml b/doc/rtd-environment.yml index d2e31c3..1d8a3db 100644 --- a/doc/rtd-environment.yml +++ b/doc/rtd-environment.yml @@ -2,7 +2,7 @@ name: rtd-environment channels: - conda-forge dependencies: - - numpy>=1.16.6 + - numpy>=1.16.6,<2.0 - scipy>=1.10.1 - cython>=0.29.33 - sphinx==6.1.3 diff --git a/requirements.txt b/requirements.txt index 70eaeb7..7c55daa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,6 @@ cython>=1.0 -numpy>=1.16.6 +numpy>=1.16.6,<2.0 scipy>=1.10.1 -jax>=0.4.23 -jaxlib>=0.4.23 qutip>=5.0.1 -qutip-qtrl @ git+https://github.com/qutip/qutip-qtrl.git@master -qutip-jax @ git+https://github.com/qutip/qutip-jax.git@master +qutip-qtrl pre-commit diff --git a/setup.cfg b/setup.cfg index 1bb2a96..04fe89b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ keywords = quantum, physics, dynamics license = BSD 3-Clause License license_files = LICENSE classifiers = - Development Status :: 2 - Pre-Alpha + Development Status :: 4 - Beta Intended Audience :: Science/Research License :: OSI Approved :: BSD License Programming Language :: Python @@ -28,12 +28,10 @@ package_dir= packages = find: include_package_data = True install_requires = - jax - jaxlib packaging qutip - qutip-qtrl @ git+https://github.com/qutip/qutip-qtrl.git@master - qutip-jax @ git+https://github.com/qutip/qutip-jax.git@master + qutip-qtrl + numpy>=1.16.6,<2.0 setup_requires = cython>=1.0 packaging @@ -41,8 +39,24 @@ setup_requires = [options.packages.find] where = src +[options.entry_points] +qutip.family = + qutip_qoc = qutip_qoc.family + [options.extras_require] +jopt = + jax + jaxlib + qutip-jax +rl = + gymnasium>=0.29.1 + stable-baselines3>=2.3.2 tests = pytest>=6.0 + jupytext + nbconvert + ipykernel full = + %(jopt)s + %(rl)s %(tests)s diff --git a/src/qutip_qoc/_goat.py b/src/qutip_qoc/_goat.py index 007388c..617e660 100644 --- a/src/qutip_qoc/_goat.py +++ b/src/qutip_qoc/_goat.py @@ -65,7 +65,9 @@ def __init__( self._var_t = "guess" in time_options # num of params for each control function - self._para_counts = [len(v["guess"]) for v in control_parameters.values()] + self._para_counts = [ + len(v["guess"]) for k, v in control_parameters.items() if k != "__time__" + ] # inferred attributes self._tot_n_para = sum(self._para_counts) # excl. time diff --git a/src/qutip_qoc/_jopt.py b/src/qutip_qoc/_jopt.py index 0e94491..6c9cfe2 100644 --- a/src/qutip_qoc/_jopt.py +++ b/src/qutip_qoc/_jopt.py @@ -5,38 +5,47 @@ import qutip as qt from qutip import Qobj, QobjEvo -from diffrax import Dopri5, PIDController +try: + import jax + from jax import custom_jvp + import jax.numpy as jnp + import qutip_jax # noqa: F401 -import jax -from jax import custom_jvp -import jax.numpy as jnp -import qutip_jax # noqa: F401 + import jaxlib # noqa: F401 + from diffrax import Dopri5, PIDController -@custom_jvp -def _abs(x): - return jnp.abs(x) + _jax_available = True +except ImportError: + _jax_available = False -def _abs_jvp(primals, tangents): - """ - Custom jvp for absolute value of complex functions - """ - (x,) = primals - (t,) = tangents +if _jax_available: + + @custom_jvp + def _abs(x): + return jnp.abs(x) + - abs_x = _abs(x) - res = jnp.where( - abs_x == 0, - 0.0, # prevent division by zero - jnp.real(jnp.multiply(jnp.conj(x), t)) / abs_x, - ) + def _abs_jvp(primals, tangents): + """ + Custom jvp for absolute value of complex functions + """ + (x,) = primals + (t,) = tangents + + abs_x = _abs(x) + res = jnp.where( + abs_x == 0, + 0.0, # prevent division by zero + jnp.real(jnp.multiply(jnp.conj(x), t)) / abs_x, + ) - return abs_x, res + return abs_x, res -# register custom jvp for absolut value of complex functions -_abs.defjvp(_abs_jvp) + # register custom jvp for absolut value of complex functions + _abs.defjvp(_abs_jvp) class _JOPT: @@ -55,6 +64,10 @@ def __init__( guess_params, **integrator_kwargs, ): + if not _jax_available: + raise ImportError("The JOPT algorithm requires the modules jax, " + "jaxlib, and qutip_jax to be installed.") + self._Hd = objective.H[0] self._Hc_lst = objective.H[1:] @@ -137,8 +150,8 @@ def _infid(self, params): if self._fid_type == "TRACEDIFF": diff = X - self._target # to prevent if/else in qobj.dag() and qobj.tr() - diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims) - g = 1 / 2 * (diff_dag * diff).data.trace() + diff_dag = diff.dag() # direct access to JAX array, no fallback! + g = 1 / 2 * jnp.trace(diff_dag.data._jxa @ diff.data._jxa) infid = jnp.real(self._norm_fac * g) else: g = self._norm_fac * self._target.overlap(X) @@ -147,4 +160,4 @@ def _infid(self, params): elif self._fid_type == "SU": # f_SU (incl global phase) infid = 1 - jnp.real(g) - return infid + return infid \ No newline at end of file diff --git a/src/qutip_qoc/_optimizer.py b/src/qutip_qoc/_optimizer.py index e5b12f5..8218d51 100644 --- a/src/qutip_qoc/_optimizer.py +++ b/src/qutip_qoc/_optimizer.py @@ -283,9 +283,6 @@ def _global_local_optimization( _get_init_and_bounds_from_options(x0, control_parameters[key].get("guess")) _get_init_and_bounds_from_options(bounds, control_parameters[key].get("bounds")) - _get_init_and_bounds_from_options(x0, time_options.get("guess", None)) - _get_init_and_bounds_from_options(bounds, time_options.get("bounds", None)) - optimizer_kwargs["x0"] = np.concatenate(x0) multi_objective = _MultiObjective( diff --git a/src/qutip_qoc/_rl.py b/src/qutip_qoc/_rl.py new file mode 100644 index 0000000..21e8cb1 --- /dev/null +++ b/src/qutip_qoc/_rl.py @@ -0,0 +1,385 @@ +""" +This module contains functions that implement quantum optimal control +using reinforcement learning (RL) techniques, allowing for the optimization +of control pulse sequences in quantum systems. +""" +import qutip as qt +from qutip import Qobj +from qutip_qoc import Result + +import numpy as np + +import gymnasium as gym +from gymnasium import spaces +from stable_baselines3 import PPO +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.callbacks import BaseCallback + +import time + + +class _RL(gym.Env): + """ + Class for storing a control problem and implementing quantum optimal + control using reinforcement learning. This class defines a custom + Gym environment that models the dynamics of quantum systems + under various control pulses, and uses RL algorithms to optimize the + parameters of these pulses. + """ + + def __init__( + self, + objectives, + control_parameters, + time_interval, + time_options, + alg_kwargs, + optimizer_kwargs, + minimizer_kwargs, + integrator_kwargs, + qtrl_optimizers, + ): + """ + Initialize the reinforcement learning environment for quantum + optimal control. Sets up the system Hamiltonian, control parameters, + and defines the observation and action spaces for the RL agent. + """ + + super(_RL, self).__init__() + + self._Hd_lst, self._Hc_lst = [], [] + for objective in objectives: + # extract drift and control Hamiltonians from the objective + self._Hd_lst.append(objective.H[0]) + self._Hc_lst.append( + [H[0] if isinstance(H, list) else H for H in objective.H[1:]] + ) + + def create_pulse_func(idx): + """ + Create a control pulse lambda function for a given index. + """ + return lambda t, args: self._pulse(t, args, idx + 1) + + # create the QobjEvo with Hd, Hc and controls(args) + self._H_lst = [self._Hd_lst[0]] + dummy_args = {f"alpha{i+1}": 1.0 for i in range(len(self._Hc_lst[0]))} + for i, Hc in enumerate(self._Hc_lst[0]): + self._H_lst.append([Hc, create_pulse_func(i)]) + self._H = qt.QobjEvo(self._H_lst, args=dummy_args) + + self.shorter_pulses = alg_kwargs.get( + "shorter_pulses", False + ) # lengthen the training to look for pulses of shorter duration, therefore episodes with fewer steps + + # extract bounds for control_parameters + bounds = [] + for key in control_parameters.keys(): + bounds.append(control_parameters[key].get("bounds")) + self._lbound = [b[0][0] for b in bounds] + self._ubound = [b[0][1] for b in bounds] + + self._alg_kwargs = alg_kwargs + + self._initial = objectives[0].initial + self._target = objectives[0].target + self._state = None + self._dim = self._initial.shape[0] + + self._result = Result( + objectives=objectives, + time_interval=time_interval, + start_local_time=time.time(), # initial optimization time + n_iters=0, # Number of iterations(episodes) until convergence + iter_seconds=[], # list containing the time taken for each iteration(episode) of the optimization + var_time=True, # Whether the optimization was performed with variable time + guess_params=[], + ) + + self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid= self.max_steps + ) # if the episode ended without reaching the goal + + observation = self._get_obs() + return observation, reward, bool(self.terminated), bool(self.truncated), {} + + def _get_obs(self): + """ + Get the current state observation for the RL agent. Converts the system's + quantum state or matrix into a real-valued NumPy array suitable for RL algorithms. + """ + rho = self._state.full().flatten() + obs = np.concatenate((np.real(rho), np.imag(rho))) + return obs.astype( + np.float32 + ) # Gymnasium expects the observation to be of type float32 + + def reset(self, seed=None): + """ + Reset the environment to the initial state, preparing for a new episode. + """ + self._save_episode_info() + + time_diff = self._episode_info[-1]["elapsed_time"] - ( + self._episode_info[-2]["elapsed_time"] + if len(self._episode_info) > 1 + else self._result.start_local_time + ) + self._result.iter_seconds.append(time_diff) + self._current_step = 0 # Reset the step counter + self.current_episode += 1 # Increment episode counter + self._actions = self._temp_actions.copy() + self.terminated = False + self.truncated = False + self._temp_actions = [] + self._result._final_states = [self._state] + self._state = self._initial + return self._get_obs(), {} + + def _save_result(self): + """ + Save the results of the optimization process, including the optimized + pulse sequences, final states, and performance metrics. + """ + result_obj = self._backup_result if self._use_backup_result else self._result + + if self._use_backup_result: + self._backup_result.iter_seconds = self._result.iter_seconds.copy() + self._backup_result._final_states = self._result._final_states.copy() + self._backup_result.infidelity = self._result.infidelity + + result_obj.end_local_time = time.time() + result_obj.n_iters = len(self._result.iter_seconds) + result_obj.optimized_params = self._actions.copy() + [ + self._result.total_seconds + ] # If var_time is True, the last parameter is the evolution time + result_obj._optimized_controls = self._actions.copy() + result_obj._guess_controls = [] + result_obj._optimized_H = [self._H] + + def result(self): + """ + Final conversions and return of optimization results + """ + if self._use_backup_result: + self._backup_result.start_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(self._backup_result.start_local_time) + ) + self._backup_result.end_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(self._backup_result.end_local_time) + ) + return self._backup_result + else: + self._save_result() + self._result.start_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(self._result.start_local_time) + ) + self._result.end_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(self._result.end_local_time) + ) + return self._result + + def train(self): + """ + Train the RL agent on the defined quantum control problem using the specified + reinforcement learning algorithm. Checks environment compatibility with Gym API. + """ + # Check if the environment follows Gym API + check_env(self, warn=True) + + # Create the model + model = PPO( + "MlpPolicy", self, verbose=1 + ) # verbose = 1 to display training progress and statistics in the terminal + + stop_callback = EarlyStopTraining(verbose=1) + + # Train the model + model.learn(total_timesteps=self._total_timesteps, callback=stop_callback) + + +class EarlyStopTraining(BaseCallback): + """ + A callback to stop training based on specific conditions (steps, infidelity, max iterations) + """ + + def __init__(self, verbose: int = 0): + super(EarlyStopTraining, self).__init__(verbose) + + def _on_step(self) -> bool: + """ + This method is required by the BaseCallback class. We use it to stop the training. + - Stop training if the maximum number of episodes is reached. + - Stop training if it finds an episode with infidelity <= than target infidelity + - If all of the last 100 episodes have infidelity below the target and use the same number of steps, stop training. + """ + env = self.training_env.get_attr("unwrapped")[0] + + # Check if we need to stop training + if env.current_episode >= env.max_episodes: + if env._use_backup_result is True: + env._backup_result.message = f"Reached {env.max_episodes} episodes, stopping training. Return the last founded episode with infid < target_infid" + else: + env._result.message = ( + f"Reached {env.max_episodes} episodes, stopping training." + ) + return False # Stop training + elif (env._result.infidelity <= env._fid_err_targ) and not (env.shorter_pulses): + env._result.message = "Stop training because an episode with infidelity <= target infidelity was found" + return False # Stop training + elif env.shorter_pulses: + if ( + env._result.infidelity <= env._fid_err_targ + ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime + env._use_backup_result = True + env._save_result() + if len(env._episode_info) >= 100: + last_100_episodes = env._episode_info[-100:] + + min_steps = min(info["steps_used"] for info in last_100_episodes) + steps_condition = all( + ep["steps_used"] == min_steps for ep in last_100_episodes + ) + infid_condition = all( + ep["final_infidelity"] <= env._fid_err_targ + for ep in last_100_episodes + ) + + if steps_condition and infid_condition: + env._use_backup_result = False + env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid." + return False # Stop training + return True # Continue training diff --git a/src/qutip_qoc/family.py b/src/qutip_qoc/family.py new file mode 100644 index 0000000..d76645a --- /dev/null +++ b/src/qutip_qoc/family.py @@ -0,0 +1,8 @@ +"""QuTiP family package entry point.""" + +from . import __version__ + + +def version(): + """Return information to include in qutip.about().""" + return "qutip-qoc", __version__ diff --git a/src/qutip_qoc/pulse_optim.py b/src/qutip_qoc/pulse_optim.py index 45866be..fc44231 100644 --- a/src/qutip_qoc/pulse_optim.py +++ b/src/qutip_qoc/pulse_optim.py @@ -1,7 +1,7 @@ """ This module is the entry point for the optimization of control pulses. It provides the function `optimize_pulses` which prepares and runs the -GOAT, JOPT, GRAPE or CRAB optimization. +GOAT, JOPT, GRAPE, CRAB or RL optimization. """ import numpy as np @@ -11,6 +11,14 @@ from qutip_qoc._optimizer import _global_local_optimization from qutip_qoc._time import _TimeInterval +import qutip as qt + +try: + from qutip_qoc._rl import _RL + _rl_available = True +except ImportError: + _rl_available = False + __all__ = ["optimize_pulses"] @@ -22,9 +30,10 @@ def optimize_pulses( optimizer_kwargs=None, minimizer_kwargs=None, integrator_kwargs=None, + optimization_type=None, ): """ - Run GOAT, JOPT, GRAPE or CRAB optimization. + Run GOAT, JOPT, GRAPE, CRAB or RL optimization. Parameters ---------- @@ -40,6 +49,7 @@ def optimize_pulses( control_id : dict - guess: ndarray, shape (n,) + For RL you don't need to specify the guess. Initial guess. Array of real elements of size (n,), where ``n`` is the number of independent variables. @@ -48,7 +58,7 @@ def optimize_pulses( `guess`. None is used to specify no bound. __time__ : dict, optional - Only supported by GOAT and JOPT. + Only supported by GOAT, JOPT (for RL use `algorithm_kwargs: 'shorter_pulses'`). If given the pulse duration is treated as optimization parameter. It must specify both: @@ -70,14 +80,15 @@ def optimize_pulses( - alg : str Algorithm to use for the optimization. - Supported are: "GRAPE", "CRAB", "GOAT", "JOPT". + Supported are: "GRAPE", "CRAB", "GOAT", "JOPT" and "RL". - fid_err_targ : float, optional Fidelity error target for the optimization. - max_iter : int, optional Maximum number of iterations to perform. - Referes to local minimizer steps. + Referes to local minimizer steps or in the context of + `alg: "RL"` to the max. number of episodes. Global steps default to 0 (no global optimization). Can be overridden by specifying in minimizer_kwargs. @@ -116,6 +127,11 @@ def optimize_pulses( Options for the solver, see :obj:`MESolver.options` and `Integrator <./classes.html#classes-ode>`_ for a list of all options. + optimization_type : str, optional + Type of optimization. By default, QuTiP-QOC will try to automatically determine + whether this is a *state transfer* or a *gate synthesis* problem. Set this + flag to ``"state_transfer"`` or ``"gate_synthesis"`` to set the mode manually. + Returns ------- result : :class:`qutip_qoc.Result` @@ -133,7 +149,7 @@ def optimize_pulses( # create time interval time_interval = _TimeInterval(tslots=tlist) - time_options = control_parameters.pop("__time__", {}) + time_options = control_parameters.get("__time__", {}) if time_options: # convert to list of bounds if not already if not isinstance(time_options["bounds"][0], (list, tuple)): time_options["bounds"] = [time_options["bounds"]] @@ -151,8 +167,9 @@ def optimize_pulses( # extract guess and bounds for the control pulses x0, bounds = [], [] for key in control_parameters.keys(): - x0.append(control_parameters[key].get("guess")) - bounds.append(control_parameters[key].get("bounds")) + if key != "__time__": + x0.append(control_parameters[key].get("guess")) + bounds.append(control_parameters[key].get("bounds")) try: # GRAPE, CRAB format lbound = [b[0][0] for b in bounds] ubound = [b[0][1] for b in bounds] @@ -179,9 +196,63 @@ def optimize_pulses( "gtol": algorithm_kwargs.get("min_grad", 0.0 if alg == "CRAB" else 1e-8), } + # Iterate over objectives and convert initial and target states based on the optimization type + for objective in objectives: + H_list = objective.H if isinstance(objective.H, list) else [objective.H] + if any(qt.issuper(H_i) for H_i in H_list): + if isinstance(optimization_type, str) and optimization_type.lower() == "state_transfer": + if qt.isket(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + elif qt.isoper(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(objective.initial) + + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) + elif qt.isoper(objective.target): + objective.target = qt.operator_to_vector(objective.target) + + algorithm_kwargs.setdefault("fid_params", {}) + algorithm_kwargs["fid_params"].setdefault("scale_factor", 1.0 / dim) + + elif isinstance(optimization_type, str) and optimization_type.lower() == "gate_synthesis": + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + + elif optimization_type is None: + is_state_transfer = False + if qt.isoper(objective.initial) and qt.isoper(objective.target): + if np.isclose(objective.initial.tr(), 1) and np.isclose(objective.target.tr(), 1): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(objective.initial) + objective.target = qt.operator_to_vector(objective.target) + is_state_transfer = True + else: + objective.initial = qt.to_super(objective.initial) + objective.target = qt.to_super(objective.target) + + if qt.isket(objective.initial): + dim = objective.initial.shape[0] + objective.initial = qt.operator_to_vector(qt.ket2dm(objective.initial)) + is_state_transfer = True + + if qt.isket(objective.target): + objective.target = qt.operator_to_vector(qt.ket2dm(objective.target)) + is_state_transfer = True + + if is_state_transfer: + algorithm_kwargs.setdefault("fid_params", {}) + algorithm_kwargs["fid_params"].setdefault("scale_factor", 1.0 / dim) + # prepare qtrl optimizers qtrl_optimizers = [] if alg == "CRAB" or alg == "GRAPE": + dyn_type = "GEN_MAT" + for objective in objectives: + if any(qt.isoper(H_i) for H_i in (objective.H if isinstance(objective.H, list) else [objective.H])): + dyn_type = "UNIT" + if alg == "GRAPE": # algorithm specific kwargs use_as_amps = True minimizer_kwargs.setdefault("method", "L-BFGS-B") # gradient @@ -238,7 +309,7 @@ def optimize_pulses( "accuracy_factor": None, # deprecated "alg_params": alg_params, "optim_params": algorithm_kwargs.get("optim_params", None), - "dyn_type": algorithm_kwargs.get("dyn_type", "GEN_MAT"), + "dyn_type": algorithm_kwargs.get("dyn_type", dyn_type), "dyn_params": algorithm_kwargs.get("dyn_params", None), "prop_type": algorithm_kwargs.get( "prop_type", "DEF" @@ -348,6 +419,27 @@ def optimize_pulses( qtrl_optimizers.append(qtrl_optimizer) + elif alg == "RL": + if not _rl_available: + raise ImportError( + "The required dependencies (gymnasium, stable-baselines3) for " + "the reinforcement learning algorithm are not available." + ) + + rl_env = _RL( + objectives, + control_parameters, + time_interval, + time_options, + algorithm_kwargs, + optimizer_kwargs, + minimizer_kwargs, + integrator_kwargs, + qtrl_optimizers, + ) + rl_env.train() + return rl_env.result() + return _global_local_optimization( objectives, control_parameters, diff --git a/src/qutip_qoc/result.py b/src/qutip_qoc/result.py index 493104a..7738aa9 100644 --- a/src/qutip_qoc/result.py +++ b/src/qutip_qoc/result.py @@ -2,7 +2,6 @@ This module contains the Result class for storing and reporting the results of a full pulse control optimization run. """ -import jaxlib import pickle import textwrap import numpy as np @@ -11,6 +10,13 @@ import qutip as qt +try: + import jax + import jaxlib + _jitfun_type = type(jax.jit(lambda x: x)) +except ImportError: + _jitfun_type = None + __all__ = ["Result"] @@ -331,8 +337,10 @@ def final_states(self): evo_time = self.time_interval.evo_time # choose solver method based on type of control function - if isinstance( - self.objectives[0].H[1][1], jaxlib.xla_extension.PjitFunction + # if jax is installed, _jitfun_type is set to + # jaxlib.xla_extension.PjitFunction, otherwise it is None + if _jitfun_type is not None and isinstance( + self.objectives[0].H[1][1], _jitfun_type ): method = "diffrax" # for JAX defined contols else: diff --git a/tests/interactive/CRAB_gate_closed.md b/tests/interactive/CRAB_gate_closed.md new file mode 100644 index 0000000..ea36556 --- /dev/null +++ b/tests/interactive/CRAB_gate_closed.md @@ -0,0 +1,102 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# CRAB algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] +H = [Hd, Hc[0], Hc[1], Hc[2]] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## CRAB algorithm + +```python +n_params = 3 # adjust in steps of 3 +control_params = { + "ctrl_x": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, + "ctrl_y": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, + "ctrl_z": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, +} + +res_crab = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "CRAB", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_crab.infidelity) + +plt.plot(times, res_crab.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_crab.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_crab.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('CRAB pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc[0], res_crab.optimized_controls[0]], [Hc[1], res_crab.optimized_controls[1]], [Hc[2], res_crab.optimized_controls[2]]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('CRAB performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_crab.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/CRAB_state_closed.md b/tests/interactive/CRAB_state_closed.md new file mode 100644 index 0000000..278ab21 --- /dev/null +++ b/tests/interactive/CRAB_state_closed.md @@ -0,0 +1,95 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# CRAB algorithm for 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## CRAB algorithm + +```python +n_params = 6 # adjust in steps of 3 +control_params = { + "ctrl_x": {"guess": [1 for _ in range(n_params)], "bounds": [(-1, 1)] * n_params}, +} + +res_crab = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "CRAB", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_crab.infidelity) + +plt.plot(times, res_crab.optimized_controls[0], label='optimized pulse') +plt.title('CRAB pulse') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_crab.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title("CRAB performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_crab.infidelity < 0.001 +assert np.abs(evolution.states[-1].overlap(target_state)) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GOAT_gate_closed.md b/tests/interactive/GOAT_gate_closed.md new file mode 100644 index 0000000..4eafa97 --- /dev/null +++ b/tests/interactive/GOAT_gate_closed.md @@ -0,0 +1,280 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GOAT algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +goat_guess = [1, 1] +guess_pulse = goat_guess[0] * np.sin(goat_guess[1] * times) + +H_guess = [Hd] + [[hc, guess_pulse] for hc in Hc] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GOAT algorithm + +```python +# control function +def sin(t, c): + return c[0] * np.sin(c[1] * t) + +# derivatives +def grad_sin(t, c, idx): + if idx == 0: # w.r.t. c0 + return np.sin(c[1] * t) + if idx == 1: # w.r.t. c1 + return c[0] * np.cos(c[1] * t) * t + if idx == 2: # w.r.t. time + return c[0] * np.cos(c[1] * t) * c[1] + +H = [Hd] + [[hc, sin, {"grad": grad_sin}] for hc in Hc] +``` + +### a) not optimized over time + +```python +control_params = { + id: {"guess": goat_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 + for id in ['x', 'y', 'z'] +} + +# run the optimization +res_goat = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_goat.infidelity) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times, res_goat.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_goat.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_goat.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, + [Hc[0], np.array(res_goat.optimized_controls[0])], + [Hc[1], np.array(res_goat.optimized_controls[1])], + [Hc[2], np.array(res_goat.optimized_controls[2])]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('GOAT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_goat_time = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_goat_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_goat_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], 'b', label='optimized pulse sx') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[1])[time_range], 'g', label='optimized pulse sy') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[2])[time_range], 'r', label='optimized pulse sz') +plt.title('GOAT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_goat_time.optimized_controls[0])], + [Hc[1], np.array(res_goat_time.optimized_controls[1])], + [Hc[2], np.array(res_goat_time.optimized_controls[2])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_gate, times2) + +plt.plot(times2, [fidelity(gate, initial_gate) for gate in evolution_time.states], label="Overlap with initial gate") +plt.plot(times2, [fidelity(gate, target_gate) for gate in evolution_time.states], label="Overlap with target gate") + +plt.title('GOAT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Global optimization + +```python +res_goat_global = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001, + }, + optimizer_kwargs = { + "method": "basinhopping", + "max_iter": 100, + } +) + +global_time = res_goat_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_goat_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], 'b', label='optimized pulse sx') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[1])[global_range], 'g', label='optimized pulse sy') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[2])[global_range], 'r', label='optimized pulse sz') +plt.title('GOAT pulses (global optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_goat_global.optimized_controls[0])], + [Hc[1], np.array(res_goat_global.optimized_controls[1])], + [Hc[2], np.array(res_goat_global.optimized_controls[2])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_gate, times3) + +plt.plot(times3, [fidelity(gate, initial_gate) for gate in evolution_global.states], label="Overlap with initial gate") +plt.plot(times3, [fidelity(gate, target_gate) for gate in evolution_global.states], label="Overlap with target gate") + +plt.title('GOAT (global optimization) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +fig, axes = plt.subplots(1, 3, figsize=(18, 4)) # 1 row, 3 columns + +titles = ["GOAT sx pulses", "GOAT sy pulses", "GOAT sz pulses"] + +for i, ax in enumerate(axes): + ax.plot(times, guess_pulse, label='initial guess') + ax.plot(times, res_goat.optimized_controls[i], label='optimized pulse') + ax.plot(times[time_range], np.array(res_goat_time.optimized_controls[i])[time_range], label='optimized (over time) pulse') + ax.plot(times[global_range], np.array(res_goat_global.optimized_controls[i])[global_range], label='global optimized pulse') + ax.set_title(titles[i]) + ax.set_xlabel('Time') + ax.set_ylabel('Pulse amplitude') + ax.legend() + +plt.tight_layout() +plt.show() +``` + +## Validation + +```python +assert res_goat.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 + +assert res_goat_time.infidelity < 0.001 +assert fidelity(evolution_time.states[-1], target_gate) > 1-0.001 + +assert res_goat_global.infidelity < 0.001 +assert fidelity(evolution_global.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GOAT_state_closed.md b/tests/interactive/GOAT_state_closed.md new file mode 100644 index 0000000..bbce5f2 --- /dev/null +++ b/tests/interactive/GOAT_state_closed.md @@ -0,0 +1,292 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GOAT algorithm for a 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +goat_guess = [1, 0.5] +guess_pulse = goat_guess[0] * np.sin(goat_guess[1] * times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GOAT algorithm + +```python +# control function +def sin(t, c): + return c[0] * np.sin(c[1] * t) + +# gradient +def grad_sin(t, c, idx): + if idx == 0: # w.r.t. c0 + return np.sin(c[1] * t) + if idx == 1: # w.r.t. c1 + return c[0] * np.cos(c[1] * t) * t + if idx == 2: # w.r.t. time + return c[0] * np.cos(c[1] * t) * c[1] + +H = [Hd] + [[Hc, sin, {"grad": grad_sin}]] +``` + +### a) not optimized over time + +```python +control_params = { + "ctrl_x": {"guess": goat_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 +} + +# run the optimization +res_goat = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_goat.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_goat.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('GOAT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +Here, GOAT is stuck in a local minimum and does not reach the desired fidelity. + + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_goat_time = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, +) + +opt_time = res_goat_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_goat_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], label='optimized (over time) pulse') +plt.title('GOAT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_goat_time.optimized_controls[0])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_state, times2) + +plt.plot(times2, [np.abs(state.overlap(initial_state)) for state in evolution_time.states], label="Overlap with initial state") +plt.plot(times2, [np.abs(state.overlap(target_state)) for state in evolution_time.states], label="Overlap with target state") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], '--', label="Fidelity") + +plt.title('GOAT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +GOAT is still stuck in a local minimum, but the fidelity has improved. + + +### c) global optimization + +```python +res_goat_global = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GOAT", + "fid_err_targ": 0.001 + }, + optimizer_kwargs={ + "method": "basinhopping", + "max_iter": 1000 + } +) + +global_time = res_goat_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_goat_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], label='optimized pulse') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], label='global optimized pulse') +plt.title('GOAT pulses (global)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_goat_global.optimized_controls[0])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_state, times3) + +plt.plot(times3, [np.abs(state.overlap(initial_state)) for state in evolution_global.states], label="Overlap with initial state") +plt.plot(times3, [np.abs(state.overlap(target_state)) for state in evolution_global.states], label="Overlap with target state") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], '--', label="Fidelity") + +plt.title('GOAT (global) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +plt.plot(times, guess_pulse, color='blue', label='initial guess') +plt.plot(times, res_goat.optimized_controls[0], color='orange', label='optimized pulse') +plt.plot(times[time_range], np.array(res_goat_time.optimized_controls[0])[time_range], + color='green', label='optimized (over time) pulse') +plt.plot(times[global_range], np.array(res_goat_global.optimized_controls[0])[global_range], + color='red', label='global optimized pulse') + +plt.title('GOAT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +print('Guess Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) +print('GOAT Fidelity: ', 1 - res_goat.infidelity) +print('Time Fidelity: ', 1 - res_goat_time.infidelity) +print('GLobal Fidelity: ', 1 - res_goat_global.infidelity) + +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], color='blue', label="Guess") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], color='orange', label="GOAT") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], + color='green', label="Time") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], + color='red', label="Global") + +plt.title('Fidelities') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +guess_fidelity = qt.fidelity(evolution_guess.states[-1], target_state) + +# target fidelity not reached in part a), check that it is better than the guess +assert 1 - res_goat.infidelity > guess_fidelity +assert np.allclose(np.abs(evolution.states[-1].overlap(target_state)), 1 - res_goat.infidelity, atol=1e-3) + +# target fidelity not reached in part b), check that it is better than part a) +assert res_goat_time.infidelity < res_goat.infidelity +assert np.allclose(np.abs(evolution_time.states[-1].overlap(target_state)), 1 - res_goat_time.infidelity, atol=1e-3) + +assert res_goat_global.infidelity < 0.001 +assert np.abs(evolution_global.states[-1].overlap(target_state)) > 1 - 0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GRAPE_gate_closed.md b/tests/interactive/GRAPE_gate_closed.md new file mode 100644 index 0000000..bb9b145 --- /dev/null +++ b/tests/interactive/GRAPE_gate_closed.md @@ -0,0 +1,124 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GRAPE algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] +H = [Hd, Hc[0], Hc[1], Hc[2]] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +guess_pulse_x = np.sin(times) +guess_pulse_y = np.cos(times) +guess_pulse_z = np.tanh(times) + +H_guess = [Hd, [Hc[0], guess_pulse_x], [Hc[1], guess_pulse_y], [Hc[2], guess_pulse_z]] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GRAPE algorithm + +```python +control_params = { + "ctrl_x": {"guess": np.sin(times), "bounds": [-1, 1]}, + "ctrl_y": {"guess": np.cos(times), "bounds": [-1, 1]}, + "ctrl_z": {"guess": np.tanh(times), "bounds": [-1, 1]}, +} + +res_grape = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GRAPE", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_grape.infidelity) + +plt.plot(times, guess_pulse_x, 'b--', label='guess pulse sx') +plt.plot(times, res_grape.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, guess_pulse_y, 'g--', label='guess pulse sy') +plt.plot(times, res_grape.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, guess_pulse_z, 'r--', label='guess pulse sz') +plt.plot(times, res_grape.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('GRAPE pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc[0], res_grape.optimized_controls[0]], [Hc[1], res_grape.optimized_controls[1]], [Hc[2], res_grape.optimized_controls[2]]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('GRAPE performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_grape.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/GRAPE_state_closed.md b/tests/interactive/GRAPE_state_closed.md new file mode 100644 index 0000000..2681df7 --- /dev/null +++ b/tests/interactive/GRAPE_state_closed.md @@ -0,0 +1,114 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GRAPE algorithm for 2 level system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +guess_pulse = np.sin(times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## GRAPE algorithm + +```python +control_params = { + "ctrl_x": {"guess": guess_pulse, "bounds": [-1, 1]}, # Control pulse for Hc1 +} + +res_grape = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "GRAPE", + "fid_err_targ": 0.001 + }, +) + +print('Infidelity: ', res_grape.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_grape.optimized_controls[0], label='optimized pulse') +plt.title('GRAPE pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_grape.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('GRAPE performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +assert res_grape.infidelity < 0.01 +assert np.abs(evolution.states[-1].overlap(target_state)) > 1-0.01 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/JOPT_gate_closed.md b/tests/interactive/JOPT_gate_closed.md new file mode 100644 index 0000000..1f3370c --- /dev/null +++ b/tests/interactive/JOPT_gate_closed.md @@ -0,0 +1,285 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JOPT algorithm for a closed system + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import gates, qeye, sigmax, sigmay, sigmaz +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +try: + from jax import jit + from jax import numpy as jnp +except ImportError: # JAX not available, skip test + import pytest + pytest.skip("JAX not available") + +def fidelity(gate, target_gate): + """ + Fidelity used for unitary gates in qutip-qtrl and qutip-qoc + """ + return np.abs(gate.overlap(target_gate) / target_gate.norm()) +``` + +## Problem setup + +```python +omega = 0.1 # energy splitting +sx, sy, sz = sigmax(), sigmay(), sigmaz() + +Hd = 1 / 2 * omega * sz +Hc = [sx, sy, sz] + +# objective for optimization +initial_gate = qeye(2) +target_gate = gates.hadamard_transform() + +times = np.linspace(0, np.pi / 2, 250) +``` + +## Guess + +```python +jopt_guess = [1, 1] +guess_pulse = jopt_guess[0] * np.sin(jopt_guess[1] * times) + +H_guess = [Hd] + [[hc, guess_pulse] for hc in Hc] +evolution_guess = qt.sesolve(H_guess, initial_gate, times) + +print('Fidelity: ', fidelity(evolution_guess.states[-1], target_gate)) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution_guess.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution_guess.states], label="Overlap with target gate") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## JOPT algorithm + +```python +@jit +def sin_x(t, c, **kwargs): + return c[0] * jnp.sin(c[1] * t) + +H = [Hd] + [[hc, sin_x] for hc in Hc] +``` + +### a) not optimized over time + +```python +control_params = { + id: {"guess": jopt_guess, "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 + for id in ['x', 'y', 'z'] +} + +res_jopt = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs={ + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_jopt.infidelity) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times, res_jopt.optimized_controls[0], 'b', label='optimized pulse sx') +plt.plot(times, res_jopt.optimized_controls[1], 'g', label='optimized pulse sy') +plt.plot(times, res_jopt.optimized_controls[2], 'r', label='optimized pulse sz') +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, + [Hc[0], np.array(res_jopt.optimized_controls[0])], + [Hc[1], np.array(res_jopt.optimized_controls[1])], + [Hc[2], np.array(res_jopt.optimized_controls[2])]] +evolution = qt.sesolve(H_result, initial_gate, times) + +plt.plot(times, [fidelity(gate, initial_gate) for gate in evolution.states], label="Overlap with initial gate") +plt.plot(times, [fidelity(gate, target_gate) for gate in evolution.states], label="Overlap with target gate") + +plt.title('JOPT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_jopt_time = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs={ + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_jopt_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_jopt_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], 'b', label='optimized pulse sx') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[1])[time_range], 'g', label='optimized pulse sy') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[2])[time_range], 'r', label='optimized pulse sz') +plt.title('JOPT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_jopt_time.optimized_controls[0])], + [Hc[1], np.array(res_jopt_time.optimized_controls[1])], + [Hc[2], np.array(res_jopt_time.optimized_controls[2])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_gate, times2) + +plt.plot(times2, [fidelity(gate, initial_gate) for gate in evolution_time.states], label="Overlap with initial gate") +plt.plot(times2, [fidelity(gate, target_gate) for gate in evolution_time.states], label="Overlap with target gate") + +plt.title('JOPT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Global optimization + +```python +res_jopt_global = optimize_pulses( + objectives = Objective(initial_gate, H, target_gate), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, + optimizer_kwargs = { + "method": "basinhopping", + "max_iter": 100, + } +) + +global_time = res_jopt_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_jopt_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, 'k--', label='guess pulse sx, sy, sz') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], 'b', label='optimized pulse sx') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[1])[global_range], 'g', label='optimized pulse sy') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[2])[global_range], 'r', label='optimized pulse sz') +plt.title('JOPT pulses (global optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo( + [Hd, [Hc[0], np.array(res_jopt_global.optimized_controls[0])], + [Hc[1], np.array(res_jopt_global.optimized_controls[1])], + [Hc[2], np.array(res_jopt_global.optimized_controls[2])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_gate, times3) + +plt.plot(times3, [fidelity(gate, initial_gate) for gate in evolution_global.states], label="Overlap with initial gate") +plt.plot(times3, [fidelity(gate, target_gate) for gate in evolution_global.states], label="Overlap with target gate") + +plt.title('JOPT (global optimization) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +fig, axes = plt.subplots(1, 3, figsize=(18, 4)) # 1 row, 3 columns + +titles = ["JOPT sx pulses", "JOPT sy pulses", "JOPT sz pulses"] + +for i, ax in enumerate(axes): + ax.plot(times, guess_pulse, label='initial guess') + ax.plot(times, res_jopt.optimized_controls[i], label='optimized pulse') + ax.plot(times[time_range], np.array(res_jopt_time.optimized_controls[i])[time_range], label='optimized (over time) pulse') + ax.plot(times[global_range], np.array(res_jopt_global.optimized_controls[i])[global_range], label='global optimized pulse') + ax.set_title(titles[i]) + ax.set_xlabel('time') + ax.set_ylabel('Pulse amplitude') + ax.legend() + +plt.tight_layout() +plt.show() +``` + +## Validation + +```python +assert res_jopt.infidelity < 0.001 +assert fidelity(evolution.states[-1], target_gate) > 1-0.001 + +assert res_jopt_time.infidelity < 0.001 +assert fidelity(evolution_time.states[-1], target_gate) > 1-0.001 + +assert res_jopt_global.infidelity < 0.001 +assert fidelity(evolution_global.states[-1], target_gate) > 1-0.001 +``` + +```python +qt.about() +``` + + diff --git a/tests/interactive/JOPT_state_closed.md b/tests/interactive/JOPT_state_closed.md new file mode 100644 index 0000000..11547cb --- /dev/null +++ b/tests/interactive/JOPT_state_closed.md @@ -0,0 +1,296 @@ +--- +jupyter: + jupytext: + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.17.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# JOPT algorithm for a 2 level system + + +```python +import matplotlib.pyplot as plt +import numpy as np +from qutip import basis, Qobj +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +try: + from jax import jit + from jax import numpy as jnp +except ImportError: # JAX not available, skip test + import pytest + pytest.skip("JAX not available") +``` + +## Problem setup + +```python +# Energy levels +E1, E2 = 1.0, 2.0 + +Hd = Qobj(np.diag([E1, E2])) +Hc = Qobj(np.array([ + [0, 1], + [1, 0] +])) +H = [Hd, Hc] + +initial_state = basis(2, 0) # |1> +target_state = basis(2, 1) # |2> + +times = np.linspace(0, 2 * np.pi, 250) +``` + +## Guess + +```python +jopt_guess = [1, 0.5] +guess_pulse = jopt_guess[0] * np.sin(jopt_guess[1] * times) + +H_guess = [Hd, [Hc, guess_pulse]] +evolution_guess = qt.sesolve(H_guess, initial_state, times) + +print('Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution_guess.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution_guess.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], '--', label="Fidelity") +plt.title("Guess performance") +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## JOPT algorithm + +```python +@jit +def sin(t, c, **kwargs): + return c[0] * jnp.sin(c[1] * t) + +H = [Hd] + [[Hc, sin]] +``` + +### a) not optimized over time + +```python +control_params = { + "ctrl_x": {"guess": [1, 0], "bounds": [(-1, 1), (0, 2 * np.pi)]} # c0 and c1 +} + +res_jopt = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +print('Infidelity: ', res_jopt.infidelity) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +H_result = [Hd, [Hc, np.array(res_jopt.optimized_controls[0])]] +evolution = qt.sesolve(H_result, initial_state, times) + +plt.plot(times, [np.abs(state.overlap(initial_state)) for state in evolution.states], label="Overlap with initial state") +plt.plot(times, [np.abs(state.overlap(target_state)) for state in evolution.states], label="Overlap with target state") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], '--', label="Fidelity") + +plt.title('JOPT performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +Here, JOPT is stuck in a local minimum and does not reach the desired fidelity. + + +### b) optimized over time + +```python +# treats time as optimization variable +control_params["__time__"] = { + "guess": times[len(times) // 2], + "bounds": [times[0], times[-1]], +} + +# run the optimization +res_jopt_time = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + minimizer_kwargs = { + "method": "Nelder-Mead", + }, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, +) + +opt_time = res_jopt_time.optimized_params[-1][0] +time_range = times < opt_time + +print('Infidelity: ', res_jopt_time.infidelity) +print('Optimized time: ', opt_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], label='optimized (over time) pulse') +plt.title('JOPT pulses (time optimization)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times2 = times[time_range] +if opt_time not in times2: + times2 = np.append(times2, opt_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_jopt_time.optimized_controls[0])]], tlist=times) +evolution_time = qt.sesolve(H_result, initial_state, times2) + +plt.plot(times2, [np.abs(state.overlap(initial_state)) for state in evolution_time.states], label="Overlap with initial state") +plt.plot(times2, [np.abs(state.overlap(target_state)) for state in evolution_time.states], label="Overlap with target state") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], '--', label="Fidelity") + +plt.title('JOPT (optimized over time) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +JOPT is still stuck in a local minimum, but the fidelity has improved. + + +### c) global optimization + +```python +res_jopt_global = optimize_pulses( + objectives = Objective(initial_state, H, target_state), + control_parameters = control_params, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001 + }, + optimizer_kwargs={ + "method": "basinhopping", + "max_iter": 1000 + } +) + +global_time = res_jopt_global.optimized_params[-1][0] +global_range = times < global_time + +print('Infidelity: ', res_jopt_global.infidelity) +print('Optimized time: ', global_time) + +plt.plot(times, guess_pulse, label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], label='optimized pulse') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], label='global optimized pulse') +plt.title('JOPT pulses (global)') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +times3 = times[global_range] +if global_time not in times3: + times3 = np.append(times3, global_time) + +H_result = qt.QobjEvo([Hd, [Hc, np.array(res_jopt_global.optimized_controls[0])]], tlist=times) +evolution_global = qt.sesolve(H_result, initial_state, times3) + +plt.plot(times3, [np.abs(state.overlap(initial_state)) for state in evolution_global.states], label="Overlap with initial state") +plt.plot(times3, [np.abs(state.overlap(target_state)) for state in evolution_global.states], label="Overlap with target state") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], '--', label="Fidelity") + +plt.title('JOPT (global) performance') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Comparison + +```python +plt.plot(times, guess_pulse, color='blue', label='initial guess') +plt.plot(times, res_jopt.optimized_controls[0], color='orange', label='optimized pulse') +plt.plot(times[time_range], np.array(res_jopt_time.optimized_controls[0])[time_range], + color='green', label='optimized (over time) pulse') +plt.plot(times[global_range], np.array(res_jopt_global.optimized_controls[0])[global_range], + color='red', label='global optimized pulse') + +plt.title('JOPT pulses') +plt.xlabel('Time') +plt.ylabel('Pulse amplitude') +plt.legend() +plt.show() +``` + +```python +print('Guess Fidelity: ', qt.fidelity(evolution_guess.states[-1], target_state)) +print('JOPT Fidelity: ', 1 - res_jopt.infidelity) +print('Time Fidelity: ', 1 - res_jopt_time.infidelity) +print('GLobal Fidelity: ', 1 - res_jopt_global.infidelity) + +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution_guess.states], color='blue', label="Guess") +plt.plot(times, [qt.fidelity(state, target_state) for state in evolution.states], color='orange', label="Goat") +plt.plot(times2, [qt.fidelity(state, target_state) for state in evolution_time.states], + color='green', label="Time") +plt.plot(times3, [qt.fidelity(state, target_state) for state in evolution_global.states], + color='red', label="Global") + +plt.title('Fidelities') +plt.xlabel('Time') +plt.legend() +plt.show() +``` + +## Validation + +```python +guess_fidelity = qt.fidelity(evolution_guess.states[-1], target_state) + +# target fidelity not reached in part a), check that it is better than the guess +assert 1 - res_jopt.infidelity > guess_fidelity +assert np.allclose(np.abs(evolution.states[-1].overlap(target_state)), 1 - res_jopt.infidelity, atol=1e-3) + +# target fidelity not reached in part b), check that it is better than part a) +assert res_jopt_time.infidelity < res_jopt.infidelity +assert np.allclose(np.abs(evolution_time.states[-1].overlap(target_state)), 1 - res_jopt_time.infidelity, atol=1e-3) + +assert res_jopt_global.infidelity < 0.001 +assert np.abs(evolution_global.states[-1].overlap(target_state)) > 1 - 0.001 +``` + +```python +qt.about() +``` diff --git a/tests/interactive/about.md b/tests/interactive/about.md new file mode 100644 index 0000000..9d22894 --- /dev/null +++ b/tests/interactive/about.md @@ -0,0 +1,21 @@ +This directory contains notebooks showcasing the use of all algorithms contained within `QuTiP-QOC`. +The examples are chosen as simple as possible, for the purpose of demonstrating how to use `QuTiP-QOC` in all of these scenarios, and for the purpose of automatically testing `QuTiP-QOC`'s basic functionality in all of these scenarios. +The included algorithms are: +- GRAPE +- CRAB +- GOAT +- JOPT + +For each algorithm, we have: +- a closed-system state transfer example +- an open-system state transfer example +- a closed-system gate synthesis example +- an open-system gate synthesis example + +The notebooks are included automatically in runs of the test suite (see `test_interactive.py`). + +To view and run the notebooks manually, the `jupytext` package is required. +The notebooks can then either be opened from within Jupyter Notebook using "Open With" -> "Jupytext Notebook", or by converting them first to the `ipynb` format using +``` +jupytext --to ipynb [notebook name].nb +``` \ No newline at end of file diff --git a/tests/test_analytical_pulses.py b/tests/test_analytical_pulses.py index 616106b..ecc1417 100644 --- a/tests/test_analytical_pulses.py +++ b/tests/test_analytical_pulses.py @@ -4,11 +4,16 @@ import pytest import qutip as qt -import qutip_jax # noqa: F401 import numpy as np -import jax.numpy as jnp import collections +try: + import jax.numpy as jnp + import qutip_jax # noqa: F401 + _jax_available = True +except ImportError: + _jax_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective @@ -139,52 +144,59 @@ def grad_sin(t, p, idx): ) -# ----------------------- System and JAX Control --------------------- +if _jax_available: + # ----------------------- System and JAX Control --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -Hc_jax = [ - [qt.sigmax(), lambda t, p: sin_jax(t, p)], - [qt.sigmay(), lambda t, q: sin_jax(t, q)], -] + Hc_jax = [ + [qt.sigmax(), lambda t, p: sin_jax(t, p)], + [qt.sigmay(), lambda t, q: sin_jax(t, q)], + ] -H_jax = H_d + Hc_jax + H_jax = H_d + Hc_jax -# ------------------------------- Objective ------------------------------- + # ------------------------------- Objective ------------------------------- -# state to state transfer -state2state_jax = state2state._replace( - objectives=[Objective(initial, H_jax, target)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # state to state transfer + state2state_jax = state2state._replace( + objectives=[Objective(initial, H_jax, target)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# unitary gate synthesis -unitary_jax = unitary._replace( - objectives=[Objective(initial_U, H_jax, target_U)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # unitary gate synthesis + unitary_jax = unitary._replace( + objectives=[Objective(initial_U, H_jax, target_U)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# unitary gate synthesis - time optimization -time_jax = time._replace( - objectives=[Objective(initial_U, H_jax, target_U)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + # unitary gate synthesis - time optimization + time_jax = time._replace( + objectives=[Objective(initial_U, H_jax, target_U)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) -# map synthesis -Lc_jax = [ - [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], - [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], -] -L_jax = L_d + Lc_jax + # map synthesis + Lc_jax = [ + [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], + [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], + ] + L_jax = L_d + Lc_jax -mapping_jax = mapping._replace( - objectives=[Objective(initial_map, L_jax, target_map)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.1}, # relaxed objective -) + mapping_jax = mapping._replace( + objectives=[Objective(initial_map, L_jax, target_map)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.1}, # relaxed objective + ) +else: + # jax not available, set these to none so tests will be skipped + state2state_jax = None + unitary_jax = None + time_jax = None + mapping_jax = None @pytest.fixture( params=[ @@ -205,6 +217,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("JAX not available") + result = optimize_pulses( tst.objectives, tst.control_parameters, diff --git a/tests/test_fidelity.py b/tests/test_fidelity.py index 8490a3e..2c3509d 100644 --- a/tests/test_fidelity.py +++ b/tests/test_fidelity.py @@ -5,9 +5,14 @@ import pytest import qutip as qt import numpy as np -import jax.numpy as jnp import collections +try: + import jax.numpy as jnp + _jax_available = True +except ImportError: + _jax_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective @@ -108,55 +113,64 @@ def grad_sin(t, p, idx): algorithm_kwargs={"alg": "GOAT", "fid_type": "TRACEDIFF"}, ) -# ----------------------- System and JAX Control --------------------- +if _jax_available: + # ----------------------- System and JAX Control --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -Hc_jax = [ - [qt.sigmax(), lambda t, p: sin_jax(t, p)], - [qt.sigmay(), lambda t, q: sin_jax(t, q)], -] + Hc_jax = [ + [qt.sigmax(), lambda t, p: sin_jax(t, p)], + [qt.sigmay(), lambda t, q: sin_jax(t, q)], + ] -H_jax = H_d + Hc_jax + H_jax = H_d + Hc_jax -# ------------------------------- Objective ------------------------------- + # ------------------------------- Objective ------------------------------- -# state to state transfer -PSU_state2state_jax = PSU_state2state._replace( - objectives=[Objective(initial, H_jax, (-1j) * initial)], - algorithm_kwargs={"alg": "JOPT"}, -) + # state to state transfer + PSU_state2state_jax = PSU_state2state._replace( + objectives=[Objective(initial, H_jax, (-1j) * initial)], + algorithm_kwargs={"alg": "JOPT"}, + ) -SU_state2state_jax = SU_state2state._replace( - objectives=[Objective(initial, H_jax, initial)], algorithm_kwargs={"alg": "JOPT"} -) + SU_state2state_jax = SU_state2state._replace( + objectives=[Objective(initial, H_jax, initial)], algorithm_kwargs={"alg": "JOPT"} + ) -# unitary gate synthesis -PSU_unitary_jax = PSU_unitary._replace( - objectives=[Objective(initial_U, H_jax, (-1j) * initial_U)], - algorithm_kwargs={"alg": "JOPT"}, -) + # unitary gate synthesis + PSU_unitary_jax = PSU_unitary._replace( + objectives=[Objective(initial_U, H_jax, (-1j) * initial_U)], + algorithm_kwargs={"alg": "JOPT"}, + ) -SU_unitary_jax = SU_unitary._replace( - objectives=[Objective(initial_U, H_jax, initial_U)], - algorithm_kwargs={"alg": "JOPT"}, -) + SU_unitary_jax = SU_unitary._replace( + objectives=[Objective(initial_U, H_jax, initial_U)], + algorithm_kwargs={"alg": "JOPT"}, + ) -# map synthesis -Lc_jax = [ - [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], - [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], -] -L_jax = L_d + Lc_jax + # map synthesis + Lc_jax = [ + [qt.liouvillian(qt.sigmax()), lambda t, p: sin_jax(t, p)], + [qt.liouvillian(qt.sigmay()), lambda t, q: sin_jax(t, q)], + ] + L_jax = L_d + Lc_jax -TRCDIFF_map_jax = TRCDIFF_map._replace( - objectives=[Objective(initial_map, L_jax, initial_map)], - algorithm_kwargs={"alg": "JOPT", "fid_type": "TRACEDIFF"}, -) + TRCDIFF_map_jax = TRCDIFF_map._replace( + objectives=[Objective(initial_map, L_jax, initial_map)], + algorithm_kwargs={"alg": "JOPT", "fid_type": "TRACEDIFF"}, + ) + +else: + # jax not available, set these to none so tests will be skipped + PSU_state2state_jax = None + SU_state2state_jax = None + PSU_unitary_jax = None + SU_unitary_jax = None + TRCDIFF_map_jax = None @pytest.fixture( @@ -182,6 +196,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("JAX not available") + result = optimize_pulses( tst.objectives, tst.control_parameters, diff --git a/tests/test_interactive.py b/tests/test_interactive.py new file mode 100644 index 0000000..0961db0 --- /dev/null +++ b/tests/test_interactive.py @@ -0,0 +1,32 @@ +""" +This file contains the test suite for running the interactive test notebooks +in the 'tests/interactive' directory. +Taken, modified from https://github.com/SeldonIO/alibi/blob/master/testing/test_notebooks.py +""" + +import glob +import nbclient.exceptions +import pytest + +from pathlib import Path +from jupytext.cli import jupytext +import nbclient + +# Set of all example notebooks +NOTEBOOK_DIR = 'tests/interactive' +ALL_NOTEBOOKS = { + Path(x).name + for x in glob.glob(str(Path(NOTEBOOK_DIR).joinpath('*.md'))) + if Path(x).name != 'about.md' +} + +@pytest.mark.parametrize("notebook", ALL_NOTEBOOKS) +def test_notebook(notebook): + notebook = Path(NOTEBOOK_DIR, notebook) + try: + jupytext(args=[str(notebook), "--execute"]) + except nbclient.exceptions.CellExecutionError as e: + if e.ename == "Skipped": + pytest.skip(e.evalue) + else: + raise e \ No newline at end of file diff --git a/tests/test_jopt_open_system_bug.py b/tests/test_jopt_open_system_bug.py new file mode 100644 index 0000000..a4c9d26 --- /dev/null +++ b/tests/test_jopt_open_system_bug.py @@ -0,0 +1,39 @@ +import numpy as np +import qutip as qt +from qutip_qoc import Objective, optimize_pulses + +from jax import jit, numpy + +def test_open_system_jopt_runs_without_error(): + Hd = qt.Qobj(np.diag([1, 2])) + c_ops = [np.sqrt(0.1) * qt.sigmam()] + Hc = qt.sigmax() + + Ld = qt.liouvillian(H=Hd, c_ops=c_ops) + Lc = qt.liouvillian(Hc) + + initial_state = qt.fock_dm(2, 0) + target_state = qt.fock_dm(2, 1) + + times = np.linspace(0, 2 * np.pi, 250) + + @jit + def sin_x(t, c, **kwargs): + return c[0] * numpy.sin(c[1] * t) + L = [Ld, [Lc, sin_x]] + + guess_params = [1, 0.5] + + res_jopt = optimize_pulses( + objectives = Objective(initial_state, L, target_state), + control_parameters = { + "ctrl_x": {"guess": guess_params, "bounds": [(-1, 1), (0, 2 * np.pi)]} + }, + tlist = times, + algorithm_kwargs = { + "alg": "JOPT", + "fid_err_targ": 0.001, + }, + ) + + assert res_jopt.infidelity < 0.25, f"Fidelity error too high: {res_jopt.infidelity}" \ No newline at end of file diff --git a/tests/test_result.py b/tests/test_result.py index 0299e95..6bd7f8c 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -5,10 +5,22 @@ import pytest import qutip as qt import numpy as np -import jax -import jax.numpy as jnp import collections +try: + import jax + import jax.numpy as jnp + _jax_available = True +except ImportError: + _jax_available = False + +try: + import gymnasium + import stable_baselines3 + _rl_available = True +except ImportError: + _rl_available = False + from qutip_qoc.pulse_optim import optimize_pulses from qutip_qoc.objective import Objective from qutip_qoc._time import _TimeInterval @@ -84,37 +96,42 @@ def grad_sin(t, p, idx): objectives=[Objective(initial, H, target)], algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01}, ) -# ----------------------- JAX --------------------- +if _jax_available: + # ----------------------- JAX --------------------- -def sin_jax(t, p): - return p[0] * jnp.sin(p[1] * t + p[2]) + def sin_jax(t, p): + return p[0] * jnp.sin(p[1] * t + p[2]) -@jax.jit -def sin_x_jax(t, p, **kwargs): - return sin_jax(t, p) + @jax.jit + def sin_x_jax(t, p, **kwargs): + return sin_jax(t, p) -@jax.jit -def sin_y_jax(t, q, **kwargs): - return sin_jax(t, q) + @jax.jit + def sin_y_jax(t, q, **kwargs): + return sin_jax(t, q) -@jax.jit -def sin_z_jax(t, r, **kwargs): - return sin_jax(t, r) + @jax.jit + def sin_z_jax(t, r, **kwargs): + return sin_jax(t, r) -Hc_jax = [[qt.sigmax(), sin_x_jax], [qt.sigmay(), sin_y_jax], [qt.sigmaz(), sin_z_jax]] -H_jax = H_d + Hc_jax + Hc_jax = [[qt.sigmax(), sin_x_jax], [qt.sigmay(), sin_y_jax], [qt.sigmaz(), sin_z_jax]] -# state to state transfer -state2state_jax = state2state_goat._replace( - objectives=[Objective(initial, H_jax, target)], - algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, -) + H_jax = H_d + Hc_jax + + # state to state transfer + state2state_jax = state2state_goat._replace( + objectives=[Objective(initial, H_jax, target)], + algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, + ) + +else: + state2state_jax = None # ------------------- discrete CRAB / GRAPE control ------------------------ @@ -152,6 +169,57 @@ def sin_z_jax(t, r, **kwargs): algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01, "fix_frequency": False}, ) +if _rl_available: + # ----------------------- RL -------------------- + + # state to state transfer + initial = qt.basis(2, 0) + target = (qt.basis(2, 0) + qt.basis(2, 1)).unit() # |+⟩ + + H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians + + w, d, y = 0.1, 1.0, 0.1 + H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian + + H = [H_d] + H_c # total Hamiltonian + + state2state_rl = Case( + objectives=[Objective(initial, H, target)], + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + tlist=np.linspace(0, 10, 100), + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 20000, + "shorter_pulses": True, + }, + optimizer_kwargs={}, + ) + + # no big difference for unitary evolution + + initial = qt.qeye(2) # Identity + target = qt.gates.hadamard_transform() + + unitary_rl = state2state_rl._replace( + objectives=[Objective(initial, H, target)], + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 300, + "shorter_pulses": True, + }, + ) + +else: # skip RL tests + state2state_rl = None + unitary_rl = None + @pytest.fixture( params=[ @@ -160,6 +228,8 @@ def sin_z_jax(t, r, **kwargs): pytest.param(state2state_param_crab, id="State to state (param. CRAB)"), pytest.param(state2state_goat, id="State to state (GOAT)"), pytest.param(state2state_jax, id="State to state (JAX)"), + pytest.param(state2state_rl, id="State to state (RL)"), + pytest.param(unitary_rl, id="Unitary (RL)"), ] ) def tst(request): @@ -167,6 +237,9 @@ def tst(request): def test_optimize_pulses(tst): + if tst is None: + pytest.skip("Dependency not available") + result = optimize_pulses( tst.objectives, tst.control_parameters,