Skip to content

Commit e96f8a4

Browse files
amilstedAshley Milsted
andauthored
LazyTensor: Do not assume intermediates should be Vectors (#77)
* LazyTensor: Do not assume we act on Array. * Turn off Aqua piracy checking for now. We implement Base methods for types defined in QuantumInterface, which is piracy. Aim to fix this in the future and reenable the check. --------- Co-authored-by: Ashley Milsted <[email protected]>
1 parent 20b286d commit e96f8a4

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

src/operators_lazytensor.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,11 @@ function lazytensor_enable_cache(; maxsize::Int = -1, maxrelsize::Real = 0.0)
261261
return
262262
end
263263

264-
function _tp_matmul_first!(result::Base.ReshapedArray, a::AbstractMatrix, b::Base.ReshapedArray, α::Number, β::Number)
265-
d_first = size(b, 1)
264+
function _tp_matmul_first!(result, a::AbstractMatrix, b, α::Number, β::Number)
265+
d_first = size(a, 2)
266266
d_rest = length(b)÷d_first
267-
bp = b.parent
268-
rp = result.parent
267+
bp = parent(b)
268+
rp = parent(result)
269269
@uviews bp rp begin # avoid allocations on reshape
270270
br = reshape(bp, (d_first, d_rest))
271271
result_r = reshape(rp, (size(a, 1), d_rest))
@@ -274,11 +274,11 @@ function _tp_matmul_first!(result::Base.ReshapedArray, a::AbstractMatrix, b::Bas
274274
result
275275
end
276276

277-
function _tp_matmul_last!(result::Base.ReshapedArray, a::AbstractMatrix, b::Base.ReshapedArray, α::Number, β::Number)
278-
d_last = size(b, ndims(b))
277+
function _tp_matmul_last!(result, a::AbstractMatrix, b, α::Number, β::Number)
278+
d_last = size(a, 2)
279279
d_rest = length(b)÷d_last
280-
bp = b.parent
281-
rp = result.parent
280+
bp = parent(b)
281+
rp = parent(result)
282282
@uviews a bp rp begin # avoid allocations on reshape
283283
br = reshape(bp, (d_rest, d_last))
284284
result_r = reshape(rp, (d_rest, size(a, 1)))
@@ -287,7 +287,7 @@ function _tp_matmul_last!(result::Base.ReshapedArray, a::AbstractMatrix, b::Base
287287
result
288288
end
289289

290-
function _tp_matmul_get_tmp(::Type{T}, shp::NTuple{N,Int}, sym) where {T,N}
290+
function _tp_matmul_get_tmp(::Type{T}, shp::NTuple{N,Int}, sym, ::Array) where {T,N}
291291
len = prod(shp)
292292
use_cache = lazytensor_use_cache()
293293
key = (sym, taskid(), UInt(len), T)
@@ -301,7 +301,17 @@ function _tp_matmul_get_tmp(::Type{T}, shp::NTuple{N,Int}, sym) where {T,N}
301301
Base.ReshapedArray(tmp, shp, ())
302302
end
303303

304-
function _tp_matmul_mid!(result::Base.ReshapedArray, a::AbstractMatrix, loc::Integer, b::Base.ReshapedArray, α::Number, β::Number)
304+
function _tp_matmul_get_tmp(::Type{T}, shp::NTuple{N,Int}, sym, arr::AbstractArray) where {T,N}
305+
if parent(arr) === arr
306+
# This is a fallback that does not use the cache. Does not get triggered for arr <: Array.
307+
return similar(arr, T, shp)
308+
end
309+
# Unpack wrapped arrays. If we hit an Array, we will use the cache.
310+
# If we hit a different non-wrapped array-like, we will call `similar()`.
311+
_tp_matmul_get_tmp(T, shp, sym, parent(arr))
312+
end
313+
314+
function _tp_matmul_mid!(result, a::AbstractMatrix, loc::Integer, b, α::Number, β::Number)
305315
sz_b_1 = 1
306316
for i in 1:loc-1
307317
sz_b_1 *= size(b,i)
@@ -320,11 +330,11 @@ function _tp_matmul_mid!(result::Base.ReshapedArray, a::AbstractMatrix, loc::Int
320330
move_left = sz_b_1 < sz_b_3
321331
perm = move_left ? (2,1,3) : (1,3,2)
322332

323-
br_p = _tp_matmul_get_tmp(eltype(br), ((size(br, i) for i in perm)...,), :_tp_matmul_mid_in)
333+
br_p = _tp_matmul_get_tmp(eltype(br), ((size(br, i) for i in perm)...,), :_tp_matmul_mid_in, br)
324334
@strided permutedims!(br_p, br, perm)
325335
#permutedims!(br_p, br, perm)
326336

327-
result_r_p = _tp_matmul_get_tmp(eltype(result_r), ((size(result_r, i) for i in perm)...,), :_tp_matmul_mid_out)
337+
result_r_p = _tp_matmul_get_tmp(eltype(result_r), ((size(result_r, i) for i in perm)...,), :_tp_matmul_mid_out, result_r)
328338
β == 0.0 || @strided permutedims!(result_r_p, result_r, perm)
329339
#β == 0.0 || permutedims!(result_r_p, result_r, perm)
330340

@@ -366,7 +376,7 @@ end
366376

367377
function _tp_sum_get_tmp(op::AbstractMatrix{T}, loc::Integer, arr::AbstractArray{S,N}, sym) where {T,S,N}
368378
shp = ntuple(i -> i == loc ? size(op,1) : size(arr,i), N)
369-
_tp_matmul_get_tmp(promote_type(T,S), shp, sym)
379+
_tp_matmul_get_tmp(promote_type(T,S), shp, sym, arr)
370380
end
371381

372382
#Apply a tensor product of operators to a vector.
@@ -434,7 +444,7 @@ Base.size(A::_SimpleIsometry, i) = A.shape[i]
434444

435445
function _tp_sum_get_tmp(op::_SimpleIsometry, loc::Integer, arr::AbstractArray{S,N}, sym) where {S,N}
436446
shp = ntuple(i -> i == loc ? size(op,1) : size(arr,i), N)
437-
_tp_matmul_get_tmp(S, shp, sym)
447+
_tp_matmul_get_tmp(S, shp, sym, arr)
438448
end
439449

440450
function _tp_matmul!(result, a::_SimpleIsometry, loc::Integer, b, α::Number, β::Number)

test/test_aqua.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ using QuantumOpticsBase
33
using Aqua
44

55
@testset "aqua" begin
6-
Aqua.test_all(QuantumOpticsBase)
6+
Aqua.test_all(QuantumOpticsBase,
7+
piracy=false # TODO: Due to Base methods in QuantumOpticsBase, for types defined in QuantumInterface
8+
)
79
end # testset

test/test_operators_lazytensor.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,15 @@ op_sp = LazyTensor(b_l, b_r, [1, 2, 3], sparse.((subop1, subop2, subop3)))*0.1
204204
op_ = 0.1*subop1 subop2 subop3
205205

206206
state = Ket(b_r, rand(ComplexF32, length(b_r)))
207+
state_sp = sparse(state) # to test no-cache path
207208
result_ = Ket(b_l, rand(ComplexF64, length(b_l)))
208209
result = deepcopy(result_)
209210
QuantumOpticsBase.mul!(result,op,state,complex(1.),complex(0.))
210211
@test 1e-6 > D(result, op_*state)
211212

213+
QuantumOpticsBase.mul!(result,op,state_sp,complex(1.),complex(0.))
214+
@test 1e-6 > D(result, op_*state)
215+
212216
QuantumOpticsBase.mul!(result,op_sp,state,complex(1.),complex(0.))
213217
@test 1e-6 > D(result, op_*state)
214218

0 commit comments

Comments
 (0)