Skip to content

Commit 2f03fbc

Browse files
Sampreetpiperfw
authored andcommitted
Update numerical backend and add example
1 parent 93d168e commit 2f03fbc

File tree

9 files changed

+161
-63
lines changed

9 files changed

+161
-63
lines changed

docs/pages/api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ class :class:`oqupy.pt_tebd.PtTebd`
178178
dictionary.
179179

180180

181-
182181
Results
183182
-------
184183

@@ -207,3 +206,6 @@ module :mod:`oqupy.operators`
207206
function :func:`oqupy.helpers.plot_correlations_with_parameters`
208207
A helper function to plot an auto-correlation function and the sampling
209208
points given by a set of parameters for a TEMPO computation.
209+
210+
function :func:`oqupy.backends.enable_jax_features`
211+
Option to use JAX to support multiple device backends (CPUs/GPUs/TPUs).

docs/pages/gpu_features.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ To enable experimental features, switch to the ``dev/jax`` branch and use
2222

2323
.. code:: python
2424
25-
from oqupy.backends import enable_gpu_features
26-
enable_gpu_features()
25+
from oqupy.backends import enable_jax_features
26+
enable_jax_features()
2727
2828
Contributing Guidelines
2929
~~~~~~~~~~~~~~~~~~~~~~~

examples/simple_dynamics_jax.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python
2+
3+
import sys
4+
sys.path.insert(0, '.')
5+
# set the 'OQUPY_BACKEND' environment variable
6+
# to 'jax' to initialize JAX backend by default
7+
# or switch to JAX backend using oqupy.backends
8+
import oqupy
9+
from oqupy.backends import enable_jax_features
10+
# import NumPy from numerical_backend
11+
from oqupy.backends.numerical_backend import np
12+
enable_jax_features()
13+
14+
import matplotlib.pyplot as plt
15+
sigma_x = oqupy.operators.sigma("x")
16+
sigma_z = oqupy.operators.sigma("z")
17+
up_density_matrix = oqupy.operators.spin_dm("z+")
18+
Omega = 1.0
19+
omega_cutoff = 5.0
20+
alpha = 0.3
21+
22+
system = oqupy.System(0.5 * Omega * sigma_x)
23+
correlations = oqupy.PowerLawSD(alpha=alpha,
24+
zeta=1,
25+
cutoff=omega_cutoff,
26+
cutoff_type='exponential')
27+
bath = oqupy.Bath(0.5 * sigma_z, correlations)
28+
tempo_parameters = oqupy.TempoParameters(dt=0.1, tcut=3.0, epsrel=10**(-4))
29+
30+
dynamics = oqupy.tempo_compute(system=system,
31+
bath=bath,
32+
initial_state=up_density_matrix,
33+
start_time=0.0,
34+
end_time=2.0,
35+
parameters=tempo_parameters,
36+
unique=True)
37+
t, s_z = dynamics.expectations(0.5*sigma_z, real=True)
38+
print(s_z)
39+
plt.plot(t, s_z, label=r'$\alpha=0.3$')
40+
plt.xlabel(r'$t\,\Omega$')
41+
plt.ylabel(r'$\langle\sigma_z\rangle$')
42+
#plt.savefig('simple_dynamics.png')
43+
plt.show()

oqupy/backends/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Module to initialize OQuPy's backends."""
2+
3+
from oqupy.backends.numerical_backend import set_numerical_backends
4+
5+
def enable_jax_features():
6+
"""Function to enable experimental features."""
7+
8+
# set numerical backend to JAX
9+
set_numerical_backends('jax')

