Skip to content

Commit 2d7211e

Browse files
authored
Merge pull request #26 from Ericgig/Custate.init_fix
`Custate.__init__` from CuState fix.
2 parents 35007f3 + e3f80c2 commit 2d7211e

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/qutip_cuquantum/state.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def __init__(self, arg, hilbert_dims=None, shape=None, copy=True):
3636
hilbert_dims = arg.hilbert_space_dims
3737
base = arg
3838

39+
elif isinstance(arg, CuState):
40+
if shape is not None:
41+
assert arg.shape == shape
42+
else:
43+
shape = arg.shape
44+
if hilbert_dims is not None:
45+
assert arg.base.hilbert_space_dims == hilbert_dims
46+
if copy:
47+
arg = arg.copy()
48+
base = arg.base
49+
3950
elif (CuPyDense is not None and isinstance(arg, CuPyDense)) or isinstance(arg, cp.ndarray):
4051
if CuPyDense is not None and isinstance(arg, CuPyDense):
4152
arg = arg._cp
@@ -221,6 +232,7 @@ def adjoint(self):
221232
arr = self.to_cupy().transpose().conj()
222233
return CuState(arr, hilbert_dims=self.base.hilbert_space_dims, shape=(self.shape[1], self.shape[0]))
223234

235+
224236
def CuState_from_Dense(mat):
225237
return CuState(mat)
226238

@@ -357,7 +369,7 @@ def adjoint_cuState(state):
357369
@_data.sub.register(CuState)
358370
def sub_cuState(left, right):
359371
return add_cuState(left, right, -1)
360-
372+
361373
@_data.iszero.register(CuState)
362374
def iszero_cuState(state):
363375
return not cp.any(state.base.storage)
@@ -378,12 +390,12 @@ def matmul_cuState(left, right):
378390
ctx = settings.cuDensity["ctx"]
379391
if(left.shape[0] == 1 and right.shape[1] == 1):
380392
# Scalar case
381-
hilbert_dims = (1,)
393+
hilbert_dims = (1,)
382394
else:
383395
hilbert_dims = left.base.hilbert_space_dims
384396

385397
left_array = left.to_cupy()
386398
right_array = right.to_cupy()
387399
arr = left_array @ right_array
388-
389-
return CuState(arr, hilbert_dims=hilbert_dims, shape=output_shape)
400+
401+
return CuState(arr, hilbert_dims=hilbert_dims, shape=output_shape)

0 commit comments

Comments
 (0)