Skip to content

Commit 2506d0d

Browse files
AmitRotemAmit Rotem
andauthored
Reduce allocations for multipling LazyTensor of sparse and dense (#80)
Avoid reshape by letting sparse lazytensor gemm routines work on vectors. Also check dimensions. --------- Co-authored-by: Amit Rotem <[email protected]>
1 parent 3ed0e84 commit 2506d0d

File tree

5 files changed

+78
-26
lines changed

5 files changed

+78
-26
lines changed

src/operators.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,5 +392,9 @@ multiplicable(a::AbstractOperator, b::Ket) = multiplicable(a.basis_r, b.basis)
392392
multiplicable(a::Bra, b::AbstractOperator) = multiplicable(a.basis, b.basis_l)
393393
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)
394394

395-
Base.size(op::AbstractOperator) = prod(length(op.basis_l),length(op.basis_r))
396-
Base.size(op::AbstractOperator, i::Int) = (i==1 ? length(op.basis_l) : length(op.basis_r))
395+
Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
396+
function Base.size(op::AbstractOperator, i::Int)
397+
i < 1 && throw(ErrorException(lazy"dimension out of range, should be strictly positive, got $i"))
398+
i > 2 && return 1
399+
i==1 ? length(op.basis_l) : length(op.basis_r)
400+
end

src/operators_dense.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ function _strides(shape)
274274
return S
275275
end
276276

277+
function _strides(shape::Ty)::Ty where Ty <: Tuple
278+
accumulate(*, (1,Base.front(shape)...))
279+
end
280+
277281
# Dense operator version
278282
@generated function _ptrace(::Type{Val{RANK}}, a,
279283
shape_l, shape_r,

src/operators_lazytensor.jl

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,11 @@ end
572572
function _gemm_recursive_dense_lazy(i_k, N_k, K, J, val,
573573
shape, strides_k, strides_j,
574574
indices, h::LazyTensor,
575-
op::Matrix, result::Matrix)
575+
op::AbstractArray, result::AbstractArray)
576576
if i_k > N_k
577-
for I=1:size(op, 1)
577+
if isa(op, AbstractVector)
578+
result[K] += val*op[J]
579+
else I=1:size(op, 1)
578580
result[I, K] += val*op[I, J]
579581
end
580582
return nothing
@@ -609,7 +611,7 @@ end
609611
function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
610612
shape, strides_k, strides_j,
611613
indices, h::LazyTensor,
612-
op::Matrix, result::Matrix)
614+
op::AbstractArray, result::AbstractArray)
613615
if i_k > N_k
614616
for I=1:size(op, 2)
615617
result[J, I] += val*op[K, I]
@@ -641,45 +643,69 @@ function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
641643
end
642644
end
643645