oqupy/backends/numerical_backend.py

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,81 @@
1313
Module containing NumPy-like and SciPy-like numerical backends.
1414
"""
1515

16+
import os
17+
1618
import numpy as default_np
1719
import scipy.linalg as default_la
1820

21+
from tensornetwork.backend_contextmanager import \
22+
set_default_backend
23+
1924
import oqupy.config as oc
2025

26+
# store instances of the initialized backends
27+
# this way, `oqupy.config` remains unchanged
28+
# and `ocupy.config.DEFAULT_BACKEND` is used
29+
# when NumPy and LinAlg are initialized
30+
NUMERICAL_BACKEND_INSTANCES = {}
31+
32+
def get_numerical_backends(
33+
backend_name: str,
34+
):
35+
"""Function to get numerical backend.
36+
37+
Parameters
38+
----------
39+
backend_name: str
40+
Name of the backend. Options are `'jax'` and `'numpy'`.
41+
42+
Returns
43+
-------
44+
backends: list
45+
NumPy and LinAlg backends.
46+
"""
47+
48+
_bn = backend_name.lower()
49+
if _bn in NUMERICAL_BACKEND_INSTANCES:
50+
set_default_backend(_bn)
51+
return NUMERICAL_BACKEND_INSTANCES[_bn]
52+
assert _bn in ['jax', 'numpy'], \
53+
"currently supported backends are `'jax'` and `'numpy'`"
54+
55+
if 'jax' in _bn:
56+
try:
57+
# explicitly import and configure jax
58+
import jax
59+
import jax.numpy as jnp
60+
import jax.scipy.linalg as jla
61+
jax.config.update('jax_enable_x64', True)
62+
63+
# # TODO: GPU memory allocation (default is 0.75)
64+
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
65+
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
66+
67+
# set TensorNetwork backend
68+
set_default_backend('jax')
69+
70+
NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla]
71+
return NUMERICAL_BACKEND_INSTANCES['jax']
72+
except ImportError:
73+
print("JAX not installed, defaulting to NumPy")
74+
75+
# set TensorNetwork backend
76+
set_default_backend('numpy')
77+
78+
NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la]
79+
return NUMERICAL_BACKEND_INSTANCES['numpy']
80+
2181
class NumPy:
2282
"""
2383
The NumPy backend employing
2484
dynamic switching through `oqupy.config`.
2585
"""
26-
@property
27-
def backend(self) -> default_np:
86+
def __init__(self,
87+
backend_name=oc.DEFAULT_BACKEND,
88+
):
2889
"""Getter for the backend."""
29-
return oc.NUMERICAL_BACKEND_NUMPY
90+
self.backend = get_numerical_backends(backend_name)[0]
3091

3192
@property
3293
def dtype_complex(self) -> default_np.dtype:
@@ -42,12 +103,11 @@ def __getattr__(self,
42103
name: str,
43104
):
44105
"""Return the backend's default attribute."""
45-
backend = object.__getattribute__(self, 'backend')
46-
return getattr(backend, name)
106+
return getattr(self.backend, name)
47107

48108
def update(self,
49109
array,
50-
indices:tuple,
110+
indices: tuple,
51111
values,
52112
) -> default_np.ndarray:
53113
"""Option to update select indices of an array with given values."""
@@ -61,26 +121,46 @@ def get_random_floats(self,
61121
shape,
62122
):
63123
"""Method to obtain random floats with a given seed and shape."""
64-
backend = object.__getattribute__(self, 'backend')
65124
random_floats = default_np.random.default_rng(seed).random(shape, \
66125
dtype=default_np.float64)
67-
return backend.array(random_floats, dtype=self.dtype_float)
126+
return self.backend.array(random_floats, dtype=self.dtype_float)
68127

69128
class LinAlg:
70129
"""
71130
The Linear Algebra backend employing
72131
dynamic switching through `oqupy.config`.
73132
"""
74-
@property
75-
def backend(self) -> default_la:
133+
def __init__(self,
134+
backend_name=oc.DEFAULT_BACKEND,
135+
):
76136
"""Getter for the backend."""
77-
return oc.NUMERICAL_BACKEND_LINALG
137+
self.backend = get_numerical_backends(backend_name)[1]
78138

