|
572 | 572 | function _gemm_recursive_dense_lazy(i_k, N_k, K, J, val, |
573 | 573 | shape, strides_k, strides_j, |
574 | 574 | indices, h::LazyTensor, |
575 | | - op::Matrix, result::Matrix) |
| 575 | + op::AbstractArray, result::AbstractArray) |
576 | 576 | if i_k > N_k |
577 | | - for I=1:size(op, 1) |
| 577 | + if isa(op, AbstractVector) |
| 578 | + result[K] += val*op[J] |
| 579 | + else I=1:size(op, 1) |
578 | 580 | result[I, K] += val*op[I, J] |
579 | 581 | end |
580 | 582 | return nothing |
|
609 | 611 | function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val, |
610 | 612 | shape, strides_k, strides_j, |
611 | 613 | indices, h::LazyTensor, |
612 | | - op::Matrix, result::Matrix) |
| 614 | + op::AbstractArray, result::AbstractArray) |
613 | 615 | if i_k > N_k |
614 | 616 | for I=1:size(op, 2) |
615 | 617 | result[J, I] += val*op[K, I] |
@@ -641,45 +643,69 @@ function _gemm_recursive_lazy_dense(i_k, N_k, K, J, val, |
641 | 643 | end |
642 | 644 | end |
643 | 645 |
|
644 | | -function _gemm_puresparse(alpha, op::Matrix, h::LazyTensor{B1,B2,F,I,T}, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
| 646 | +""" |
| 647 | + check_mul!_compatibility(R, A, B) |
| 648 | +Check that `R,A,B` are dimentially compatible for `R.=A*B`. And that `R` is not aliased with either `A` nor `B`. |
| 649 | +""" |
| 650 | +function check_mul!_compatibility(R::AbstractVecOrMat, A, B) |
| 651 | + _check_mul!_aliasing_compatibility(R, A, B) |
| 652 | + _check_mul!_dim_compatibility(size(R), size(A), size(B)) |
| 653 | +end |
| 654 | +function _check_mul!_dim_compatibility(sizeR::Tuple, sizeA::Tuple, sizeB::Tuple) |
| 655 | + # R .= A*B |
| 656 | + if sizeA[2] != sizeB[1] |
| 657 | + throw(DimensionMismatch(lazy"A has dimensions $sizeA but B has dimensions $sizeB. Can't do `A*B`")) |
| 658 | + end |
| 659 | + if sizeR != (sizeA[1], Base.tail(sizeB)...) # using tail to account for vectors |
| 660 | + throw(DimensionMismatch(lazy"R has dimensions $sizeR but A*B has dimensions $((sizeA[1], Base.tail(sizeB)...)). Can't do `R.=A*B`")) |
| 661 | + end |
| 662 | +end |
| 663 | +function _check_mul!_aliasing_compatibility(R, A, B) |
| 664 | + if R===A || R===B |
| 665 | + throw(ArgumentError(lazy"output matrix must not be aliased with input matrix")) |
| 666 | + end |
| 667 | +end |
| 668 | + |
| 669 | + |
| 670 | +function _gemm_puresparse(alpha, op::AbstractArray, h::LazyTensor{B1,B2,F,I,T}, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
| 671 | + if op isa AbstractVector |
| 672 | + # _gemm_recursive_dense_lazy will treat `op` as a `Bra` |
| 673 | + _check_mul!_aliasing_compatibility(result, op, h) |
| 674 | + _check_mul!_dim_compatibility(size(result), reverse(size(h)), size(op)) |
| 675 | + else |
| 676 | + check_mul!_compatibility(result, op, h) |
| 677 | + end |
645 | 678 | if iszero(beta) |
646 | 679 | fill!(result, beta) |
647 | 680 | elseif !isone(beta) |
648 | 681 | rmul!(result, beta) |
649 | 682 | end |
650 | 683 | N_k = length(h.basis_r.bases) |
651 | | - shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)] |
652 | | - strides_j = _strides(h.basis_l.shape) |
653 | | - strides_k = _strides(h.basis_r.shape) |
| 684 | + shape, strides_j, strides_k = _get_shape_and_strides(h) |
654 | 685 | _gemm_recursive_dense_lazy(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result) |
655 | 686 | end |
656 | 687 |
|
657 | | -function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::Matrix, beta, result::Matrix) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
| 688 | +function _gemm_puresparse(alpha, h::LazyTensor{B1,B2,F,I,T}, op::AbstractArray, beta, result::AbstractArray) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
| 689 | + check_mul!_compatibility(result, h, op) |
658 | 690 | if iszero(beta) |
659 | 691 | fill!(result, beta) |
660 | 692 | elseif !isone(beta) |
661 | 693 | rmul!(result, beta) |
662 | 694 | end |
663 | 695 | N_k = length(h.basis_l.bases) |
664 | | - shape = [min(h.basis_l.shape[i], h.basis_r.shape[i]) for i=1:length(h.basis_l.shape)] |
665 | | - strides_j = _strides(h.basis_l.shape) |
666 | | - strides_k = _strides(h.basis_r.shape) |
| 696 | + shape, strides_j, strides_k = _get_shape_and_strides(h) |
667 | 697 | _gemm_recursive_lazy_dense(1, N_k, 1, 1, alpha*h.factor, shape, strides_k, strides_j, h.indices, h, op, result) |
668 | 698 | end |
669 | 699 |
|
| 700 | +function _get_shape_and_strides(h) |
| 701 | + shape_l, shape_r = _comp_size(h.basis_l), _comp_size(h.basis_r) |
| 702 | + shape = min.(shape_l, shape_r) |
| 703 | + strides_j, strides_k = _strides(shape_l), _strides(shape_r) |
| 704 | + return shape, strides_j, strides_k |
| 705 | +end |
| 706 | + |
670 | 707 | _mul_puresparse!(result::DenseOpType{B1,B3},h::LazyTensor{B1,B2,F,I,T},op::DenseOpType{B2,B3},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, h, op.data, beta, result.data); result) |
671 | 708 | _mul_puresparse!(result::DenseOpType{B1,B3},op::DenseOpType{B1,B2},h::LazyTensor{B2,B3,F,I,T},alpha,beta) where {B1,B2,B3,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, op.data, h, beta, result.data); result) |
| 709 | +_mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a, b.data, beta, result.data); result) |
| 710 | +_mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} = (_gemm_puresparse(alpha, a.data, b, beta, result.data); result) |
672 | 711 |
|
673 | | -function _mul_puresparse!(result::Ket{B1},a::LazyTensor{B1,B2,F,I,T},b::Ket{B2},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
674 | | - b_data = reshape(b.data, length(b.data), 1) |
675 | | - result_data = reshape(result.data, length(result.data), 1) |
676 | | - _gemm_puresparse(alpha, a, b_data, beta, result_data) |
677 | | - result |
678 | | -end |
679 | | - |
680 | | -function _mul_puresparse!(result::Bra{B2},a::Bra{B1},b::LazyTensor{B1,B2,F,I,T},alpha,beta) where {B1,B2,F,I,T<:Tuple{Vararg{SparseOpPureType}}} |
681 | | - a_data = reshape(a.data, 1, length(a.data)) |
682 | | - result_data = reshape(result.data, 1, length(result.data)) |
683 | | - _gemm_puresparse(alpha, a_data, b, beta, result_data) |
684 | | - result |
685 | | -end |
|
0 commit comments