Skip to content

Commit 18a3dce

Browse files
authored
Merge pull request #21 from ghanem-nv/misc_improvements
Misc improvements
2 parents c26264f + a91cecb commit 18a3dce

File tree

10 files changed

+496
-310
lines changed

10 files changed

+496
-310
lines changed

LICENSE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
BSD 3-Clause License
22

3-
Copyright (c) 2025, QuTiP
3+
Copyright (c) 2025, NVIDIA CORPORATION, QuTiP developers and contributors.
4+
All rights reserved.
45

56
Redistribution and use in source and binary forms, with or without
67
modification, are permitted provided that the following conditions are met:

src/qutip_cuquantum/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .operator import CuOperator
2929
from .state import CuState
3030
import numpy
31-
31+
from .mixed_dispatch import * #Ensure that the mixed dispatch is registered
3232

3333
# TODO: The split per density is not great
3434
# Add an operator / state split in qutip?
@@ -95,7 +95,6 @@ def set_as_default(ctx: cuquantum.densitymat.WorkStream=None, reverse=False):
9595
if not reverse:
9696
settings.cuDensity["ctx"] = ctx
9797
settings.core["default_dtype"] = "cuDensity"
98-
settings.core['numpy_backend'] = cupy
9998

10099
if True: # if mpi, how to check from ctx?
101100
settings.core["auto_real_casting"] = False
@@ -111,7 +110,6 @@ def set_as_default(ctx: cuquantum.densitymat.WorkStream=None, reverse=False):
111110

112111
else:
113112
settings.core["default_dtype"] = "core"
114-
settings.core['numpy_backend'] = numpy
115113
settings.core["auto_real_casting"] = True
116114

117115
SESolver.solver_options['method'] = "adams"
@@ -145,7 +143,6 @@ def __enter__(self):
145143
self.previous_values["default_dtype"] = qutip.settings.core["default_dtype"]
146144
settings.core["default_dtype"] = "cuDensity"
147145
self.previous_values["numpy_backend"] = qutip.settings.core["numpy_backend"]
148-
settings.core['numpy_backend'] = cupy
149146

150147
self.previous_values["auto_real"] = settings.core["auto_real_casting"]
151148
if True: # if mpi, how to check from ctx?
@@ -169,7 +166,6 @@ def __enter__(self):
169166

170167
def __exit__(self, exc_type, exc_value, traceback):
171168
settings.core["default_dtype"] = self.previous_values["default_dtype"]
172-
settings.core['numpy_backend'] = self.previous_values["numpy_backend"]
173169
settings.core["auto_real_casting"] = self.previous_values["auto_real"]
174170
SESolver.solver_options['method'] = self.previous_values["SESolverM"]
175171
MESolver.solver_options['method'] = self.previous_values["MESolverM"]

src/qutip_cuquantum/mixed_dispatch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88
from cuquantum.densitymat import Operator
99

