|
| 1 | +Experimental Support for GPUs/TPUs |
| 2 | +================================== |
| 3 | + |
| 4 | +The current development branch ``dev/jax`` implements |
| 5 | +experimental support for GPUs/TPUs. |
| 6 | + |
| 7 | +Although OQuPy is built on top of the backend-agnostic |
| 8 | +`TensorNetwork <https://github.com/google/TensorNetwork>`__ library, |
| 9 | +OQuPy uses vanilla NumPy and SciPy throughout its implementation. |
| 10 | +The ``dev/jax`` branch adds support for GPUs/TPUs via the |
| 11 | +`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new |
| 12 | +``oqupy.backends.numerical_backend.py`` module handles the |
| 13 | +`breaking changes in JAX |
| 14 | +NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__, |
| 15 | +while the rest of the modules utilizes ``numpy`` and ``scipy.linalg`` |
| 16 | +instances from there without explicitly importing JAX-based libraries. |
| 17 | + |
| 18 | +Enabling Experimental Features |
| 19 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 20 | + |
| 21 | +To enable experimental features, switch to the ``dev/jax`` branch and use |
| 22 | + |
| 23 | +.. code:: python |
| 24 | +
|
| 25 | + from oqupy.backends import enable_jax_features |
| 26 | + enable_jax_features() |
| 27 | +
|
| 28 | +Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to |
| 29 | +initialize the jax backend by default. |
| 30 | + |
| 31 | +Contributing Guidelines |
| 32 | +~~~~~~~~~~~~~~~~~~~~~~~ |
| 33 | + |
| 34 | +To contribute features compatible with the JAX backend, |
| 35 | +please adhere to the following set of guidelines: |
| 36 | + |
| 37 | +- avoid wildcard imports of NumPy and SciPy. |
| 38 | +- use ``from oqupy.backends.numerical_backend import np`` instead of |
| 39 | + ``import numpy as np`` and use the alias ``default_np`` in cases |
| 40 | + vanilla NumPy is explicitly required. |
| 41 | +- use ``from oqupy.backends.numerical_backend import la`` instead of |
| 42 | + ``import scipy.linalg as la``, except that for non-symmetric |
| 43 | + eigen-decomposition, ``scipy.linalg.eig`` should be used. |
| 44 | +- use one of ``np.dtype_complex`` (``np.dtype_float``) or |
| 45 | + ``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``) |
| 46 | + instead of ``np.complex_`` (``np.float_``). |
| 47 | +- convert lists or tuples to arrays when passing them as arguments |
| 48 | + inside functions. |
| 49 | +- use ``array = np.update(array, indices, values)`` instead of |
| 50 | + ``array[indices] = values``. |
| 51 | +- use ``np.get_random_floats(seed, shape)`` instead of |
| 52 | + ``np.random.default_rng(seed).random(shape)``. |
| 53 | +- declare signatures for ``np.vectorize`` explicitly. |
| 54 | +- avoid directly changing the ``shape`` attribute of an array (use |
| 55 | + ``.reshape`` instead) |
0 commit comments