Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ function _strides(shape)
return S
end

function _strides(shape::Ty)::Ty where Ty <: Tuple
accumulate(*, (1,Base.front(shape)...))
end

# Dense operator version
@generated function _ptrace(::Type{Val{RANK}}, a,
shape_l, shape_r,
Expand Down
15 changes: 9 additions & 6 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,7 @@ function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, r
rmul!(result, beta)
end
N_k = length(h.basis_r.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_srtides(h)
_gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

Expand All @@ -661,12 +659,17 @@ function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, r
rmul!(result, beta)
end
N_k = length(h.basis_l.bases)
shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)]
strides_j = _strides(h.basis_l.shape)
strides_k = _strides(h.basis_r.shape)
shape, strides_j, strides_k = _get_shape_and_srtides(h)
_gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result)
end

function _get_shape_and_srtides(h)
shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r)
shape = min.(shape_l, shape_r)
strides_j, strides_k = _strides(shape_l), _strides(shape_r)
return shape, strides_j, strides_k
end

_mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result)
_mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result)

Expand Down