Skip to content

Commit 6f3aea1

Browse files
Sampreetpiperfw
authored andcommitted
Update functions and tests for JAX backend
1 parent 342a393 commit 6f3aea1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+467
-416
lines changed

oqupy/backends/numerical_backend.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import oqupy.config as oc
2020

21-
class Numpy:
21+
class NumPy:
2222
"""
2323
The NumPy backend employing
2424
dynamic switching through `oqupy.config`.
@@ -38,19 +38,28 @@ def dtype_float(self) -> default_np.dtype:
3838
"""Getter for the float datatype."""
3939
return oc.NumPyDtypeFloat
4040

41-
def __getattr__(self, name: str):
41+
def __getattr__(self,
42+
name: str,
43+
):
4244
"""Return the backend's default attribute."""
4345
backend = object.__getattribute__(self, 'backend')
4446
return getattr(backend, name)
4547

46-
def update(self, array, indices:tuple, values) -> default_np.ndarray:
48+
def update(self,
49+
array,
50+
indices:tuple,
51+
values,
52+
) -> default_np.ndarray:
4753
"""Option to update select indices of an array with given values."""
4854
if not isinstance(array, default_np.ndarray):
4955
return array.at[indices].set(values)
5056
array[indices] = values
5157
return array
5258

53-
def get_random_floats(self, seed, shape):
59+
def get_random_floats(self,
60+
seed,
61+
shape,
62+
):
5463
"""Method to obtain random floats with a given seed and shape."""
5564
backend = object.__getattribute__(self, 'backend')
5665
random_floats = default_np.random.default_rng(seed).random(shape, \
@@ -73,5 +82,5 @@ def __getattr__(self, name: str):
7382
return getattr(backend, name)
7483

7584
# initialize for import
76-
np = Numpy()
85+
np = NumPy()
7786
la = LinAlg()

oqupy/backends/pt_tempo_backend.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,6 @@ def initialize(self) -> None:
150150
# the inner `for` loop can be optimized for parallelization
151151
if self._num_infl > 2:
152152
indices = list(range(1, self._num_infl - 1))
153-
# influences_mpo += create_deltas(self._influence, indices,
154-
# [0, 1, 1, 0])
155-
# influences_mps += create_deltas(self._influence, indices,
156-
# [0, 1, 0], scale)
157153
for index in indices:
158154
infl = self._influence(index)
159155
influences_mpo.append(create_delta(infl, [0, 1, 1, 0]))

oqupy/backends/tempo_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,6 @@ def initialize_mps_mpo(self):
434434
# the inner `for` loop can be optimized for parallelization
435435
if dkmax_pre_compute > 1:
436436
indices = list(range(1, dkmax_pre_compute))
437-
# influences += create_deltas(self._influence, indices,
438-
# [1, 0, 0, 1])
439437
for index in indices:
440438
infl = self._influence(index)
441439
influences.append(create_delta(infl, [1, 0, 0, 1]))

oqupy/bath_dynamics.py

Lines changed: 109 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def occupation(
195195
np.sum(_sys_correlations.real*re_kernel \
196196
+ 1j*_sys_correlations.imag*im_kernel, axis = 0)
197197
).real * coup
198-
bath_occupation = np.append([0], bath_occupation)
198+
bath_occupation = np.append(np.array([0]), bath_occupation)
199199
if not change_only and self._temp > 0:
200200
bath_occupation += np.exp(-freq/self._temp) \
201201
/ (1 - np.exp(-freq/self._temp))
@@ -396,54 +396,121 @@ def phase(region, swap_ts = False):
396396
ph += np.exp(a * tk*dt + b * tkp*dt) / (a * b)
397397
return ph
398398

399-
400399
if dagg == (0, 1):
401-
re_kernel[regions['a']] = phase('a') + phase('a', 1)
402-
403-
re_kernel[regions['b']] = phase('b')
404-
405-
im_kernel[regions['a']] = ((2*n_1 + 1) * phase('a') -
406-
(2*n_2 + 1) * phase('a', 1))
407-
408-
im_kernel[regions['b']] = (2*n_1 + 1) * phase('b')
409-
410-
im_kernel[regions['c']] = -2 * (n_1 + 1) * phase('c')
400+
re_kernel = np.update(
401+
array=re_kernel,
402+
indices=regions['a'],
403+
values=phase('a') + phase('a', 1)
404+
)
405+
re_kernel = np.update(
406+
array=re_kernel,
407+
indices=regions['b'],
408+
values=phase('b')
409+
)
410+
411+
im_kernel = np.update(
412+
array=im_kernel,
413+
indices=regions['a'],
414+
values=((2 * n_1 + 1) * phase('a') -
415+
(2 * n_2 + 1) * phase('a', 1))
416+
)
417+
im_kernel = np.update(
418+
array=im_kernel,
419+
indices=regions['b'],
420+
values=(2 * n_1 + 1) * phase('b')
421+
)
422+
im_kernel = np.update(
423+
array=im_kernel,
424+
indices=regions['c'],
425+
values=- 2 * (n_1 + 1) * phase('c')
426+
)
411427

412428
elif dagg == (1, 0):
413-
re_kernel[regions['a']] = phase('a') + phase('a', 1)
414-
415-
re_kernel[regions['b']] = phase('b')
416-
417-
im_kernel[regions['a']] = ((2*n_1 + 1) * phase('a') -
418-
(2*n_2 + 1) * phase('a', 1))
419-
420-
im_kernel[regions['b']] = (2*n_1 + 1) * phase('b')
421-
422-
im_kernel[regions['c']] = 2 * n_1 * phase('c')
429+
re_kernel = np.update(
430+
array=re_kernel,
431+
indices=regions['a'],
432+
values=phase('a') + phase('a', 1)
433+
)
434+
re_kernel = np.update(
435+
array=re_kernel,
436+
indices=regions['b'],
437+
values=phase('b')
438+
)
439+
440+
im_kernel = np.update(
441+
array=im_kernel,
442+
indices=regions['a'],
443+
values=((2*n_1 + 1) * phase('a') -
444+
(2*n_2 + 1) * phase('a', 1))
445+
)
446+
im_kernel = np.update(
447+
array=im_kernel,
448+
indices=regions['b'],
449+
values=(2*n_1 + 1) * phase('b')
450+
)
451+
im_kernel = np.update(
452+
array=im_kernel,
453+
indices=regions['c'],
454+
values=2 * n_1 * phase('c')
455+
)
423456

424457
elif dagg == (1, 1):
425-
re_kernel[regions['a']] = -(phase('a') + phase('a', 1))
426-
427-
re_kernel[regions['b']] = -phase('b')
428-
429-
im_kernel[regions['a']] = ((2*n_1 + 1) * phase('a') +
430-
(2*n_2 + 1) * phase('a', 1))
431-
432-
im_kernel[regions['b']] = (2*n_1 + 1) * phase('b')
433-
434-
im_kernel[regions['c']] = 2 * (n_1 + 1) * phase('c')
458+
re_kernel = np.update(
459+
array=re_kernel,
460+
indices=regions['a'],
461+
values=- (phase('a') + phase('a', 1))
462+
)
463+
re_kernel = np.update(
464+
array=re_kernel,
465+
indices=regions['b'],
466+
values=- phase('b')
467+
)
468+
469+
im_kernel = np.update(
470+
array=im_kernel,
471+
indices=regions['a'],
472+
values=((2*n_1 + 1) * phase('a') +
473+
(2*n_2 + 1) * phase('a', 1))
474+
)
475+
im_kernel = np.update(
476+
array=im_kernel,
477+
indices=regions['b'],
478+
values=(2*n_1 + 1) * phase('b')
479+
)
480+
im_kernel = np.update(
481+
array=im_kernel,
482+
indices=regions['c'],
483+
values=2 * (n_1 + 1) * phase('c')
484+
)
435485

436486
elif dagg == (0, 0):
437-
re_kernel[regions['a']] = -(phase('a') + phase('a', 1))
438-
439-
re_kernel[regions['b']] = -phase('b')
440-
441-
im_kernel[regions['a']] = -((2*n_2 + 1) * phase('a', 1) +
442-
(2*n_1 + 1) * phase('a'))
443-
444-
im_kernel[regions['b']] = -(2*n_1 + 1) * phase('b')
445-
446-
im_kernel[regions['c']] = -2 * n_1 * phase('c')
487+
re_kernel = np.update(
488+
array=re_kernel,
489+
indices=regions['a'],
490+
values=- (phase('a') + phase('a', 1))
491+
)
492+
re_kernel = np.update(
493+
array=re_kernel,
494+
indices=regions['b'],
495+
values=- phase('b')
496+
)
497+
498+
im_kernel = np.update(
499+
array=im_kernel,
500+
indices=regions['a'],
501+
values=- ((2*n_1 + 1) * phase('a') +
502+
(2*n_2 + 1) * phase('a', 1))
503+
)
504+
im_kernel = np.update(
505+
array=im_kernel,
506+
indices=regions['b'],
507+
values=- (2*n_1 + 1) * phase('b')
508+
)
509+
im_kernel = np.update(
510+
array=im_kernel,
511+
indices=regions['c'],
512+
values=- 2 * n_1 * phase('c')
513+
)
447514

448515
re_kernel = np.triu(re_kernel) #only keep triangular region
449516
im_kernel = np.triu(im_kernel)

oqupy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import scipy.linalg as default_la
1919
NUMERICAL_BACKEND_NUMPY = default_np
2020
NUMERICAL_BACKEND_LINALG = default_la
21-
NumPyDtypeComplex = default_np.complex128
22-
NumPyDtypeFloat = default_np.float64
21+
NumPyDtypeComplex = default_np.complex128 # earlier NpDtype
22+
NumPyDtypeFloat = default_np.float64 # earlier NpDtypeReal
2323

2424
# Separator string for __str__ functions
2525
SEPERATOR = "----------------------------------------------\n"

oqupy/dynamics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def __str__(self) -> Text:
5252
ret.append(" length = {} timesteps \n".format(len(self)))
5353
if len(self) > 0:
5454
ret.append(" min time = {} \n".format(
55-
np.min(self._times)))
55+
np.min(self.times)))
5656
ret.append(" max time = {} \n".format(
57-
np.max(self._times)))
57+
np.max(self.times)))
5858
return "".join(ret)
5959

