-
Notifications
You must be signed in to change notification settings - Fork 30
Description
Summary
Currently the oqupy.util.create_delta(tensor, index_scrambling) function returns a axis-scrambled tensor from a given tensor. The shape of the output tensor is determined by the index_scrambling parameter whose indices are the axis-indices of the original tensor. A way to speed up a special case of this function where a rank-n tensor is converted to a rank-(n + 1) tensor (equivalent to index_scrambling set to [0, 1, ..., n-2, n-1, n-1]) has been proposed in this issue. I am opening this issue to discuss a more general and NumPy-friendly version of create_delta by eliminating the while loop and the recursive increase_list_of_index function.
Changelog
- 2024-09-29: Updated symbols and function docs.
- 2024-09-01: Corrected typos, updated scripts and added new plots. Removed runtimes for the extension as those included variable runtimes for the influence matrices (which utilizes least recently used cache).
Motivation
The rationale behind this change is not only to speed up the create_delta function but also make it vectorizable and to support just-in-time (JIT) compilation for certain steps.
Related Theory
From what I understand,
- The
increase_list_of_indexfunction iterates over the axes oftensorbackwards and returnsFalsewhen all the indices are covered. This is leads to a total number of iterations$n_{iters} = m_{n_{in} - 1} \times ... \times m_1 \times m_0$ inside thewhileloop ofcreate_delta, where$n_{in}$ is the rank oftensorand$m$ denotes the number of elements in each axis, such that the shape of the tensor is ($m_0$ ,$m_1$ , ...,$m_{n_{in} -1}$ ). Thewhileloop and the condition-function can therefore be replaced by aforloop using$n_{iters}$ iterables each of the input and output axes. - The
ret_ndarrayis a tensor of rank$n_{out}$ , determined by the total number of elements inindex_scrambling. Since the elements ofindex_scramblingare the indices for the axes oftensor, by constructing a 2D-arrayindices_inwith shape ($n_{in}, n_{iters}$ ) whose elements slice the individual axes oftensorto return the required values at each iteration, one can obtain the 2D-arrayindices_outwith shape ($n_{out}, n_{iters}$ ) for the updated indices ofret_ndarrayby slicingindices_inwithindex_scrambling. This removes the requirement of aforloop.
Implementation
The create_delta function is modified to:
def create_delta(
tensor: ndarray,
index_scrambling: List[int],
) -> ndarray:
"""Creates deltas in a tensor."""
# converting to NumPy-array for future-proof implementation
# see [this issue](https://github.com/google/jax/issues/4564)
# the shape of the tensor has n_in elements whereas
# index_scrambling has n_out elements
_shape = np.array(tensor.shape, dtype=int)
_idxs = np.array(index_scrambling, dtype=int)
# obtain the selection indices for each axis
_indices = get_indices(_shape, np.prod(_shape))
# scramble output tensor with elements of input tensor
scrambled_tensor = np.zeros(tuple(_shape[_idxs]), \
dtype=tensor.dtype)
scrambled_tensor[tuple(_indices[_idxs])] = tensor[tuple(_indices)]
return scrambled_tensor
def get_indices(
shape: ndarray,
n_iters: int,
) -> ndarray:
"""Obtain index matrix for scrambling."""
# obtain divisors for each axis as values equal to the
# number of elements contained upto the preceeding axes
# for e.g., shape [4, 5, 3] will result in [15, 3, 1]
divisors = np.cumprod(np.concatenate([
shape[1:],
np.array([1], dtype=int)
])[::-1])[::-1]
# prepare an iteration matrix of shape (n_iters, n_in)
# to index each axis, for e.g., n_iters = 3 x 5 x 4
iteration_matrix = np.arange(0, n_iters).reshape(
(n_iters, 1)).repeat(shape.shape[0], 1)
# divide each element with the divisors obtained above
# and obtain the remainder modullo the size of each axis
# return the index matrix with shape (n_in, n_iters)
return ((iteration_matrix / divisors).astype(int) % shape).TThe idea behind segregating the get_indices function is to make it JIT-compatible. Here, n_iters is passed separately to maintain JAX-compatibility with np.arange. All the corresponding changes can be viewed by comparing the pr/enhancement-create-delta branch.
Comparison
The following plots illustrate the speedups for four different scenarios (use-cases in oqupy.backends.tempo_backend, oqupy.backends.pt_tempo_backend and oqupy.process_tensor, and runtimes for multi-time correlations as demonstrated in Fig. 5 of arXiv:2406.16650):
Reproducing the Comparisons
The following snippet can be used to reproduce the first three plots:
ms = np.arange(2, 16)**2 # axis dimensions
shapes = [(m, m) for m in ms]
index_scrambling = [1, 0, 0, 1] # or [0, 1, 1, 0]
# # uncomment for third plot
# chi_ds = np.arange(2, 503, 20) # bond dimensions
# shapes = [(chi_d, chi_d, 4) for chi_d in chi_ds]
# index_scrambling = [0, 1, 2, 2]
funcs = [create_delta_old, create_delta]
times = []
average_over = 5
for shape in shapes:
tensor_in = np.ones(shape, dtype=np.complex_)
ts = []
for j in range(len(funcs)):
start = time.time()
for i in range(average_over):
_ = funcs[j](
tensor=tensor_in,
index_scrambling=index_scrambling
)
ts.append((time.time() - start) / average_over)
times.append(ts)The final plot is obtained with the same methods and parameters as mentioned in the arXiv preprint. The runtimes of the general case proposed here are close to those of the special-case snippet posted by @eoin-dp-oneill in the previous issue. Also, I observed slightly faster runtimes for higher dimensions with JIT-ted JAX-CPU implementation of get_indices. All plots are obtained using an Intel i7 8700K processor throttled at 90% usage.
Extension
The initialize_mps_mpo method of oqupy.backends.tempo_backend.BaseTempoBackend involves several calls to create_delta successively, within a for loop running for dkmax_pre_compute - 1 steps with equal shapes of the influence tensor (for i != 0). Same goes for the oqupy.backends.pt_tempo_backend.PtTempoBackend.initialize_mps_mpo method but with one less step at the end. Since the influence tensor is a reproducible tensor for each i, a new oqupy.util function create_deltas (with an s) can be implemented to speed up this successive computation as follows:
def create_deltas(
func_tensors: callable,
indices: List[int],
index_scrambling: List[int]) -> List[ndarray]:
"""Creates deltas in multiple tensors."""
# use a test tensor to obtain the indices
tensor = func_tensors(indices[0])
_shape = np.array(tensor.shape, dtype=int)
_idxs = np.array(index_scrambling, dtype=int)
_indices = get_indices(_shape, np.prod(_shape))
indices_in = tuple(_indices)
indices_out = tuple(_indices[_idxs])
# accumulate scrambled tensors and return list
scrambled_tensors = []
for i in indices:
array = np.zeros(tuple(_shape[_idxs]), \
dtype=tensor.dtype)
array[indices_out] = func_tensors(i)[indices_in]
scrambled_tensors.append(array)
return scrambled_tensorsThe corresponding code block of for loop in initialize_mps_mpo (say for BaseTempoBackend) can be modified to:
influences = []
# this block takes care of `i == 0`
infl = self._influence(0)
if self._degeneracy_maps is not None:
infl_four_legs = np.zeros((tmp_west_deg_num_vals, self._dim**2,
tmp_north_deg_num_vals, self._dim**2), \
dtype=NpDtype)
# a little bit of optimization is done here by
# removing the `for` loop and updating slices
_idxs = np.array(list(range(self._dim**2)))
indices = (west_degeneracy_map[_idxs], _idxs,
north_degeneracy_map[_idxs], _idxs)
infl_four_legs[indices] = infl[indices[2]]
else:
infl_four_legs = create_delta(infl, [1, 0, 0, 1])
infl_four_legs = np.dot(np.moveaxis(infl_four_legs, 1, -1), \
self._super_u_dagg)
infl_four_legs = np.moveaxis(infl_four_legs, -1, 1)
infl_four_legs = np.dot(infl_four_legs, self._super_u.T)
influences.append(infl_four_legs)
# this block takes care of `i > 0`
if dkmax_pre_compute > 1:
indices = list(range(1, dkmax_pre_compute))
influences += create_deltas(self._influence, indices,
[1, 0, 0, 1])
# # uncomment to test the new function
# for index in indices:
# infl = self._influence(index)
# influences.append(create_delta(infl, [1, 0, 0, 1]))I have tested the above implementations using tox and reproduced the plots of arXiv:2406.16650 for both of the mentioned approaches. I hope that the generalized function will be useful to implement more platform-dependent special cases. Kindly share your views on the same as per your convenience.



