Skip to content

Commit 4d3ebe0

Browse files
authored
Merge pull request #17 from Ericgig/better_oper
Rework to operterm
2 parents 8d12af3 + 2798eb3 commit 4d3ebe0

File tree

3 files changed

+346
-62
lines changed

3 files changed

+346
-62
lines changed

setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import setuptools
1010
from Cython.Build import cythonize
1111
import qutip
12+
import qutip.core.cy as qutip_cc
1213
import numpy
1314

1415

@@ -94,7 +95,9 @@ def get_ext_modules(options):
9495
pyx_file = os.path.join("src", "qutip_cuquantum", "qobjevo.pyx")
9596
include_dirs = [
9697
numpy.get_include(),
97-
os.path.abspath(os.path.join(qutip.core.data.__file__, os.pardir))
98+
os.path.abspath(os.path.join(qutip.core.data.__file__, os.pardir)),
99+
os.path.abspath(os.path.join(qutip_cc.__file__, os.pardir)),
100+
os.path.abspath(os.path.join(qutip.__file__, os.pardir))
98101
]
99102
print("*********************************************************************************")
100103
print(include_dirs)
@@ -107,7 +110,7 @@ def get_ext_modules(options):
107110
language="c++",
108111
)
109112

110-
return cythonize(ext)
113+
return cythonize(ext, include_path=include_dirs)
111114

112115

113116
if __name__ == "__main__":

src/qutip_cuquantum/operator.py

Lines changed: 82 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,32 @@
2626
# Input fixed in tests, but did not fail early
2727

2828

29-
def _transpose_cu_operator(oper):
30-
if isinstance(oper, MultidiagonalOperator):
31-
out = MultidiagonalOperator(
32-
oper.data,
33-
[-offset for offset in oper.offsets],
34-
callback=oper.callback,
35-
)
29+
def _transpose_cu_operator(oper, transpose: list[bool]):
30+
"""
31+
Transpose modes where ``transpose`` is True.
32+
"""
33+
if isinstance(oper, MultidiagonalOperator) and len(transpose) == 1:
34+
if transpose[0]:
35+
out = MultidiagonalOperator(
36+
oper.data,
37+
[-offset for offset in oper.offsets],
38+
callback=oper.callback,
39+
)
3640
elif isinstance(oper, DenseOperator):
3741
N = oper.num_modes
3842
batch_dims_oper = len(oper.data.shape) % 2
39-
perm = tuple(range(N, 2*N)) + tuple(range(N))
43+
44+
perm_l = []
45+
perm_r = []
46+
for i, trans in enumerate(transpose):
47+
if trans:
48+
perm_l.append(i + N)
49+
perm_r.append(i)
50+
else:
51+
perm_l.append(i)
52+
perm_r.append(i+ N)
53+
perm = tuple(perm_l + perm_r)
54+
4055
new_callback = None
4156

