Skip to content

Commit 8401557

Browse files
committed
improve lazy tensor warning
1 parent 4f81980 commit 8401557

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

src/qobj/functions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ function LinearAlgebra.kron(
184184
B::AbstractQuantumObject{DT2,OpType,<:Dimensions},
185185
) where {DT1,DT2,OpType<:Union{KetQuantumObject,BraQuantumObject,OperatorQuantumObject}}
186186
QType = promote_op_type(A, B)
187+
_lazy_tensor_warning(A.data, B.data)
187188
return QType(kron(A.data, B.data), A.type, Dimensions((A.dimensions.to..., B.dimensions.to...)))
188189
end
189190

@@ -197,6 +198,7 @@ for ADimType in (:Dimensions, :GeneralDimensions)
197198
B::AbstractQuantumObject{DT2,OperatorQuantumObject,<:$BDimType},
198199
) where {DT1,DT2}
199200
QType = promote_op_type(A, B)
201+
_lazy_tensor_warning(A.data, B.data)
200202
return QType(
201203
kron(A.data, B.data),
202204
Operator,
@@ -221,6 +223,7 @@ for AOpType in (:KetQuantumObject, :BraQuantumObject, :OperatorQuantumObject)
221223
B::AbstractQuantumObject{DT2,$BOpType},
222224
) where {DT1,DT2}
223225
QType = promote_op_type(A, B)
226+
_lazy_tensor_warning(A.data, B.data)
224227
return QType(
225228
kron(A.data, B.data),
226229
Operator,

src/qobj/superoperators.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@ _spre(A::MatrixOperator, Id::AbstractMatrix) = MatrixOperator(_spre(A.A, Id))
2626
_spre(A::ScaledOperator, Id::AbstractMatrix) = ScaledOperator(A.λ, _spre(A.L, Id))
2727
_spre(A::AddedOperator, Id::AbstractMatrix) = AddedOperator(map(op -> _spre(op, Id), A.ops))
2828
function _spre(A::AbstractSciMLOperator, Id::AbstractMatrix)
29-
_lazy_tensor_warning("spre", A)
29+
_lazy_tensor_warning(Id, A)
3030
return kron(Id, A)
3131
end
3232

3333
_spost(B::MatrixOperator, Id::AbstractMatrix) = MatrixOperator(_spost(B.A, Id))
3434
_spost(B::ScaledOperator, Id::AbstractMatrix) = ScaledOperator(B.λ, _spost(B.L, Id))
3535
_spost(B::AddedOperator, Id::AbstractMatrix) = AddedOperator(map(op -> _spost(op, Id), B.ops))
3636
function _spost(B::AbstractSciMLOperator, Id::AbstractMatrix)
37-
_lazy_tensor_warning("spost", B)
38-
return kron(transpose(B), Id)
37+
B_T = transpose(B)
38+
_lazy_tensor_warning(B_T, Id)
39+
return kron(B_T, Id)
3940
end
4041

4142
## intrinsic liouvillian

src/utilities.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,21 @@ _non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
153153
join(arg, ", ") *
154154
")` instead of `$argname = $arg`." maxlog = 1
155155

156-
_lazy_tensor_warning(func_name::String, data::AbstractSciMLOperator) =
157-
@warn "The function `$func_name` uses lazy tensor (which can hurt performance) for data type: $(get_typename_wrapper(data))"
156+
# lazy tensor warning
157+
for AType in (:AbstractArray, :AbstractSciMLOperator)
158+
for BType in (:AbstractArray, :AbstractSciMLOperator)
159+
if AType == BType == :AbstractArray
160+
@eval begin
161+
_lazy_tensor_warning(::$AType, ::$BType) = nothing
162+
end
163+
else
164+
@eval begin
165+
_lazy_tensor_warning(A::$AType, B::$BType) =
166+
@warn "using lazy tensor (which can hurt performance) between data types: $(get_typename_wrapper(A)) and $(get_typename_wrapper(B))"
167+
end
168+
end
169+
end
170+
end
158171

159172
# functions for getting Float or Complex element type
160173
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)

test/core-test/quantum_objects_evo.jl

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,25 +160,19 @@
160160
@inferred a * a
161161
@inferred a * a'
162162

163-
# TODO: kron is currently not supported
164-
# @inferred kron(a)
165-
# @inferred kron(a, σx)
166-
# @inferred kron(a, eye(2))
163+
@inferred kron(a)
164+
@test_logs (:warn,) @inferred kron(a, σx)
165+
@test_logs (:warn,) @inferred kron(a, eye(2))
166+
@test_logs (:warn,) (:warn,) @inferred kron(a, eye(2), eye(2))
167167
end
168168
end
169169

170-
# TODO: tensor is currently not supported
171-
# @testset "tensor" begin
172-
# σx = sigmax()
173-
# X3 = kron(σx, σx, σx)
174-
# @test tensor(σx) == kron(σx)
175-
# @test tensor(fill(σx, 3)...) == X3
176-
# X_warn = @test_logs (
177-
# :warn,
178-
# "`tensor(A)` or `kron(A)` with `A` is a `Vector` can hurt performance. Try to use `tensor(A...)` or `kron(A...)` instead.",
179-
# ) tensor(fill(σx, 3))
180-
# @test X_warn == X3
181-
# end
170+
@testset "tensor" begin
171+
σx = QobjEvo(sigmax())
172+
X3 = @test_logs (:warn,) (:warn,) tensor(σx, σx, σx)
173+
X_warn = @test_logs (:warn,) (:warn,) (:warn,) tensor(fill(σx, 3))
174+
@test X_warn(0) == X3(0) == tensor(sigmax(), sigmax(), sigmax())
175+
end
182176

183177
@testset "Time Dependent Operators and SuperOperators" begin
184178
N = 10

0 commit comments

Comments
 (0)