1+ from typing import Any
12from cuquantum .densitymat import DensePureState , DenseMixedState
23
34import numpy as np
1213except ImportError :
1314 CuPyDense = None
1415
16+ try :
17+ import mpi4py .MPI as MPI
18+ except ImportError :
19+ MPI = None
20+
1521
1622class CuState (Data ):
1723 def __init__ (self , arg , hilbert_dims = None , shape = None , copy = True ):
@@ -30,48 +36,69 @@ def __init__(self, arg, hilbert_dims=None, shape=None, copy=True):
3036 hilbert_dims = arg .hilbert_space_dims
3137 base = arg
3238
33- elif CuPyDense is not None and isinstance (arg , CuPyDense ):
39+ elif (CuPyDense is not None and isinstance (arg , CuPyDense )) or isinstance (arg , cp .ndarray ):
40+ if CuPyDense is not None and isinstance (arg , CuPyDense ):
41+ arg = arg ._cp
42+
43+ if arg .ndim == 1 :
44+ arg = arg .reshape (- 1 , 1 )
45+ elif arg .ndim > 2 :
46+ raise ValueError ("Only 1D or 2D arrays are supported" )
47+
3448 if shape is None :
3549 shape = arg .shape
3650 if hilbert_dims is None :
37- hilbert_dims = arg .shape [:1 ]
51+ if (arg .shape [0 ] == 1 ):
52+ hilbert_dims = (arg .shape [1 ],)
53+ else :
54+ hilbert_dims = (arg .shape [0 ],)
3855
39- if arg .shape [0 ] != np .prod (hilbert_dims ) or arg .shape [1 ] != 1 :
40- # TODO: Add sanity check for hilbert_dims
56+ if arg .shape [0 ] != 1 and arg .shape [1 ] != 1 :
57+ is_hilbert_dim_matching = (arg .shape [0 ] == np .prod (hilbert_dims ) and arg .shape [1 ] == np .prod (hilbert_dims ))
58+ if not is_hilbert_dim_matching :
59+ raise ValueError (f"Shape { arg .shape } does not match hilbert_dims { hilbert_dims } for mixed state" )
4160 base = DenseMixedState (ctx , hilbert_dims , 1 , "complex128" )
4261 sizes , offsets = base .local_info
4362 sls = tuple (slice (s , s + n ) for s , n in zip (offsets , sizes ))[:- 1 ]
4463 N = np .prod (sizes )
45- if len (arg . _cp ) == N :
64+ if len (arg ) == N :
4665 base .attach_storage (cp .array (
47- arg . _cp
66+ arg
4867 .reshape (hilbert_dims * 2 )[sls ]
4968 .ravel (order = "F" ),
69+ dtype = "complex128" ,
5070 copy = copy
5171 ))
5272 else :
5373 base .allocate_storage ()
5474 base .storage [:N ] = (
55- arg . _cp
75+ arg
5676 .reshape (hilbert_dims * 2 )[sls ]
5777 .ravel (order = "F" )
5878 )
5979
6080 else :
81+ is_hilbert_dim_matching = ((arg .shape [1 ] == 1 and arg .shape [0 ] == np .prod (hilbert_dims )) or
82+ (arg .shape [0 ] == 1 and arg .shape [1 ] == np .prod (hilbert_dims )))
83+ if not is_hilbert_dim_matching :
84+ raise ValueError (f"Shape { arg .shape } does not match hilbert_dims { hilbert_dims } for pure state" )
85+
6186 base = DensePureState (ctx , hilbert_dims , 1 , "complex128" )
6287 sizes , offsets = base .local_info
6388 sls = tuple (slice (s , s + n ) for s , n in zip (offsets , sizes ))[:- 1 ]
6489 N = np .prod (sizes )
65- if len (arg . _cp ) == N :
90+ if len (arg ) == N :
6691 base .attach_storage (cp .array (
67- arg . _cp
92+ arg
6893 .reshape (hilbert_dims )[sls ]
69- .ravel (order = "F" ), copy = copy
94+ .ravel (order = "F" ),
95+ dtype = "complex128" ,
96+ copy = copy
7097 ))
7198 else :
7299 base .allocate_storage ()
73100 base .storage [:N ] = (
74- arg . _cp
101+ arg
75102 .reshape (hilbert_dims )[sls ]
76103 .ravel (order = "F" )
77104 )
@@ -122,17 +149,25 @@ def to_array(self, as_tensor=False):
122149 return self .to_cupy (as_tensor ).get ()
123150
124151 def to_cupy (self , as_tensor = False ):
125- # TODO: How to implement for mpi?
126152 if type (self .base ) is DenseMixedState :
127153 tensor_shape = self .base .hilbert_space_dims * 2
128154 else :
129155 tensor_shape = self .base .hilbert_space_dims
156+
157+ local_tensor = self .base .view ()[..., 0 ]
130158 if self .base .local_info [0 ][:- 1 ] != tensor_shape :
131- raise NotImplementedError (
132- "Not Implemented for MPI distributed array."
133- f"{ self .base .local_info [0 ][:- 1 ]} vs { self .base .hilbert_space_dims } "
134- )
135- tensor = self .base .view ()[..., 0 ]
159+ if MPI is None :
160+ raise ImportError ("mpi4py is not imported. Distributed tensor assembly requires mpi4py." )
161+ comm = MPI .COMM_WORLD
162+ tensor = cp .empty (tensor_shape , dtype = cp .complex128 )
163+ sizes , offsets = self .base .local_info
164+ local_sls = tuple (slice (s , s + n ) for s , n in zip (offsets , sizes ))[:- 1 ]
165+ all_sls = comm .allgather (local_sls )
166+ all_tensor = comm .allgather (local_tensor )
167+ for rank in range (comm .Get_size ()):
168+ tensor [all_sls [rank ]] = all_tensor [rank ]
169+ else :
170+ tensor = local_tensor
136171 if not as_tensor :
137172 tensor = tensor .reshape (* self .shape , order = "C" )
138173 return tensor
@@ -145,7 +180,8 @@ def __add__(self, other):
145180 if isinstance (other , Data ):
146181 return _data .add (self , other )
147182 return NotImplemented
148-
183+ if (self .shape != other .shape ):
184+ raise ValueError ("Incompatible shapes" )
149185 new = self .copy ()
150186 new .base .inplace_accumulate (other .base , 1. )
151187 return new
@@ -156,6 +192,8 @@ def __sub__(self, other):
156192 return _data .sub (self , other )
157193 return NotImplemented
158194
195+ if (self .shape != other .shape ):
196+ raise ValueError ("Incompatible shapes" )
159197 new = self .copy ()
160198 new .base .inplace_accumulate (other .base , - 1. )
161199 return new
@@ -175,17 +213,20 @@ def conj(self):
175213 )
176214
177215 def transpose (self ):
178- raise NotImplementedError ()
216+ arr = self .to_cupy ().transpose ()
217+ return CuState (arr , hilbert_dims = self .base .hilbert_space_dims , shape = (self .shape [1 ], self .shape [0 ]))
179218
180- def adjoint (self ):
181- raise NotImplementedError ()
182219
220+ def adjoint (self ):
221+ arr = self .to_cupy ().transpose ().conj ()
222+ return CuState (arr , hilbert_dims = self .base .hilbert_space_dims , shape = (self .shape [1 ], self .shape [0 ]))
183223
184224def CuState_from_Dense (mat ):
185225 return CuState (mat )
186226
187227
188228def Dense_from_CuState (mat ):
229+ print ("Dense_from_CuState" )
189230 return _data .Dense (mat .to_array ())
190231
191232
@@ -300,3 +341,49 @@ def isherm(state, tol=-1):
300341
301342def zeros_like_cuState (state ):
302343 return CuState (state .base .clone (cp .zeros_like (state .base .storage , order = "F" )))
344+
345+ @_data .conj .register (CuState )
346+ def conj_cuState (state ):
347+ return state .conj ()
348+
349+ @_data .transpose .register (CuState )
350+ def transpose_cuState (state ):
351+ return state .transpose ()
352+
353+ @_data .adjoint .register (CuState )
354+ def adjoint_cuState (state ):
355+ return state .adjoint ()
356+
357+ @_data .sub .register (CuState )
358+ def sub_cuState (left , right ):
359+ return add_cuState (left , right , - 1 )
360+
361+ @_data .iszero .register (CuState )
362+ def iszero_cuState (state ):
363+ return not cp .any (state .base .storage )
364+
365+
366+ @_data .matmul .register (CuState )
367+ def matmul_cuState (left , right ):
368+ if (left .shape [1 ] != right .shape [0 ]):
369+ raise ValueError ("Incompatible shapes" )
370+
371+ if left .base .hilbert_space_dims != right .base .hilbert_space_dims :
372+ raise ValueError (
373+ f"Incompatible hilbert space: { left .base .hilbert_space_dims } "
374+ f"and { right .base .hilbert_space_dims } ."
375+ )
376+
377+ output_shape = (left .shape [0 ], right .shape [1 ])
378+ ctx = settings .cuDensity ["ctx" ]
379+ if (left .shape [0 ] == 1 and right .shape [1 ] == 1 ):
380+ # Scalar case
381+ hilbert_dims = (1 ,)
382+ else :
383+ hilbert_dims = left .base .hilbert_space_dims
384+
385+ left_array = left .to_cupy ()
386+ right_array = right .to_cupy ()
387+ arr = left_array @ right_array
388+
389+ return CuState (arr , hilbert_dims = hilbert_dims , shape = output_shape )
0 commit comments