Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ function _strides(shape)
return S
end

function _strides(shape::Ty)::Ty where Ty <: Tuple
accumulate(*, (1,Base.front(shape)...))
end

# Dense operator version
@generated function _ptrace(::Type{Val{RANK}}, a,
shape_l, shape_r,
Expand Down
27 changes: 12 additions & 15 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ end
function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val,
shape, strides_k, strides_j,
indices, h::LazyTensor,
op::Matrix, result::Matrix)
op::VecOrMat, result::VecOrMat)
if i_k > N_k
for I=1:size(op, 2)
result[J, I] += val*op[K, I]
Expand Down Expand Up @@ -648,34 +648,31 @@ function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, r
rmul!(result, beta)
end
N_k = length(h.basis_r.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::VecOrMat, beta, result::VecOrMat) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
if iszero(beta)
fill!(result, beta)
elseif !isone(beta)
rmul!(result, beta)
end
N_k = length(h.basis_l.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_strides(h)
_gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _get_shape_and_strides(h)
shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r)
shape = min.(shape_l, shape_r)
strides_j, strides_k = _strides(shape_l), _strides(shape_r)
return shape, strides_j, strides_k
end

_mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result)
_mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result)

function _mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
b_data = reshape(b.data, length(b.data), 1)
result_data = reshape(result.data, length(result.data), 1)
_gemm_puresparse(alpha, a, b_data, beta, result_data)
result
end
_mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a, b.data, beta, result.data); result)

function _mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}}
a_data = reshape(a.data, 1, length(a.data))
Expand Down