Skip to content

Commit cb86d5d

Browse files
CopilotTendonFFF
andcommitted
Add AMDGPU sparse matrix support extension
Co-authored-by: TendonFFF <[email protected]>
1 parent 679c2b5 commit cb86d5d

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3030
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
3131

3232
[weakdeps]
33+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3334
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3435
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3536
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
3637
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3738
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
3839

3940
[extensions]
41+
QuantumToolboxAMDGPUExt = "AMDGPU"
4042
QuantumToolboxCUDAExt = "CUDA"
4143
QuantumToolboxChainRulesCoreExt = "ChainRulesCore"
4244
QuantumToolboxGPUArraysExt = ["GPUArrays", "KernelAbstractions"]
4345
QuantumToolboxMakieExt = "Makie"
4446

4547
[compat]
48+
AMDGPU = "1, 2"
4649
ArrayInterface = "6, 7"
4750
CUDA = "5.0 - 5.8, 5.9.4 - 5"
4851
ChainRulesCore = "1"

ext/QuantumToolboxAMDGPUExt.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module QuantumToolboxAMDGPUExt
2+
3+
using QuantumToolbox
4+
using QuantumToolbox: makeVal, getVal
5+
import QuantumToolbox: _sparse_similar, _convert_eltype_wordsize
6+
import AMDGPU: roc, ROCArray, allowscalar
7+
import AMDGPU.rocSPARSE: ROCSparseVector, ROCSparseMatrixCSC, ROCSparseMatrixCSR, AbstractROCSparseArray
8+
import SparseArrays: SparseVector, SparseMatrixCSC, sparse
9+
import AMDGPU.Adapt: adapt
10+
11+
allowscalar(false)
12+
13+
@doc raw"""
14+
ROCArray(A::QuantumObject)
15+
16+
If `A.data` is a dense array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.ROCArray` for gpu calculations.
17+
"""
18+
ROCArray(A::QuantumObject) = QuantumObject(ROCArray(A.data), A.type, A.dimensions)
19+
20+
@doc raw"""
21+
ROCArray{T}(A::QuantumObject)
22+
23+
If `A.data` is a dense array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.ROCArray` with element type `T` for gpu calculations.
24+
"""
25+
ROCArray{T}(A::QuantumObject) where {T} = QuantumObject(ROCArray{T}(A.data), A.type, A.dimensions)
26+
27+
@doc raw"""
28+
ROCSparseVector(A::QuantumObject)
29+
30+
If `A.data` is a sparse vector, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseVector` for gpu calculations.
31+
"""
32+
ROCSparseVector(A::QuantumObject) = QuantumObject(ROCSparseVector(A.data), A.type, A.dimensions)
33+
34+
@doc raw"""
35+
ROCSparseVector{T}(A::QuantumObject)
36+
37+
If `A.data` is a sparse vector, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseVector` with element type `T` for gpu calculations.
38+
"""
39+
ROCSparseVector{T}(A::QuantumObject) where {T} = QuantumObject(ROCSparseVector{T}(A.data), A.type, A.dimensions)
40+
41+
@doc raw"""
42+
ROCSparseMatrixCSC(A::QuantumObject)
43+
44+
If `A.data` is in the type of `SparseMatrixCSC`, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseMatrixCSC` for gpu calculations.
45+
"""
46+
ROCSparseMatrixCSC(A::QuantumObject) = QuantumObject(ROCSparseMatrixCSC(A.data), A.type, A.dimensions)
47+
48+
@doc raw"""
49+
ROCSparseMatrixCSC{T}(A::QuantumObject)
50+
51+
If `A.data` is in the type of `SparseMatrixCSC`, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseMatrixCSC` with element type `T` for gpu calculations.
52+
"""
53+
ROCSparseMatrixCSC{T}(A::QuantumObject) where {T} = QuantumObject(ROCSparseMatrixCSC{T}(A.data), A.type, A.dimensions)
54+
55+
@doc raw"""
56+
ROCSparseMatrixCSR(A::QuantumObject)
57+
58+
If `A.data` is in the type of `SparseMatrixCSC`, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseMatrixCSR` for gpu calculations.
59+
"""
60+
ROCSparseMatrixCSR(A::QuantumObject) = QuantumObject(ROCSparseMatrixCSR(A.data), A.type, A.dimensions)
61+
62+
@doc raw"""
63+
ROCSparseMatrixCSR(A::QuantumObject)
64+
65+
If `A.data` is in the type of `SparseMatrixCSC`, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU.rocSPARSE.ROCSparseMatrixCSR` with element type `T` for gpu calculations.
66+
"""
67+
ROCSparseMatrixCSR{T}(A::QuantumObject) where {T} = QuantumObject(ROCSparseMatrixCSR{T}(A.data), A.type, A.dimensions)
68+
69+
@doc raw"""
70+
roc(A::QuantumObject; word_size::Int=64)
71+
72+
Return a new [`QuantumObject`](@ref) where `A.data` is in the type of `AMDGPU` arrays for gpu calculations.
73+
74+
# Arguments
75+
- `A::QuantumObject`: The [`QuantumObject`](@ref)
76+
- `word_size::Int`: The word size of the element type of `A`, can be either `32` or `64`. Default to `64`.
77+
"""
78+
function roc(A::QuantumObject; word_size::Union{Val,Int} = Val(64))
79+
_word_size = getVal(makeVal(word_size))
80+
81+
((_word_size == 64) || (_word_size == 32)) || throw(DomainError(_word_size, "The word size should be 32 or 64."))
82+
83+
return roc(A, makeVal(word_size))
84+
end
85+
roc(A::QuantumObject, word_size::Union{Val{32},Val{64}}) =
86+
QuantumObject(adapt(ROCArray{_convert_eltype_wordsize(eltype(A), word_size)}, A.data), A.type, A.dimensions)
87+
function roc(
88+
A::QuantumObject{ObjType,DimsType,<:SparseVector},
89+
word_size::Union{Val{32},Val{64}},
90+
) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions}
91+
return ROCSparseVector{_convert_eltype_wordsize(eltype(A), word_size)}(A)
92+
end
93+
function roc(
94+
A::QuantumObject{ObjType,DimsType,<:SparseMatrixCSC},
95+
word_size::Union{Val{32},Val{64}},
96+
) where {ObjType<:QuantumObjectType,DimsType<:AbstractDimensions}
97+
return ROCSparseMatrixCSC{_convert_eltype_wordsize(eltype(A), word_size)}(A)
98+
end
99+
100+
QuantumToolbox.to_dense(A::MT) where {MT<:AbstractROCSparseArray} = ROCArray(A)
101+
102+
QuantumToolbox.to_dense(::Type{T1}, A::ROCArray{T2}) where {T1<:Number,T2<:Number} = ROCArray{T1}(A)
103+
QuantumToolbox.to_dense(::Type{T}, A::AbstractROCSparseArray) where {T<:Number} = ROCArray{T}(A)
104+
105+
QuantumToolbox._sparse_similar(A::ROCSparseMatrixCSC, args...) = sparse(args..., fmt = :csc)
106+
QuantumToolbox._sparse_similar(A::ROCSparseMatrixCSR, args...) = sparse(args..., fmt = :csr)
107+
108+
end

0 commit comments

Comments
 (0)