2626# Input fixed in tests, but did not fail early
2727
2828
29- def _transpose_cu_operator (oper ):
30- if isinstance (oper , MultidiagonalOperator ):
31- out = MultidiagonalOperator (
32- oper .data ,
33- [- offset for offset in oper .offsets ],
34- callback = oper .callback ,
35- )
29+ def _transpose_cu_operator (oper , transpose : list [bool ]):
30+ """
31+ Transpose modes where ``transpose`` is True.
32+ """
33+ if isinstance (oper , MultidiagonalOperator ) and len (transpose ) == 1 :
34+ if transpose [0 ]:
35+ out = MultidiagonalOperator (
36+ oper .data ,
37+ [- offset for offset in oper .offsets ],
38+ callback = oper .callback ,
39+ )
3640 elif isinstance (oper , DenseOperator ):
3741 N = oper .num_modes
3842 batch_dims_oper = len (oper .data .shape ) % 2
39- perm = tuple (range (N , 2 * N )) + tuple (range (N ))
43+
44+ perm_l = []
45+ perm_r = []
46+ for i , trans in enumerate (transpose ):
47+ if trans :
48+ perm_l .append (i + N )
49+ perm_r .append (i )
50+ else :
51+ perm_l .append (i )
52+ perm_r .append (i + N )
53+ perm = tuple (perm_l + perm_r )
54+
4055 new_callback = None
4156
4257 if oper .callback is not None :
@@ -102,7 +117,8 @@ def _oper_to_ElementaryOperator(
102117 oper ,
103118 hilbert_idx ,
104119 hilbert_dims ,
105- transform ,
120+ transpose ,
121+ dag ,
106122 copy = False
107123):
108124 N = len (hilbert_idx )
@@ -111,32 +127,27 @@ def _oper_to_ElementaryOperator(
111127 if isinstance (oper , (DenseOperator , MultidiagonalOperator )):
112128 if N != 1 and isinstance (oper , MultidiagonalOperator ):
113129 raise ValueError (
114- "MultidiagonalOperator on multiple hilbert spaces"
130+ "MultidiagonalOperator on multiple hilbert spaces are not supported. "
115131 )
116132 if list (oper .shape [:len (oper .shape ) // 2 ]) != list (shape ):
117133 raise ValueError (
118134 f"Operator shape does not match hilbert spaces: "
119135 f"{ list (oper .shape [:len (oper .shape ) // 2 ])} , { shape } "
120136 )
121137
122- if transform == Transform .DIRECT :
123- out = oper
124- elif transform == Transform .ADJOINT :
125- out = oper .dag ()
126- elif transform == Transform .CONJ :
127- out = _transpose_cu_operator (oper ).dag ()
128- elif transform == Transform .TRANSPOSE :
129- out = _transpose_cu_operator (oper )
138+ cu_oper = oper
130139
131140 else :
132- if transform == Transform .DIRECT :
133- pass
134- elif transform == Transform .ADJOINT :
135- oper = oper .adjoint ()
136- elif transform == Transform .CONJ :
141+ if all (transpose ) and dag :
137142 oper = oper .conj ()
138- elif transform == Transform .TRANSPOSE :
143+ transpose = [False ]
144+ dag = False
145+ elif dag :
146+ oper = oper .adjoint ()
147+ dag = False
148+ elif all (transpose ):
139149 oper = oper .transpose ()
150+ transpose = [False ]
140151
141152 if isinstance (oper , _data .Dia ) and N == 1 :
142153 dia_matrix = oper .as_scipy ()
@@ -145,12 +156,17 @@ def _oper_to_ElementaryOperator(
145156 for i , offset in enumerate (offsets ):
146157 end = None if offset == 0 else - abs (offset )
147158 data [:end , i ] = dia_matrix .diagonal (offset )
148- out = MultidiagonalOperator (data , offsets )
159+ cu_oper = MultidiagonalOperator (data , offsets )
149160
150161 else :
151- out = DenseOperator (oper .to_array ().reshape (shape + shape ))
162+ cu_oper = DenseOperator (oper .to_array ().reshape (shape + shape ))
152163
153- return out
164+ if any (transpose ):
165+ cu_oper = _transpose_cu_operator (cu_oper , transpose )
166+ if dag :
167+ cu_oper = cu_oper .dag ()
168+
169+ return cu_oper
154170
155171
156172###############################################################################
@@ -538,11 +554,14 @@ def to_OperatorTerm(self, dual=False, copy=True, hilbert_dims=None):
538554 for term in self .terms :
539555 cuterm = tensor_product (dtype = "complex128" )
540556 for pterm in term .prod_terms :
557+ transpose = pterm .transform in [Transform .TRANSPOSE , Transform .CONJ ]
558+ dag = pterm .transform in [Transform .ADJOINT , Transform .CONJ ]
541559 oper = _oper_to_ElementaryOperator (
542560 pterm .operator ,
543561 pterm .hilbert ,
544562 self .hilbert_space_dims ,
545- pterm .transform ,
563+ [transpose ] * len (pterm .hilbert ),
564+ dag ,
546565 copy
547566 )
548567 # Inverted order confirmed by nvidia
@@ -552,40 +571,29 @@ def to_OperatorTerm(self, dual=False, copy=True, hilbert_dims=None):
552571 else :
553572 N_hilbert = len (self .hilbert_dims ) // 2
554573 # TODO: make this tests weak compare?
555- assert self .hilbert_dims [:N_hilbert ] == self .hilbert_dims [N_hilbert :]
574+ if self .hilbert_dims [:N_hilbert ] != self .hilbert_dims [N_hilbert :]:
575+ raise ValueError (
576+ f"Hilbert space inconsistent with square superoperator: { self .hilbert_dims } "
577+ )
556578 for term in self .terms :
557579 cuterm = tensor_product (dtype = "complex128" )
558580 for pterm in term .prod_terms :
559- if all (i < N_hilbert for i in pterm .hilbert ):
560- oper = _oper_to_ElementaryOperator (
561- pterm .operator ,
562- pterm .hilbert ,
563- self .hilbert_space_dims ,
564- trans_transform [pterm .transform ],
565- copy
566- )
567- # Inverted order confirmed by nvidia
568- cuterm = cuterm * tensor_product (
569- (oper , pterm .hilbert , (True ,))
570- )
571-
572- elif any (i < N_hilbert for i in pterm .hilbert ):
573- raise NotImplementedError (
574- "Operators acting on both original and "
575- "dual spaces are not supported."
576- )
581+ modes = tuple (mode % N_hilbert for mode in pterm .hilbert )
582+ duals = tuple (mode < N_hilbert for mode in pterm .hilbert )
583+ transpose = tuple (duals )
584+ dag = pterm .transform in [Transform .ADJOINT , Transform .CONJ ]
585+ if pterm .transform in [Transform .CONJ , Transform .TRANSPOSE ]:
586+ transpose = tuple (not trans for trans in transpose )
577587
578- else :
579- oper = _oper_to_ElementaryOperator (
580- pterm .operator ,
581- pterm .hilbert ,
582- self .hilbert_space_dims ,
583- pterm .transform ,
584- copy
585- )
586- cuterm = cuterm * tensor_product (
587- (oper , tuple (i - N_hilbert for i in pterm .hilbert ))
588- )
588+ oper = _oper_to_ElementaryOperator (
589+ pterm .operator ,
590+ pterm .hilbert ,
591+ self .hilbert_space_dims ,
592+ transpose ,
593+ dag ,
594+ copy
595+ )
596+ cuterm = cuterm * tensor_product ((oper , modes , duals ,))
589597
590598 out = out + (cuterm * term .factor )
591599
@@ -731,6 +739,22 @@ def isherm(operator, tol=-1):
731739 if tol < 0 :
732740 tol = settings .core ["atol" ]
733741 return cp .allclose (oper , oper .T .conj (), atol = tol )
742+
743+
744+ @_data .identity_like .register (CuOperator )
745+ def identity_like (data , / ):
746+ """
747+ Create an identity matrix of the same type and shape.
748+ """
749+ if not data .shape [0 ] == data .shape [1 ]:
750+ raise ValueError (
751+ "Can't create an identity matrix like a non square matrix."
752+ )
753+
754+ new = CuOperator (hilbert_dims = data .hilbert_dims )
755+ new .terms .append (Term ([], 1. ))
756+ return new
757+
734758###############################################################################
735759###############################################################################
736760
0 commit comments