Skip to content

Commit 3ed0e84

Browse files
amilstedAshley Milsted
andauthored
Check dimensions in sparse gemm and gemv (#83)
* Check dimensions. * Switch to DimensionMismatch * Add some tests. --------- Co-authored-by: Ashley Milsted <[email protected]>
1 parent f6b3987 commit 3ed0e84

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

src/sparsematrix.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ end
9999

100100

101101
function gemm!(alpha, M::SparseMatrixCSC, B::AbstractMatrix, beta, result::AbstractMatrix)
102+
size(M, 2) == size(B, 1) || throw(DimensionMismatch())
103+
size(M, 1) == size(result, 1) || throw(DimensionMismatch())
104+
size(B, 2) == size(result, 2) || throw(DimensionMismatch())
102105
if iszero(beta)
103106
fill!(result, beta)
104107
elseif !isone(beta)
@@ -112,6 +115,9 @@ function gemm!(alpha, M::SparseMatrixCSC, B::AbstractMatrix, beta, result::Abstr
112115
end
113116

114117
function gemm!(alpha, B::AbstractMatrix, M::SparseMatrixCSC, beta, result::AbstractMatrix)
118+
size(M, 1) == size(B, 2) || throw(DimensionMismatch())
119+
size(M, 2) == size(result,2) || throw(DimensionMismatch())
120+
size(B, 1) == size(result,1) || throw(DimensionMismatch())
115121
if iszero(beta)
116122
fill!(result, beta)
117123
elseif !isone(beta)
@@ -146,6 +152,9 @@ function gemm!(alpha, M_::Adjoint{T,<:SparseMatrixCSC{T}}, B::AbstractMatrix, be
146152
if nnz(M) > 550
147153
LinearAlgebra.mul!(result, M_, B, alpha, beta)
148154
else
155+
size(M_, 2) == size(B, 1) || throw(DimensionMismatch())
156+
size(M_, 1) == size(result, 1) || throw(DimensionMismatch())
157+
size(B, 2) == size(result, 2) || throw(DimensionMismatch())
149158
if iszero(beta)
150159
fill!(result, beta)
151160
elseif !isone(beta)
@@ -156,6 +165,9 @@ function gemm!(alpha, M_::Adjoint{T,<:SparseMatrixCSC{T}}, B::AbstractMatrix, be
156165
end
157166

158167
function gemm!(alpha, B::AbstractMatrix, M::Adjoint{T,<:SparseMatrixCSC{T}}, beta, result::AbstractMatrix) where T
168+
size(M, 1) == size(B, 2) || throw(DimensionMismatch())
169+
size(M, 2) == size(result,2) || throw(DimensionMismatch())
170+
size(B, 1) == size(result,1) || throw(DimensionMismatch())
159171
if iszero(beta)
160172
fill!(result, beta)
161173
elseif !isone(beta)
@@ -178,6 +190,9 @@ function gemm!(alpha, A::SparseMatrixCSC, B::SparseMatrixCSC, beta, result::Abst
178190
end
179191

180192
function gemv!(alpha, M::SparseMatrixCSC, v::AbstractVector, beta, result::AbstractVector)
193+
size(M, 2) == size(v, 1) || throw(DimensionMismatch())
194+
size(M, 1) == size(result, 1) || throw(DimensionMismatch())
195+
181196
if iszero(beta)
182197
fill!(result, beta)
183198
elseif !isone(beta)
@@ -201,6 +216,9 @@ function gemv!(alpha, M::SparseMatrixCSC, v::AbstractVector, beta, result::Abstr
201216
end
202217

203218
function gemv!(alpha, v::AbstractVector, M::SparseMatrixCSC, beta, result::AbstractVector)
219+
size(M, 1) == size(v, 1) || throw(DimensionMismatch())
220+
size(M, 2) == size(result, 1) || throw(DimensionMismatch())
221+
204222
if iszero(beta)
205223
fill!(result, beta)
206224
elseif !isone(beta)

test/test_operators_dense.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,13 @@ op3 = randoperator(bf)
382382
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
383383
@test_throws ErrorException cos.(op1)
384384

385+
# Dimension mismatches
386+
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
387+
@test_throws DimensionMismatch mul!(randstate(b1), randoperator(b2), randstate(b3))
388+
@test_throws DimensionMismatch mul!(randstate(b1)', randstate(b3)', randoperator(b2))
389+
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b2), randoperator(b3))
390+
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b3)', randoperator(b2))
391+
385392
end # testset
386393

387394
@testset "State-operator tensor products" begin

test/test_operators_sparse.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,13 @@ op_ .+= op1
389389
@test op_ == 2*op1
390390
@test_throws ErrorException cos.(op_)
391391

392+
# Dimension mismatches
393+
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter
394+
@test_throws DimensionMismatch mul!(randstate(b1), sparse(randoperator(b2)), randstate(b3))
395+
@test_throws DimensionMismatch mul!(randstate(b1)', randstate(b3)', sparse(randoperator(b2)))
396+
@test_throws DimensionMismatch mul!(randoperator(b1), sparse(randoperator(b2)), randoperator(b3))
397+
@test_throws DimensionMismatch mul!(randoperator(b1), randoperator(b3)', sparse(randoperator(b2)))
398+
392399
end # testset
393400

394401
@testset "State-operator tensor products, sparse" begin

0 commit comments

Comments
 (0)