Skip to content

Commit 6ccbd2e

Browse files
authored
Merge pull request #144 from Sampreet/pr/feature-numerical-backend
JAX Numerical Backend for GPU/TPU Support
2 parents 1ebaaa8 + f142c44 commit 6ccbd2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1098
-687
lines changed

CONTRIBUTING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The current setup uses:
2929
* [tox](https://tox.readthedocs.io) ... for testing with different environments.
3030
* [travis](https://travis-ci.com) ... for continuous integration.
3131

32+
We are actively incorporating additional features to OQuPy,
33+
details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md).
34+
3235
## How to contribute to the code or documentation
3336
Please use the
3437
[Issues](https://github.com/tempoCollaboration/OQuPy/issues) and

DEVELOPMENT.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Development
2+
3+
The current development branch "dev/jax" implements
4+
5+
* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus)
6+
7+
## Experimental Support for GPUs/TPUs
8+
9+
Although OQuPy is built on top of the backend-agnostic
10+
[TensorNetwork](https://github.com/google/TensorNetwork) library,
11+
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
12+
13+
The "dev/jax" branch adds supports for GPUs/TPUs via the
14+
[JAX](https://jax.readthedocs.io/en/latest/) library.
15+
A new `oqupy.backends.numerical_backend.py` module handles the
16+
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html),
17+
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there
18+
without explicitly importing JAX-based libraries.
19+
20+
### Enabling Experimental Features
21+
22+
To enable experimental features switch to the `dev/jax` branch and use
23+
```python
24+
from oqupy.backends import enable_jax_features
25+
enable_jax_features()
26+
```
27+
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
28+
initialize the jax backend by default.
29+
30+
### Contributing Guidelines
31+
32+
To contribute features compatible with the JAX backend,
33+
please adhere to the following set of guidelines:
34+
35+
* avoid wildcard imports of NumPy and SciPy.
36+
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required.
37+
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used.
38+
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`).
39+
* convert lists or tuples to arrays when passing them as arguments inside functions.
40+
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`.
41+
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`.
42+
* declare signatures for `np.vectorize` explicitly.
43+
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead)

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ...
112112
:caption: Development
113113

114114
pages/contributing
115+
pages/gpu_features
115116
pages/authors
116117
pages/how_to_cite
117118
pages/sharing

docs/pages/api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ class :class:`oqupy.pt_tebd.PtTebd`
178178
dictionary.
179179

180180

181-
182181
Results
183182
-------
184183

@@ -207,3 +206,6 @@ module :mod:`oqupy.operators`
207206
function :func:`oqupy.helpers.plot_correlations_with_parameters`
208207
A helper function to plot an auto-correlation function and the sampling
209208
points given by a set of parameters for a TEMPO computation.
209+
210+
function :func:`oqupy.backends.enable_jax_features`
211+
Option to use JAX to support multiple device backends (CPUs/GPUs/TPUs).

docs/pages/gpu_features.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
Experimental Support for GPUs/TPUs
2+
==================================
3+
The current development branch "dev/jax" implements experimental support
4+
for GPUs/TPUs.
5+
6+
Although OQuPy is built on top of the backend-agnostic
7+
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library,
8+
OQuPy uses vanilla NumPy and SciPy throughout its implementation.
9+
10+
The "dev/jax" branch adds supports for GPUs/TPUs via the
11+
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new
12+
``oqupy.backends.numerical_backend.py`` module handles the
13+
`breaking changes in JAX
14+
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__,
15+
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg``
16+
instances from there without explicitly importing JAX-based libraries.
17+
18+
Enabling Experimental Features
19+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20+
21+
To enable experimental features, switch to the ``dev/jax`` branch and use
22+
23+
.. code:: python
24+
25+
from oqupy.backends import enable_jax_features
26+
enable_jax_features()
27+
28+
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
29+
initialize the jax backend by default.
30+
31+
Contributing Guidelines
32+
~~~~~~~~~~~~~~~~~~~~~~~
33+
34+
To contribute features compatible with the JAX backend,
35+
please adhere to the following set of guidelines:
36+
37+
- avoid wildcard imports of NumPy and SciPy.
38+
- use ``from oqupy.backends.numerical_backend import np`` instead of
39+
``import numpy as np`` and use the alias ``default_np`` in cases
40+
vanilla NumPy is explicitly required.
41+
- use ``from oqupy.backends.numerical_backend import la`` instead of
42+
``import scipy.linalg as la``, except that for non-symmetric
43+
eigen-decomposition, ``scipy.linalg.eig`` should be used.
44+
- use one of ``np.dtype_complex`` (``np.dtype_float``) or
45+
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``)
46+
instead of ``np.complex_`` (``np.float_``).
47+
- convert lists or tuples to arrays when passing them as arguments
48+
inside functions.
49+
- use ``array = np.update(array, indices, values)`` instead of
50+
``array[indices] = values``.
51+
- use ``np.get_random_floats(seed, shape)`` instead of
52+
``np.random.default_rng(seed).random(shape)``.
53+
- declare signatures for ``np.vectorize`` explicitly.
54+
- avoid directly changing the ``shape`` attribute of an array (use
55+
``.reshape`` instead)

examples/simple_dynamics_jax.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python
2+
3+
import sys
4+
sys.path.insert(0, '.')
5+
# set the 'OQUPY_BACKEND' environment variable
6+
# to 'jax' to initialize JAX backend by default
7+
# or switch to JAX backend using oqupy.backends
8+
import oqupy
9+
from oqupy.backends import enable_jax_features
10+
# import NumPy from numerical_backend
11+
#from oqupy.backends.numerical_backend import np
12+
#enable_jax_features()
13+
14+
import matplotlib.pyplot as plt
15+
sigma_x = oqupy.operators.sigma("x")
16+
sigma_z = oqupy.operators.sigma("z")
17+
up_density_matrix = oqupy.operators.spin_dm("z+")
18+
Omega = 1.0
19+
omega_cutoff = 5.0
20+
alpha = 0.3
21+
22+
system = oqupy.System(0.5 * Omega * sigma_x)
23+
correlations = oqupy.PowerLawSD(alpha=alpha,
24+
zeta=1,
25+
cutoff=omega_cutoff,
26+
cutoff_type='exponential')
27+
bath = oqupy.Bath(0.5 * sigma_z, correlations)
28+
tempo_parameters = oqupy.TempoParameters(dt=0.1, tcut=3.0, epsrel=10**(-4))
29+
30+
dynamics = oqupy.tempo_compute(system=system,
31+
bath=bath,
32+
initial_state=up_density_matrix,
33+
start_time=0.0,
34+
end_time=2.0,
35+
parameters=tempo_parameters,
36+
unique=True)
37+
t, s_z = dynamics.expectations(0.5*sigma_z, real=True)
38+
print(s_z)
39+
plt.plot(t, s_z, label=r'$\alpha=0.3$')
40+
plt.xlabel(r'$t\,\Omega$')
41+
plt.ylabel(r'$\langle\sigma_z\rangle$')
42+
plt.savefig('simple_dynamics_jax.png')
43+
#plt.show()

oqupy/backends/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Module to initialize OQuPy's backends."""
2+
3+
from oqupy.backends.numerical_backend import set_numerical_backends
4+
5+
def enable_jax_features():
6+
"""Function to enable experimental features."""
7+
8+
# set numerical backend to JAX
9+
set_numerical_backends('jax')

oqupy/backends/node_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
from typing import Any, List, Optional, Text, Tuple, Union
1717

18-
import numpy as np
1918
import tensornetwork as tn
2019
from tensornetwork import Node
2120
from tensornetwork.backends.base_backend import BaseBackend
2221

22+
from oqupy.backends.numerical_backend import np
2323

2424
class NodeArray:
2525
"""NodeArray class. """
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
"""
13+
Module containing NumPy-like and SciPy-like numerical backends.
14+
"""
15+
16+
import os
17+
18+
import numpy as default_np
19+
import scipy.linalg as default_la
20+
21+
from tensornetwork.backend_contextmanager import \
22+
set_default_backend
23+
24+
import oqupy.config as oc
25+
26+
# store instances of the initialized backends
27+
# this way, `oqupy.config` remains unchanged
28+
# and `ocupy.config.DEFAULT_BACKEND` is used
29+
# when NumPy and LinAlg are initialized
30+
NUMERICAL_BACKEND_INSTANCES = {}
31+
32+
def get_numerical_backends(
33+
backend_name: str,
34+
):
35+
"""Function to get numerical backend.
36+
37+
Parameters
38+
----------
39+
backend_name: str
40+
Name of the backend. Options are `'jax'` and `'numpy'`.
41+
42+
Returns
43+
-------
44+
backends: list
45+
NumPy and LinAlg backends.
46+
"""
47+
48+
_bn = backend_name.lower()
49+
if _bn in NUMERICAL_BACKEND_INSTANCES:
50+
set_default_backend(_bn)
51+
return NUMERICAL_BACKEND_INSTANCES[_bn]
52+
assert _bn in ['jax', 'numpy'], \
53+
"currently supported backends are `'jax'` and `'numpy'`"
54+
55+
if 'jax' in _bn:
56+
try:
57+
# explicitly import and configure jax
58+
import jax
59+
import jax.numpy as jnp
60+
import jax.scipy.linalg as jla
61+
jax.config.update('jax_enable_x64', True)
62+
63+
# # TODO: GPU memory allocation (default is 0.75)
64+
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
65+
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
66+
67+
# set TensorNetwork backend
68+
set_default_backend('jax')
69+
70+
NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla]
71+
return NUMERICAL_BACKEND_INSTANCES['jax']
72+
except ImportError:
73+
print("JAX not installed, defaulting to NumPy")
74+
75+
# set TensorNetwork backend
76+
set_default_backend('numpy')
77+
78+
NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la]
79+
return NUMERICAL_BACKEND_INSTANCES['numpy']
80+
81+
class NumPy:
82+
"""
83+
The NumPy backend employing
84+
dynamic switching through `oqupy.config`.
85+
"""
86+
def __init__(self,
87+
backend_name=oc.DEFAULT_BACKEND,
88+
):
89+
"""Getter for the backend."""
90+
self.backend = get_numerical_backends(backend_name)[0]
91+
92+
@property
93+
def dtype_complex(self) -> default_np.dtype:
94+
"""Getter for the complex datatype."""
95+
return oc.NumPyDtypeComplex
96+
97+
@property
98+
def dtype_float(self) -> default_np.dtype:
99+
"""Getter for the float datatype."""
100+
return oc.NumPyDtypeFloat
101+
102+
def __getattr__(self,
103+
name: str,
104+
):
105+
"""Return the backend's default attribute."""
106+
return getattr(self.backend, name)
107+
108+
def update(self,
109+
array,
110+
indices: tuple,
111+
values,
112+
) -> default_np.ndarray:
113+
"""Option to update select indices of an array with given values."""
114+
if not isinstance(array, default_np.ndarray):
115+
return array.at[indices].set(values)
116+
array[indices] = values
117+
return array
118+
119+
def get_random_floats(self,
120+
seed,
121+
shape,
122+
):
123+
"""Method to obtain random floats with a given seed and shape."""
124+
random_floats = default_np.random.default_rng(seed).random(shape, \
125+
dtype=default_np.float64)
126+
return self.backend.array(random_floats, dtype=self.dtype_float)
127+
128+
class LinAlg:
129+
"""
130+
The Linear Algebra backend employing
131+
dynamic switching through `oqupy.config`.
132+
"""
133+
def __init__(self,
134+
backend_name=oc.DEFAULT_BACKEND,
135+
):
136+
"""Getter for the backend."""
137+
self.backend = get_numerical_backends(backend_name)[1]
138+
139+
def __getattr__(self,
140+
name: str,
141+
):
142+
"""Return the backend's default attribute."""
143+
return getattr(self.backend, name)
144+
145+
# setup libraries using environment variable
146+
# fall back to oqupy.config.DEFAULT_BACKEND
147+
try:
148+
BACKEND_NAME = os.environ[oc.BACKEND_ENV_VAR]
149+
except KeyError:
150+
BACKEND_NAME = oc.DEFAULT_BACKEND
151+
np = NumPy(backend_name=BACKEND_NAME)
152+
la = LinAlg(backend_name=BACKEND_NAME)
153+
154+
def set_numerical_backends(
155+
backend_name: str
156+
):
157+
"""Function to set numerical backend.
158+
159+
Parameters
160+
----------
161+
backend_name: str
162+
Name of the backend. Options are `'jax'` and `'numpy'`.
163+
"""
164+
backends = get_numerical_backends(backend_name)
165+
np.backend = backends[0]
166+
la.backend = backends[1]

0 commit comments

Comments
 (0)