1010
@_data.matmul.register(CuOperator, CuState, CuState)
11-
def matmul_cuoperator_custate_custate(left, right, scale=1., out=None):
12-
merged_hilbert = _compare_hilbert(left.hilbert_dims, right.base.hilbert_space_dims)
13-
if not merged_hilbert:
14-
raise ValueError("Hilbert space missmatch")
11+
def matmul_cuoperator_custate_custate(left, right, scale=1., out=None):
12+
1513
if left.shape[1] == right.shape[0]:
1614
dual = False
15+
merged_hilbert = _compare_hilbert(left.hilbert_dims, right.base.hilbert_space_dims)
1716
elif left.shape[1] == right.shape[0] * right.shape[1]:
1817
dual = True
18+
merged_hilbert = _compare_hilbert(left.hilbert_dims[:len(left.hilbert_dims) // 2], right.base.hilbert_space_dims)
1919
else:
2020
raise ValueError("Shape missmatch")
2121

22+
if not merged_hilbert:
23+
raise ValueError("Hilbert space missmatch")
24+
25+
if(scale != 1.):
26+
left = left * scale
2227
oper = Operator(merged_hilbert, [left.to_OperatorTerm(dual=dual, hilbert_dims=merged_hilbert)])
2328

2429
oper.prepare_action(settings.cuDensity["ctx"], right.base)
@@ -28,3 +33,7 @@ def matmul_cuoperator_custate_custate(left, right, scale=1., out=None):
2833
oper.compute_action(0, [], state_in=right.base, state_out=out.base)
2934

3035
return out
36+
37+
@_data.matmul.register(CuState, CuOperator, CuState)
38+
def matmul_custate_cuoperator_custate(left, right, scale=1., out=None):
39+
return matmul_cuoperator_custate_custate(right.transpose(), left.transpose(), scale, out).transpose()

src/qutip_cuquantum/qobjevo.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ cdef class CuQobjEvo(QobjEvo):
9191
self.expect_ready = True
9292
# Workaround for a bug in cudensity 0.2.0.
9393
settings.cuDensity["ctx"].release_workspace()
94-
return self.operator.compute_expectation(t, None, state.base)
94+
return self.operator.compute_expectation(t, None, state.base).get()
9595

9696
def arguments(self, args):
9797
raise NotImplementedError

src/qutip_cuquantum/state.py

Lines changed: 108 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from cuquantum.densitymat import DensePureState, DenseMixedState
23

34
import numpy as np
@@ -12,6 +13,11 @@
1213
except ImportError:
1314
CuPyDense = None
1415

16+
try:
17+
import mpi4py.MPI as MPI
18+
except ImportError:
19+
MPI = None
20+
1521

1622
class CuState(Data):
1723
def __init__(self, arg, hilbert_dims=None, shape=None, copy=True):
@@ -30,48 +36,69 @@ def __init__(self, arg, hilbert_dims=None, shape=None, copy=True):
3036
hilbert_dims = arg.hilbert_space_dims
3137
base = arg
3238

33-
elif CuPyDense is not None and isinstance(arg, CuPyDense):
39+
elif (CuPyDense is not None and isinstance(arg, CuPyDense)) or isinstance(arg, cp.ndarray):
40+
if CuPyDense is not None and isinstance(arg, CuPyDense):
41+
arg = arg._cp
42+
43+
if arg.ndim == 1:
44+
arg = arg.reshape(-1, 1)
45+
elif arg.ndim > 2:
46+
raise ValueError("Only 1D or 2D arrays are supported")
47+
3448
if shape is None:
3549
shape = arg.shape
3650
if hilbert_dims is None:
37-
hilbert_dims = arg.shape[:1]
51+
if(arg.shape[0] == 1):
52+
hilbert_dims = (arg.shape[1],)
53+
else:
54+
hilbert_dims = (arg.shape[0],)
3855

39-
if arg.shape[0] != np.prod(hilbert_dims) or arg.shape[1] != 1:
40-
# TODO: Add sanity check for hilbert_dims
56+
if arg.shape[0] != 1 and arg.shape[1] != 1:
57+
is_hilbert_dim_matching = (arg.shape[0] == np.prod(hilbert_dims) and arg.shape[1] == np.prod(hilbert_dims))
58+
if not is_hilbert_dim_matching:
59+
raise ValueError(f"Shape {arg.shape} does not match hilbert_dims {hilbert_dims} for mixed state")
4160
base = DenseMixedState(ctx, hilbert_dims, 1, "complex128")
4261
sizes, offsets = base.local_info
4362
sls = tuple(slice(s, s+n) for s, n in zip(offsets, sizes))[:-1]
4463
N = np.prod(sizes)
45-
if len(arg._cp) == N:
64+
if len(arg) == N:
4665
base.attach_storage(cp.array(
47-
arg._cp
66+
arg
4867
.reshape(hilbert_dims * 2)[sls]
4968
.ravel(order="F"),
69+
dtype="complex128",
5070
copy=copy
5171
))
5272
else:
5373
base.allocate_storage()
5474
base.storage[:N] = (
55-
arg._cp
75+
arg
5676
.reshape(hilbert_dims * 2)[sls]
5777
.ravel(order="F")
5878
)
5979

6080
else:
81+
is_hilbert_dim_matching = ((arg.shape[1] == 1 and arg.shape[0] == np.prod(hilbert_dims)) or
82+
(arg.shape[0] == 1 and arg.shape[1] == np.prod(hilbert_dims)))
83+
if not is_hilbert_dim_matching:
84+
raise ValueError(f"Shape {arg.shape} does not match hilbert_dims {hilbert_dims} for pure state")
85+
6186
base = DensePureState(ctx, hilbert_dims, 1, "complex128")
6287
sizes, offsets = base.local_info
6388
sls = tuple(slice(s, s+n) for s, n in zip(offsets, sizes))[:-1]
6489
N = np.prod(sizes)
65-
if len(arg._cp) == N:
90+
if len(arg) == N:
6691
base.attach_storage(cp.array(
67-
arg._cp
92+
arg
6893
.reshape(hilbert_dims)[sls]
69-
.ravel(order="F"), copy=copy
94+
.ravel(order="F"),
95+
dtype="complex128",
96+
copy=copy
7097
))
7198
else:
7299
base.allocate_storage()
73100
base.storage[:N] = (
74-
arg._cp
101+
arg
75102
.reshape(hilbert_dims)[sls]
76103
.ravel(order="F")
77104
)
@@ -122,17 +149,25 @@ def to_array(self, as_tensor=False):
122149
return self.to_cupy(as_tensor).get()
123150

124151
def to_cupy(self, as_tensor=False):
125-
# TODO: How to implement for mpi?
126152
if type(self.base) is DenseMixedState:
127153
tensor_shape = self.base.hilbert_space_dims * 2
128154
else:
129155
tensor_shape = self.base.hilbert_space_dims
156+
157+
local_tensor = self.base.view()[..., 0]
130158
if self.base.local_info[0][:-1] != tensor_shape:
131-
raise NotImplementedError(
132-
"Not Implemented for MPI distributed array."
133-
f"{self.base.local_info[0][:-1]} vs {self.base.hilbert_space_dims}"
134-
)
135-
tensor = self.base.view()[..., 0]
159+
if MPI is None:
160+
raise ImportError("mpi4py is not imported. Distributed tensor assembly requires mpi4py.")
161+
comm = MPI.COMM_WORLD
162+
tensor = cp.empty(tensor_shape, dtype=cp.complex128)
163+
sizes, offsets = self.base.local_info
164+
local_sls = tuple(slice(s, s+n) for s, n in zip(offsets, sizes))[:-1]
165+
all_sls = comm.allgather(local_sls)
166+
all_tensor = comm.allgather(local_tensor)
167+
for rank in range(comm.Get_size()):
168+
tensor[all_sls[rank]] = all_tensor[rank]
169+
else:
170+
tensor = local_tensor
136171
if not as_tensor:
137172
tensor = tensor.reshape(*self.shape, order="C")
138173
return tensor
@@ -145,7 +180,8 @@ def __add__(self, other):
145180
if isinstance(other, Data):
146181
return _data.add(self, other)
147182
return NotImplemented
148-
183+
if(self.shape != other.shape):
184+
raise ValueError("Incompatible shapes")
149185
new = self.copy()
150186
new.base.inplace_accumulate(other.base, 1.)
151187
return new
@@ -156,6 +192,8 @@ def __sub__(self, other):
156192
return _data.sub(self, other)
157193
return NotImplemented
158194

195+
if(self.shape != other.shape):
196+
raise ValueError("Incompatible shapes")
159197
new = self.copy()
160198
new.base.inplace_accumulate(other.base, -1.)
161199
return new
@@ -175,17 +213,20 @@ def conj(self):
175213
)
176214

177215
def transpose(self):
178-
raise NotImplementedError()
216+
arr = self.to_cupy().transpose()
217+
return CuState(arr, hilbert_dims=self.base.hilbert_space_dims, shape=(self.shape[1], self.shape[0]))
179218

180-
def adjoint(self):
181-
raise NotImplementedError()
182219

220+
def adjoint(self):
221+
arr = self.to_cupy().transpose().conj()
222+
return CuState(arr, hilbert_dims=self.base.hilbert_space_dims, shape=(self.shape[1], self.shape[0]))
183223

184224
def CuState_from_Dense(mat):
185225
return CuState(mat)
186226

187227

188228
def Dense_from_CuState(mat):
229+
print("Dense_from_CuState")
189230
return _data.Dense(mat.to_array())
190231

191232

@@ -300,3 +341,49 @@ def isherm(state, tol=-1):
300341

301342
def zeros_like_cuState(state):
302343
return CuState(state.base.clone(cp.zeros_like(state.base.storage, order="F")))
344+
345+
@_data.conj.register(CuState)
346+
def conj_cuState(state):
347+
return state.conj()
348+
349+
@_data.transpose.register(CuState)
350+
def transpose_cuState(state):
351+
return state.transpose()
352+
353+
@_data.adjoint.register(CuState)
354+
def adjoint_cuState(state):
355+
return state.adjoint()
356+
357+
@_data.sub.register(CuState)
358+
def sub_cuState(left, right):
359+
return add_cuState(left, right, -1)
360+
361+
@_data.iszero.register(CuState)
362+
def iszero_cuState(state):
363+
return not cp.any(state.base.storage)
364+
365+
366+
@_data.matmul.register(CuState)
367+
def matmul_cuState(left, right):
368+
if(left.shape[1] != right.shape[0]):
369+
raise ValueError("Incompatible shapes")
370+
371+
if left.base.hilbert_space_dims != right.base.hilbert_space_dims:
372+
raise ValueError(
373+
f"Incompatible hilbert space: {left.base.hilbert_space_dims} "
374+
f"and {right.base.hilbert_space_dims}."
375+
)
376+
377+
output_shape = (left.shape[0], right.shape[1])
378+
ctx = settings.cuDensity["ctx"]
379+
if(left.shape[0] == 1 and right.shape[1] == 1):
380+
# Scalar case
381+
hilbert_dims = (1,)
382+
else:
383+
hilbert_dims = left.base.hilbert_space_dims
384+
385+
left_array = left.to_cupy()
386+
right_array = right.to_cupy()
387+
arr = left_array @ right_array
388+
389+
return CuState(arr, hilbert_dims=hilbert_dims, shape=output_shape)

0 commit comments

Comments
 (0)