Skip to content

Commit ba7713b

Browse files
authored
Merge pull request #23 from Ericgig/version_update
Update for latest numpy version
2 parents ed6abee + 7bde8ae commit ba7713b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/qutip_cuquantum/qobjevo.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ cdef class CuQobjEvo(QobjEvo):
9191
self.expect_ready = True
9292
# Workaround for a bug in cudensity 0.2.0.
9393
settings.cuDensity["ctx"].release_workspace()
94-
return self.operator.compute_expectation(t, None, state.base).get()
94+
return self.operator.compute_expectation(t, None, state.base).get()[0]
9595

9696
def arguments(self, args):
9797
raise NotImplementedError
@@ -131,3 +131,6 @@ cdef class CuQobjEvo(QobjEvo):
131131

132132
def data(self, t):
133133
raise NotImplementedError
134+
135+
def __repr__(self):
136+
return "qutip-cuQuantum QobjEvo"

src/qutip_cuquantum/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,15 @@ def trace_cuState(mat):
245245
if mat.shape[0] != mat.shape[1]:
246246
raise ValueError(...)
247247

248-
return complex(mat.base.trace())
248+
return complex(mat.base.trace()[0])
249249

250250

251251
@_data.inner.register(CuState)
252252
def inner_cuState(left, right, scalar_is_ket=False):
253253
if left.shape == (1, 1) and not scalar_is_ket:
254254
inner = left.base.storage[0] * right.base.storage[0]
255255
else:
256-
inner = left.base.inner_product(right.base)
256+
inner = left.base.inner_product(right.base)[0]
257257
return complex(inner)
258258

259259

@@ -288,7 +288,7 @@ def iadd_cuState(left, right, scale=1.):
288288

289289
@_data.norm.frobenius.register(CuState)
290290
def frobenius_cuState(mat):
291-
return float(mat.base.norm())**0.5
291+
return float(mat.base.norm()[0])**0.5
292292

293293

294294
@_data.norm.l2.register(CuState)
@@ -307,7 +307,7 @@ def wrmn_error_cuState(diff, state, atol, rtol):
307307
)
308308
diff.base.storage[:] = cp.abs(diff.base.storage)
309309
diff.base.storage[:] = diff.base.storage / (atol + rtol * cp.abs(state.base.storage))
310-
return float(diff.base.norm() / (diff.shape[0] * diff.shape[1]))**0.5
310+
return float(diff.base.norm()[0] / (diff.shape[0] * diff.shape[1]))**0.5
311311

312312

313313
@_data.reshape.register(CuState)

0 commit comments

Comments
 (0)