@@ -36,6 +36,17 @@ def __init__(self, arg, hilbert_dims=None, shape=None, copy=True):
3636 hilbert_dims = arg .hilbert_space_dims
3737 base = arg
3838
39+ elif isinstance (arg , CuState ):
40+ if shape is not None :
41+ assert arg .shape == shape
42+ else :
43+ shape = arg .shape
44+ if hilbert_dims is not None :
45+ assert arg .base .hilbert_space_dims == hilbert_dims
46+ if copy :
47+ arg = arg .copy ()
48+ base = arg .base
49+
3950 elif (CuPyDense is not None and isinstance (arg , CuPyDense )) or isinstance (arg , cp .ndarray ):
4051 if CuPyDense is not None and isinstance (arg , CuPyDense ):
4152 arg = arg ._cp
@@ -221,6 +232,7 @@ def adjoint(self):
221232 arr = self .to_cupy ().transpose ().conj ()
222233 return CuState (arr , hilbert_dims = self .base .hilbert_space_dims , shape = (self .shape [1 ], self .shape [0 ]))
223234
235+
224236def CuState_from_Dense (mat ):
225237 return CuState (mat )
226238
@@ -357,7 +369,7 @@ def adjoint_cuState(state):
357369@_data .sub .register (CuState )
358370def sub_cuState (left , right ):
359371 return add_cuState (left , right , - 1 )
360-
372+
361373@_data .iszero .register (CuState )
362374def iszero_cuState (state ):
363375 return not cp .any (state .base .storage )
@@ -378,12 +390,12 @@ def matmul_cuState(left, right):
378390 ctx = settings .cuDensity ["ctx" ]
379391 if (left .shape [0 ] == 1 and right .shape [1 ] == 1 ):
380392 # Scalar case
381- hilbert_dims = (1 ,)
393+ hilbert_dims = (1 ,)
382394 else :
383395 hilbert_dims = left .base .hilbert_space_dims
384396
385397 left_array = left .to_cupy ()
386398 right_array = right .to_cupy ()
387399 arr = left_array @ right_array
388-
389- return CuState (arr , hilbert_dims = hilbert_dims , shape = output_shape )
400+
401+ return CuState (arr , hilbert_dims = hilbert_dims , shape = output_shape )
0 commit comments