1313Module containing NumPy-like and SciPy-like numerical backends.
1414"""
1515
16+ import os
17+
1618import numpy as default_np
1719import scipy .linalg as default_la
1820
21+ from tensornetwork .backend_contextmanager import \
22+ set_default_backend
23+
1924import oqupy .config as oc
2025
26+ # store instances of the initialized backends
27+ # this way, `oqupy.config` remains unchanged
28+ # and `ocupy.config.DEFAULT_BACKEND` is used
29+ # when NumPy and LinAlg are initialized
30+ NUMERICAL_BACKEND_INSTANCES = {}
31+
32+ def get_numerical_backends (
33+ backend_name : str ,
34+ ):
35+ """Function to get numerical backend.
36+
37+ Parameters
38+ ----------
39+ backend_name: str
40+ Name of the backend. Options are `'jax'` and `'numpy'`.
41+
42+ Returns
43+ -------
44+ backends: list
45+ NumPy and LinAlg backends.
46+ """
47+
48+ _bn = backend_name .lower ()
49+ if _bn in NUMERICAL_BACKEND_INSTANCES :
50+ set_default_backend (_bn )
51+ return NUMERICAL_BACKEND_INSTANCES [_bn ]
52+ assert _bn in ['jax' , 'numpy' ], \
53+ "currently supported backends are `'jax'` and `'numpy'`"
54+
55+ if 'jax' in _bn :
56+ try :
57+ # explicitly import and configure jax
58+ import jax
59+ import jax .numpy as jnp
60+ import jax .scipy .linalg as jla
61+ jax .config .update ('jax_enable_x64' , True )
62+
63+ # # TODO: GPU memory allocation (default is 0.75)
64+ # os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
65+ # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'
66+
67+ # set TensorNetwork backend
68+ set_default_backend ('jax' )
69+
70+ NUMERICAL_BACKEND_INSTANCES ['jax' ] = [jnp , jla ]
71+ return NUMERICAL_BACKEND_INSTANCES ['jax' ]
72+ except ImportError :
73+ print ("JAX not installed, defaulting to NumPy" )
74+
75+ # set TensorNetwork backend
76+ set_default_backend ('numpy' )
77+
78+ NUMERICAL_BACKEND_INSTANCES ['numpy' ] = [default_np , default_la ]
79+ return NUMERICAL_BACKEND_INSTANCES ['numpy' ]
80+
2181class NumPy :
2282 """
2383 The NumPy backend employing
2484 dynamic switching through `oqupy.config`.
2585 """
26- @property
27- def backend (self ) -> default_np :
86+ def __init__ (self ,
87+ backend_name = oc .DEFAULT_BACKEND ,
88+ ):
2889 """Getter for the backend."""
29- return oc . NUMERICAL_BACKEND_NUMPY
90+ self . backend = get_numerical_backends ( backend_name )[ 0 ]
3091
3192 @property
3293 def dtype_complex (self ) -> default_np .dtype :
@@ -42,12 +103,11 @@ def __getattr__(self,
42103 name : str ,
43104 ):
44105 """Return the backend's default attribute."""
45- backend = object .__getattribute__ (self , 'backend' )
46- return getattr (backend , name )
106+ return getattr (self .backend , name )
47107
48108 def update (self ,
49109 array ,
50- indices :tuple ,
110+ indices : tuple ,
51111 values ,
52112 ) -> default_np .ndarray :
53113 """Option to update select indices of an array with given values."""
@@ -61,26 +121,46 @@ def get_random_floats(self,
61121 shape ,
62122 ):
63123 """Method to obtain random floats with a given seed and shape."""
64- backend = object .__getattribute__ (self , 'backend' )
65124 random_floats = default_np .random .default_rng (seed ).random (shape , \
66125 dtype = default_np .float64 )
67- return backend .array (random_floats , dtype = self .dtype_float )
126+ return self . backend .array (random_floats , dtype = self .dtype_float )
68127
69128class LinAlg :
70129 """
71130 The Linear Algebra backend employing
72131 dynamic switching through `oqupy.config`.
73132 """
74- @property
75- def backend (self ) -> default_la :
133+ def __init__ (self ,
134+ backend_name = oc .DEFAULT_BACKEND ,
135+ ):
76136 """Getter for the backend."""
77- return oc . NUMERICAL_BACKEND_LINALG
137+ self . backend = get_numerical_backends ( backend_name )[ 1 ]
78138
79- def __getattr__ (self , name : str ):
139+ def __getattr__ (self ,
140+ name : str ,
141+ ):
80142 """Return the backend's default attribute."""
81- backend = object .__getattribute__ (self , 'backend' )
82- return getattr (backend , name )
143+ return getattr (self .backend , name )
144+
145+ # setup libraries using environment variable
146+ # fall back to oqupy.config.DEFAULT_BACKEND
147+ try :
148+ BACKEND_NAME = os .environ [oc .BACKEND_ENV_VAR ]
149+ except KeyError :
150+ BACKEND_NAME = oc .DEFAULT_BACKEND
151+ np = NumPy (backend_name = BACKEND_NAME )
152+ la = LinAlg (backend_name = BACKEND_NAME )
83153
84- # initialize for import
85- np = NumPy ()
86- la = LinAlg ()
154+ def set_numerical_backends (
155+ backend_name : str
156+ ):
157+ """Function to set numerical backend.
158+
159+ Parameters
160+ ----------
161+ backend_name: str
162+ Name of the backend. Options are `'jax'` and `'numpy'`.
163+ """
164+ backends = get_numerical_backends (backend_name )
165+ np .backend = backends [0 ]
166+ la .backend = backends [1 ]
0 commit comments