6060
def __len__(self) -> int:
@@ -181,7 +181,7 @@ def add(
181181
@property
182182
def shape(self):
183183
"""The shape of the states. """
184-
return np.copy(self._shape)
184+
return copy(self._shape)
185185

186186
class MeanFieldDynamics(BaseAPIClass):
187187
"""

oqupy/gradient.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def state_gradient(
7676
'gradient' : derivatives of Z with respect to the parameters
7777
'dynamics' : the dynamics of the system
7878
"""
79-
check_isinstance(parameters, ndarray, 'parameters')
79+
check_isinstance(parameters, np.ndarray, 'parameters')
8080

8181
num_steps = len(process_tensors[0])
8282
dt = process_tensors[0].dt
@@ -154,14 +154,20 @@ def combine_derivs(
154154
prog_bar.update(i)
155155

156156
for j in range(0,num_parameters):
157-
total_derivs[2*i][j] = combine_derivs(
158-
adjoint_tensor[i],
159-
first_half_prop_derivs[j].T,
160-
second_half_prop.T)
161-
total_derivs[2*i+1][j] = combine_derivs(
162-
adjoint_tensor[i],
163-
first_half_prop.T,
164-
second_half_prop_derivs[j].T)
157+
total_derivs = np.update(
158+
array=total_derivs,
159+
indices=(2 * i, j),
160+
values=combine_derivs(adjoint_tensor[i],
161+
first_half_prop_derivs[j].T,
162+
second_half_prop.T)
163+
)
164+
total_derivs = np.update(
165+
array=total_derivs,
166+
indices=(2 * i + 1, j),
167+
values=combine_derivs(adjoint_tensor[i],
168+
first_half_prop.T,
169+
second_half_prop_derivs[j].T)
170+
)
165171

166172
prog_bar.update(num_steps)
167173
prog_bar.exit()
@@ -265,7 +271,7 @@ def controls(step: int):
265271
# edges 0, 1, .., num_envs-1 are the bond legs of the environments
266272
# edge -1 is the state leg
267273
initial_ndarray = initial_state.reshape(hs_dim**2)
268-
initial_ndarray.shape = tuple([1]*num_envs+[hs_dim**2])
274+
initial_ndarray = initial_ndarray.reshape(tuple([1]*num_envs+[hs_dim**2]))
269275
current_node = tn.Node(initial_ndarray)
270276
current_edges = current_node[:]
271277

@@ -356,7 +362,7 @@ def controls(step: int):
356362

357363
target_ndarray = target_derivative
358364
target_ndarray = target_ndarray.reshape(hs_dim**2)
359-
target_ndarray.shape = tuple([1]*num_envs+[hs_dim**2])
365+
target_ndarray = target_ndarray.reshape(tuple([1]*num_envs+[hs_dim**2]))
360366
current_node = tn.Node(target_ndarray)
361367
current_edges = current_node[:]
362368

oqupy/mps_mpo.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ def compute_nn_gate(
199199
# exponentiate the liouvillian to become a propagator
200200
propagator = la.expm(dt*liouvillian)
201201
# split leg 0 and leg 1 each into left and right.
202-
propagator.shape = [hs_dim_l**2, # left output
203-
hs_dim_r**2, # right output
204-
hs_dim_l**2, # left input
205-
hs_dim_r**2] # right input
202+
propagator = propagator.reshape([hs_dim_l**2, # left output
203+
hs_dim_r**2, # right output
204+
hs_dim_l**2, # left input
205+
hs_dim_r**2]) # right input
206206
temp = np.swapaxes(propagator, 1, 2)
207207
temp = temp.reshape([hs_dim_l**2 * hs_dim_l**2,
208208
hs_dim_r**2 * hs_dim_r**2])
@@ -396,11 +396,14 @@ def __init__(
396396
if rank == 4:
397397
tmp_gamma = np.swapaxes(tmp_gamma,2,3)
398398
elif rank == 3:
399-
tmp_gamma.shape = (shape[0], shape[1], 1, shape[2])
399+
tmp_gamma = tmp_gamma.reshape((shape[0], shape[1], \
400+
1, shape[2]))
400401
elif rank == 2:
401-
tmp_gamma.shape = (1, shape[0]*shape[1], 1, 1)
402+
tmp_gamma = tmp_gamma.reshape((1, shape[0] * shape[1], \
403+
1, 1))
402404
elif rank == 1:
403-
tmp_gamma.shape = (1, shape[0], 1, 1)
405+
tmp_gamma = tmp_gamma.reshape((1, shape[0], \
406+
1, 1))
404407
else:
405408
raise ValueError()
406409
tmp_gammas.append(tmp_gamma)

0 commit comments

Comments
 (0)