79-
def __getattr__(self, name: str):
139+
def __getattr__(self,
140+
name: str,
141+
):
80142
"""Return the backend's default attribute."""
81-
backend = object.__getattribute__(self, 'backend')
82-
return getattr(backend, name)
143+
return getattr(self.backend, name)
144+
145+
# setup libraries using environment variable
146+
# fall back to oqupy.config.DEFAULT_BACKEND
147+
try:
148+
BACKEND_NAME = os.environ[oc.BACKEND_ENV_VAR]
149+
except KeyError:
150+
BACKEND_NAME = oc.DEFAULT_BACKEND
151+
np = NumPy(backend_name=BACKEND_NAME)
152+
la = LinAlg(backend_name=BACKEND_NAME)
83153

84-
# initialize for import
85-
np = NumPy()
86-
la = LinAlg()
154+
def set_numerical_backends(
155+
backend_name: str
156+
):
157+
"""Function to set numerical backend.
158+
159+
Parameters
160+
----------
161+
backend_name: str
162+
Name of the backend. Options are `'jax'` and `'numpy'`.
163+
"""
164+
backends = get_numerical_backends(backend_name)
165+
np.backend = backends[0]
166+
la.backend = backends[1]

oqupy/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
# numerical backend
1717
import numpy as default_np
18-
import scipy.linalg as default_la
19-
NUMERICAL_BACKEND_NUMPY = default_np
20-
NUMERICAL_BACKEND_LINALG = default_la
18+
BACKEND_ENV_VAR = 'OQUPY_BACKEND'
19+
DEFAULT_BACKEND = 'numpy'
2120
NumPyDtypeComplex = default_np.complex128 # earlier NpDtype
2221
NumPyDtypeFloat = default_np.float64 # earlier NpDtypeReal
2322

pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ disable=raw-checker-failed,
7979
consider-using-f-string, # GEFux: added for development
8080
too-many-arguments, # piperfw: added for development
8181
too-many-positional-arguments, # piperfw: added 2024-09-27
82+
import-outside-toplevel, # sampreet: added 2024-10-24 for JAX
8283
possibly-used-before-assignment, # GEFux: added by hand
8384
unnecessary-lambda-assignment # GEFux: added by hand
8485

tests/coverage/__init__.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,2 @@
1-
from importlib.util import find_spec
2-
3-
if find_spec('jax') is not None:
4-
# JAX configuration
5-
import jax
6-
import jax.numpy as jnp
7-
import jax.scipy.linalg as jla
8-
import oqupy.config as oc
9-
import tensornetwork as tn
10-
jax.config.update('jax_enable_x64', True)
11-
oc.NUMERICAL_BACKEND_NUMPY = jnp
12-
oc.NumPyDtypeComplex = jnp.complex128
13-
oc.NumPyDtypeFloat = jnp.float64
14-
oc.NUMERICAL_BACKEND_LINALG = jla
15-
tn.set_default_backend('jax')
16-
17-
# # TODO: GPU memory allocation (default is 0.75)
18-
# import os
19-
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
20-
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
1+
# from oqupy.backends import enable_jax_features
2+
# enable_jax_features()

tests/physics/__init__.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,2 @@
1-
from importlib.util import find_spec
2-
3-
if find_spec('jax') is not None:
4-
# JAX configuration
5-
import jax
6-
import jax.numpy as jnp
7-
import jax.scipy.linalg as jla
8-
import oqupy.config as oc
9-
import tensornetwork as tn
10-
jax.config.update('jax_enable_x64', True)
11-
oc.NUMERICAL_BACKEND_NUMPY = jnp
12-
oc.NumPyDtypeComplex = jnp.complex128
13-
oc.NumPyDtypeFloat = jnp.float64
14-
oc.NUMERICAL_BACKEND_LINALG = jla
15-
tn.set_default_backend('jax')
16-
17-
# # TODO: GPU memory allocation (default is 0.75)
18-
# import os
19-
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
20-
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
1+
# from oqupy.backends import enable_jax_features
2+
# enable_jax_features()

0 commit comments

Comments
 (0)