4257
if oper.callback is not None:
@@ -102,7 +117,8 @@ def _oper_to_ElementaryOperator(
102117
oper,
103118
hilbert_idx,
104119
hilbert_dims,
105-
transform,
120+
transpose,
121+
dag,
106122
copy=False
107123
):
108124
N = len(hilbert_idx)
@@ -111,32 +127,27 @@ def _oper_to_ElementaryOperator(
111127
if isinstance(oper, (DenseOperator, MultidiagonalOperator)):
112128
if N != 1 and isinstance(oper, MultidiagonalOperator):
113129
raise ValueError(
114-
"MultidiagonalOperator on multiple hilbert spaces"
130+
"MultidiagonalOperator on multiple hilbert spaces are not supported."
115131
)
116132
if list(oper.shape[:len(oper.shape) // 2]) != list(shape):
117133
raise ValueError(
118134
f"Operator shape does not match hilbert spaces: "
119135
f"{list(oper.shape[:len(oper.shape) // 2])}, {shape}"
120136
)
121137

122-
if transform == Transform.DIRECT:
123-
out = oper
124-
elif transform == Transform.ADJOINT:
125-
out = oper.dag()
126-
elif transform == Transform.CONJ:
127-
out = _transpose_cu_operator(oper).dag()
128-
elif transform == Transform.TRANSPOSE:
129-
out = _transpose_cu_operator(oper)
138+
cu_oper = oper
130139

131140
else:
132-
if transform == Transform.DIRECT:
133-
pass
134-
elif transform == Transform.ADJOINT:
135-
oper = oper.adjoint()
136-
elif transform == Transform.CONJ:
141+
if all(transpose) and dag:
137142
oper = oper.conj()
138-
elif transform == Transform.TRANSPOSE:
143+
transpose = [False]
144+
dag = False
145+
elif dag:
146+
oper = oper.adjoint()
147+
dag = False
148+
elif all(transpose):
139149
oper = oper.transpose()
150+
transpose = [False]
140151

141152
if isinstance(oper, _data.Dia) and N == 1:
142153
dia_matrix = oper.as_scipy()
@@ -145,12 +156,17 @@ def _oper_to_ElementaryOperator(
145156
for i, offset in enumerate(offsets):
146157
end = None if offset == 0 else -abs(offset)
147158
data[:end, i] = dia_matrix.diagonal(offset)
148-
out = MultidiagonalOperator(data, offsets)
159+
cu_oper = MultidiagonalOperator(data, offsets)
149160

150161
else:
151-
out = DenseOperator(oper.to_array().reshape(shape + shape))
162+
cu_oper = DenseOperator(oper.to_array().reshape(shape + shape))
152163

153-
return out
164+
if any(transpose):
165+
cu_oper = _transpose_cu_operator(cu_oper, transpose)
166+
if dag:
167+
cu_oper = cu_oper.dag()
168+
169+
return cu_oper
154170

155171

156172
###############################################################################
@@ -538,11 +554,14 @@ def to_OperatorTerm(self, dual=False, copy=True, hilbert_dims=None):
538554
for term in self.terms:
539555
cuterm = tensor_product(dtype="complex128")
540556
for pterm in term.prod_terms:
557+
transpose = pterm.transform in [Transform.TRANSPOSE, Transform.CONJ]
558+
dag = pterm.transform in [Transform.ADJOINT, Transform.CONJ]
541559
oper = _oper_to_ElementaryOperator(
542560
pterm.operator,
543561
pterm.hilbert,
544562
self.hilbert_space_dims,
545-
pterm.transform,
563+
[transpose] * len(pterm.hilbert),
564+
dag,
546565
copy
547566
)
548567
# Inverted order confirmed by nvidia
@@ -552,40 +571,29 @@ def to_OperatorTerm(self, dual=False, copy=True, hilbert_dims=None):
552571
else:
553572
N_hilbert = len(self.hilbert_dims) // 2
554573
# TODO: make this tests weak compare?
555-
assert self.hilbert_dims[:N_hilbert] == self.hilbert_dims[N_hilbert:]
574+
if self.hilbert_dims[:N_hilbert] != self.hilbert_dims[N_hilbert:]:
575+
raise ValueError(
576+
f"Hilbert space inconsistent with square superoperator: {self.hilbert_dims}"
577+
)
556578
for term in self.terms:
557579
cuterm = tensor_product(dtype="complex128")
558580
for pterm in term.prod_terms:
559-
if all(i < N_hilbert for i in pterm.hilbert):
560-
oper = _oper_to_ElementaryOperator(
561-
pterm.operator,
562-
pterm.hilbert,
563-
self.hilbert_space_dims,
564-
trans_transform[pterm.transform],
565-
copy
566-
)
567-
# Inverted order confirmed by nvidia
568-
cuterm = cuterm * tensor_product(
569-
(oper, pterm.hilbert, (True,))
570-
)
571-
572-
elif any(i < N_hilbert for i in pterm.hilbert):
573-
raise NotImplementedError(
574-
"Operators acting on both original and "
575-
"dual spaces are not supported."
576-
)
581+
modes = tuple(mode % N_hilbert for mode in pterm.hilbert)
582+
duals = tuple(mode < N_hilbert for mode in pterm.hilbert)
583+
transpose = tuple(duals)
584+
dag = pterm.transform in [Transform.ADJOINT, Transform.CONJ]
585+
if pterm.transform in [Transform.CONJ, Transform.TRANSPOSE]:
586+
transpose = tuple(not trans for trans in transpose)
577587

578-
else:
579-
oper = _oper_to_ElementaryOperator(
580-
pterm.operator,
581-
pterm.hilbert,
582-
self.hilbert_space_dims,
583-
pterm.transform,
584-
copy
585-
)
586-
cuterm = cuterm * tensor_product(
587-
(oper, tuple(i - N_hilbert for i in pterm.hilbert))
588-
)
588+
oper = _oper_to_ElementaryOperator(
589+
pterm.operator,
590+
pterm.hilbert,
591+
self.hilbert_space_dims,
592+
transpose,
593+
dag,
594+
copy
595+
)
596+
cuterm = cuterm * tensor_product((oper, modes, duals,))
589597

590598
out = out + (cuterm * term.factor)
591599

@@ -731,6 +739,22 @@ def isherm(operator, tol=-1):
731739
if tol < 0:
732740
tol = settings.core["atol"]
733741
return cp.allclose(oper, oper.T.conj(), atol=tol)
742+
743+
744+
@_data.identity_like.register(CuOperator)
745+
def identity_like(data, /):
746+
"""
747+
Create an identity matrix of the same type and shape.
748+
"""
749+
if not data.shape[0] == data.shape[1]:
750+
raise ValueError(
751+
"Can't create an identity matrix like a non square matrix."
752+
)
753+
754+
new = CuOperator(hilbert_dims=data.hilbert_dims)
755+
new.terms.append(Term([], 1.))
756+
return new
757+
734758
###############################################################################
735759
###############################################################################
736760

0 commit comments

Comments
 (0)