Skip to content

Commit e7b5db8

Browse files
authored
Merge pull request #145 from Sampreet/pr/jax-docs
Update Documentation for Features under Development
2 parents 1ebaaa8 + 996d3e5 commit e7b5db8

File tree

5 files changed

+110
-1
lines changed

5 files changed

+110
-1
lines changed

CONTRIBUTING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The current setup uses:
2929
* [tox](https://tox.readthedocs.io) ... for testing with different environments.
3030
* [travis](https://travis-ci.com) ... for continuous integration.
3131

32+
We are actively incorporating additional features to OQuPy,
33+
details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md).
34+
3235
## How to contribute to the code or documentation
3336
Please use the
3437
[Issues](https://github.com/tempoCollaboration/OQuPy/issues) and

DEVELOPMENT.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Development
2+
3+
The following feautres are under active development:
4+
5+
* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus)
6+
7+
## Experimental Support for GPUs/TPUs
8+
9+
The current development branch `dev/jax` implements
10+
experimental support for GPUs/TPUs.
11+
12+
Although OQuPy is built on top of the backend-agnostic
13+
[TensorNetwork](https://github.com/google/TensorNetwork) library,
14+
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
15+
The `dev/jax` branch adds support for GPUs/TPUs via the
16+
[JAX](https://jax.readthedocs.io/en/latest/) library.
17+
A new `oqupy.backends.numerical_backend.py` module handles the
18+
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html),
19+
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there
20+
without explicitly importing JAX-based libraries.
21+
22+
### Enabling Experimental Features
23+
24+
To enable experimental features switch to the `dev/jax` branch and use
25+
```python
26+
from oqupy.backends import enable_jax_features
27+
enable_jax_features()
28+
```
29+
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
30+
initialize the jax backend by default.
31+
32+
### Contributing Guidelines
33+
34+
To contribute features compatible with the JAX backend,
35+
please adhere to the following set of guidelines:
36+
37+
* avoid wildcard imports of NumPy and SciPy.
38+
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required.
39+
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used.
40+
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`).
41+
* convert lists or tuples to arrays when passing them as arguments inside functions.
42+
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`.
43+
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`.
44+
* declare signatures for `np.vectorize` explicitly.
45+
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead)

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ...
112112
:caption: Development
113113

114114
pages/contributing
115+
pages/jax_features
115116
pages/authors
116117
pages/how_to_cite
117118
pages/sharing

docs/pages/authors.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@ Authors & Acknowledgements
44
- Lead developer since 2020: `Gerald E.
55
Fux <https://github.com/gefux>`__ ([email protected])
66
- Co-lead developer since 2022: `Piper
7-
Fowler-Wright <https://github.com/piperfw>`__ ([email protected])
7+
Fowler-Wright <https://github.com/piperfw>`__ ([email protected])
88

99
Major code contributions
1010
------------------------
1111

12+
**Experimental features**
13+
14+
- `Sampreet Kalita <https://github.com/Sampreet>`__: JAX numerical backend for
15+
GPU/TPU support
16+
1217
**Version 0.5.0**
1318

1419
- `Aidan Strathearn <https://github.com/aidanstrathearn>`__: Gibbs state TEMPO [Chiu2022].

docs/pages/jax_features.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
Experimental Support for GPUs/TPUs
2+
==================================
3+
4+
The current development branch ``dev/jax`` implements
5+
experimental support for GPUs/TPUs.
6+
7+
Although OQuPy is built on top of the backend-agnostic
8+
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library,
9+
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
10+
The ``dev/jax`` branch adds support for GPUs/TPUs via the
11+
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new
12+
``oqupy.backends.numerical_backend.py`` module handles the
13+
`breaking changes in JAX
14+
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__,
15+
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg``
16+
instances from there without explicitly importing JAX-based libraries.
17+
18+
Enabling Experimental Features
19+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20+
21+
To enable experimental features, switch to the ``dev/jax`` branch and use
22+
23+
.. code:: python
24+
25+
from oqupy.backends import enable_jax_features
26+
enable_jax_features()
27+
28+
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
29+
initialize the jax backend by default.
30+
31+
Contributing Guidelines
32+
~~~~~~~~~~~~~~~~~~~~~~~
33+
34+
To contribute features compatible with the JAX backend,
35+
please adhere to the following set of guidelines:
36+
37+
- avoid wildcard imports of NumPy and SciPy.
38+
- use ``from oqupy.backends.numerical_backend import np`` instead of
39+
``import numpy as np`` and use the alias ``default_np`` in cases
40+
vanilla NumPy is explicitly required.
41+
- use ``from oqupy.backends.numerical_backend import la`` instead of
42+
``import scipy.linalg as la``, except that for non-symmetric
43+
eigen-decomposition, ``scipy.linalg.eig`` should be used.
44+
- use one of ``np.dtype_complex`` (``np.dtype_float``) or
45+
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``)
46+
instead of ``np.complex_`` (``np.float_``).
47+
- convert lists or tuples to arrays when passing them as arguments
48+
inside functions.
49+
- use ``array = np.update(array, indices, values)`` instead of
50+
``array[indices] = values``.
51+
- use ``np.get_random_floats(seed, shape)`` instead of
52+
``np.random.default_rng(seed).random(shape)``.
53+
- declare signatures for ``np.vectorize`` explicitly.
54+
- avoid directly changing the ``shape`` attribute of an array (use
55+
``.reshape`` instead)

0 commit comments

Comments
 (0)