644-
function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
646+
"""
647+
check_mul!_compatibility(R, A, B)
648+
Check that `R,A,B` are dimentially compatible for `R.=A*B`. And that `R` is not aliased with either `A` nor `B`.
649+
"""
650+
function check_mul!_compatibility(R::AbstractVecOrMat, A, B)
651+
_check_mul!_aliasing_compatibility(R, A, B)
652+
_check_mul!_dim_compatibility(size(R), size(A), size(B))
653+
end
654+
function _check_mul!_dim_compatibility(sizeR::Tuple, sizeA::Tuple, sizeB::Tuple)
655+
# R .= A*B
656+
if sizeA[2] != sizeB[1]
657+
throw(DimensionMismatch(lazy"A has dimensions $sizeA but B has dimensions $sizeB. Can't do `A*B`"))
658+
end
659+
if sizeR != (sizeA[1], Base.tail(sizeB)...) # using tail to account for vectors
660+
throw(DimensionMismatch(lazy"R has dimensions $sizeR but A*B has dimensions $((sizeA[1], Base.tail(sizeB)...)). Can't do `R.=A*B`"))
661+
end
662+
end
663+
function _check_mul!_aliasing_compatibility(R, A, B)
664+
if R===A || R===B
665+
throw(ArgumentError(lazy"output matrix must not be aliased with input matrix"))
666+
end
667+
end
668+
669+
670+
function _gemm_puresparse(alpha, op::AbstractArray, h::LazyTensor{B1,B2,F,I,T}, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
671+
if op isa AbstractVector
672+
# _gemm_recursive_dense_lazy will treat `op` as a `Bra`
673+
_check_mul!_aliasing_compatibility(result, op, h)
674+
_check_mul!_dim_compatibility(size(result), reverse(size(h)), size(op))
675+
else
676+
check_mul!_compatibility(result, op, h)
677+
end
645678
if iszero(beta)
646679
fill!(result, beta)
647680
elseif !isone(beta)
648681
rmul!(result, beta)
649682
end
650683
N_k = length(h.basis_r.bases)
651-
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
652-
strides_j = _strides(h.basis_l.shape)
653-
strides_k = _strides(h.basis_r.shape)
684+
shape, strides_j, strides_k = _get_shape_and_strides(h)
654685
_gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
655686
end
656687

657-
function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
688+
function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::AbstractArray, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
689+
check_mul!_compatibility(result, h, op)
658690
if iszero(beta)
659691
fill!(result, beta)
660692
elseif !isone(beta)
661693
rmul!(result, beta)
662694
end
663695
N_k = length(h.basis_l.bases)
664-
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
665-
strides_j = _strides(h.basis_l.shape)
666-
strides_k = _strides(h.basis_r.shape)
696+
shape, strides_j, strides_k = _get_shape_and_strides(h)
667697
_gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
668698
end
669699

700+
function _get_shape_and_strides(h)
701+
shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r)
702+
shape = min.(shape_l, shape_r)
703+
strides_j, strides_k = _strides(shape_l), _strides(shape_r)
704+
return shape, strides_j, strides_k
705+
end
706+
670707
_mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result)
671708
_mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result)
709+
_mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a, b.data, beta, result.data); result)
710+
_mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a.data, b, beta, result.data); result)
672711

673-
function _mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
674-
b_data = reshape(b.data, length(b.data), 1)
675-
result_data = reshape(result.data, length(result.data), 1)
676-
_gemm_puresparse(alpha, a, b_data, beta, result_data)
677-
result
678-
end
679-
680-
function _mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
681-
a_data = reshape(a.data, 1, length(a.data))
682-
result_data = reshape(result.data, 1, length(result.data))
683-
_gemm_puresparse(alpha, a_data, b, beta, result_data)
684-
result
685-
end

test/test_operators.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,14 @@ op12 = destroy(bfock)⊗sigmap(bspin)
130130
@test embed(b, [1,2], op12) == destroy(bfock)sigmap(bspin)one(bspin)
131131
@test embed(b, [1,3], op12) == destroy(bfock)one(bspin)sigmap(bspin)
132132

133+
# size of AbstractOperator
134+
b1, b2 = NLevelBasis.((2, 3))
135+
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
136+
@test size(Lop1) == size(dense(Lop1)) == size(dense(Lop1).data)
137+
@test all(size(Lop1, k) == size(dense(Lop1), k) for k=1:4)
138+
@test_throws ErrorException size(Lop1, 0)
139+
@test_throws ErrorException size(Lop1, -1)
140+
@test_throws ErrorException size(dense(Lop1), 0) # check for consistency
141+
@test_throws ErrorException size(dense(Lop1), -1)
142+
133143
end # testset

test/test_operators_lazytensor.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,13 @@ dop = randoperator(b3a⊗b3b, b2a⊗b2b)
404404
@test dop*lop' Operator(dop.basis_l, lop.basis_l, dop.data*dense(lop).data')
405405
@test lop*dop' Operator(lop.basis_l, dop.basis_l, dense(lop).data*dop.data')
406406

407+
# Dimension mismatches for LazyTensor with sparse
408+
b1, b2 = NLevelBasis.((2, 3))
409+
Lop1 = LazyTensor(b1^2, b2^2, 2, sparse(randoperator(b1, b2)))
410+
@test_throws DimensionMismatch Lop1*Lop1
411+
@test_throws DimensionMismatch dense(Lop1)*Lop1
412+
@test_throws DimensionMismatch sparse(Lop1)*Lop1
413+
@test_throws DimensionMismatch Lop1*dense(Lop1)
414+
@test_throws DimensionMismatch Lop1*sparse(Lop1)
407415

408416
end # testset

0 commit comments

Comments
 (0)