Skip to content

Commit 3cc4b9f

Browse files
committed
introduce QuantumToolboxMetalExt
1 parent 5af51a2 commit 3cc4b9f

File tree

6 files changed

+187
-0
lines changed

6 files changed

+187
-0
lines changed

.buildkite/Metal_Ext.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
steps:
2+
- label: "Metal Julia {{matrix.version}}"
3+
matrix:
4+
setup:
5+
version:
6+
- "1.10" # oldest
7+
#- "1" # latest
8+
plugins:
9+
- JuliaCI/julia#v1:
10+
version: "{{matrix.version}}"
11+
- JuliaCI/julia-test#v1:
12+
test_args: "--quickfail"
13+
- JuliaCI/julia-coverage#v1:
14+
codecov: true
15+
dirs:
16+
- src
17+
- ext
18+
agents:
19+
queue: "juliaecosystem"
20+
os: "macos"
21+
arch: "aarch64"
22+
env:
23+
GROUP: "Metal_Ext"
24+
SECRET_CODECOV_TOKEN: "ZfhQu/IcRLqNyZ//ZNs5sjBPaV76IHfU5gui52Qn+Rp8tOurukqgScuyDt+3HQ4R0hJYBw1/Nqg6jmBsvWSc9NEUx8kGsUJFHfN3no0+b+PFxA8oJkWc9EpyIsjht5ZIjlsFWR3f0DpPqMEle/QyWOPcal63CChXR8oAoR+Fz1Bh8GkokLlnC8F9Ugp9xBlu401GCbyZhvLTZnNIgK5yy9q8HBJnBg1cPOhI81J6JvYpEmcIofEzFV/qkfpTUPclu43WNoFX2DZPzbxilf3fsAd5/+nRkRfkNML8KiN4mnmjHxPPbuY8F5zC/PS5ybXtDpfvaMQc01WApXCkZk0ZAQ==;U2FsdGVkX1+eDT7dqCME5+Ox5i8GvWRTQbwiP/VYjapThDbxXFDeSSIC6Opmon+M8go22Bun3bat6Fzie65ang=="
25+
timeout_in_minutes: 60
26+
if: |
27+
// Don't run Buildkite if the commit message includes the text [skip ci], [ci skip], or [no ci]
28+
// Don't run Buildkite for PR draft
29+
// Only run Buildkite when new commits and PR are made to main branch
30+
build.message !~ /\[skip ci\]/ &&
31+
build.message !~ /\[ci skip\]/ &&
32+
build.message !~ /\[no ci\]/ &&
33+
!build.pull_request.draft &&
34+
(build.branch =~ /main/ || build.pull_request.base_branch =~ /main/)

.buildkite/pipeline.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,15 @@ steps:
1212
- "test/ext-test/cuda_ext.jl"
1313
- "Project.toml"
1414
target: ".buildkite/CUDA_Ext.yml"
15+
- staticfloat/forerunner: # Metal.jl tests
16+
watch:
17+
- ".buildkite/pipeline.yml"
18+
- ".buildkite/Metal_Ext.yml"
19+
- "src/**"
20+
- "ext/QuantumToolboxMetalExt.jl"
21+
- "test/runtests.jl"
22+
- "test/ext-test/metal_ext.jl"
23+
- "Project.toml"
24+
target: ".buildkite/Metal_Ext.yml"
1525
agents:
1626
queue: "juliagpu"

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2525

2626
[weakdeps]
2727
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
28+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2829

2930
[extensions]
3031
QuantumToolboxCUDAExt = "CUDA"
32+
QuantumToolboxMetalExt = "Metal"
3133

3234
[compat]
3335
ArrayInterface = "6, 7"
@@ -39,6 +41,7 @@ Graphs = "1.7"
3941
IncompleteLU = "0.2"
4042
LinearAlgebra = "<0.0.1, 1"
4143
LinearSolve = "2"
44+
Metal = "1"
4245
OrdinaryDiffEqCore = "1"
4346
OrdinaryDiffEqTsit5 = "1"
4447
Pkg = "<0.0.1, 1"
@@ -54,6 +57,7 @@ julia = "1.10"
5457

5558
[extras]
5659
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
60+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
5761
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5862

5963
[targets]

ext/QuantumToolboxMetalExt.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module QuantumToolboxMetalExt
2+
3+
using QuantumToolbox
4+
import Metal: mtl, MtlArray
5+
6+
@doc raw"""
7+
MtlArray(A::QuantumObject)
8+
If `A.data` is an arbitrary array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal.MtlArray` for gpu calculations.
9+
Note that this function will always change element type into `32`-bit (`Int32`, `Float32`, and `ComplexF32`).
10+
"""
11+
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Number} = QuantumObject(MtlArray(A.data), A.type, A.dims)
12+
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Int64} = QuantumObject(MtlArray{Int32}(A.data), A.type, A.dims)
13+
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:Float64} =
14+
QuantumObject(MtlArray{Float32}(A.data), A.type, A.dims)
15+
MtlArray(A::QuantumObject{<:AbstractArray{T}}) where {T<:ComplexF64} =
16+
QuantumObject(MtlArray{ComplexF32}(A.data), A.type, A.dims)
17+
18+
@doc raw"""
19+
MtlArray{T}(A::QuantumObject)
20+
If `A.data` is an arbitrary array, return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal.MtlArray` with element type `T` for gpu calculations.
21+
"""
22+
MtlArray{T}(A::QuantumObject{<:AbstractArray{Tq}}) where {T,Tq<:Number} =
23+
QuantumObject(MtlArray{T}(A.data), A.type, A.dims)
24+
25+
@doc raw"""
26+
mtl(A::QuantumObject)
27+
Return a new [`QuantumObject`](@ref) where `A.data` is in the type of `Metal` arrays for gpu calculations.
28+
Note that this function will always change element type into `32`-bit (`Int32`, `Float32`, and `ComplexF32`).
29+
"""
30+
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:Int64} = QuantumObject(MtlArray{Int32}(A.data), A.type, A.dims)
31+
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:Float64} = QuantumObject(MtlArray{Float32}(A.data), A.type, A.dims)
32+
mtl(A::QuantumObject{<:AbstractArray{T}}) where {T<:ComplexF64} =
33+
QuantumObject(MtlArray{ComplexF32}(A.data), A.type, A.dims)
34+
35+
## TODO: Remove the following part if Metal.jl support `sparse`
36+
import LinearAlgebra: Transpose, Adjoint
37+
import QuantumToolbox: _spre, _spost, _sprepost
38+
_spre(A::MtlArray, Id::AbstractMatrix) = kron(Id, A)
39+
_spre(A::Tranpose{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(Id, A)
40+
_spre(A::Adjoint{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(Id, A)
41+
_spost(B::MtlArray, Id::AbstractMatrix) = kron(transpose(B), Id)
42+
_spost(B::Tranpose{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(transpose(B), Id)
43+
_spost(B::Adjoint{T,<:MtlArray}, Id::AbstractMatrix) where {T<:Number} = kron(transpose(B), Id)
44+
_sprepost(A::MtlArray, B::MtlArray) = kron(transpose(B), A)
45+
_sprepost(A::MtlArray, B::Tranpose{T,<:MtlArray}) where {T<:Number} = kron(transpose(B), A)
46+
_sprepost(A::MtlArray, B::Adjoint{T,<:MtlArray}) where {T<:Number} = kron(transpose(B), A)
47+
_sprepost(A::Tranpose{T,<:MtlArray}, B::MtlArray) where {T<:Number} = kron(transpose(B), A)
48+
_sprepost(A::Tranpose{T1,<:MtlArray}, B::Tranpose{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
49+
_sprepost(A::Tranpose{T1,<:MtlArray}, B::Adjoint{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
50+
_sprepost(A::Adjoint{T,<:MtlArray}, B::MtlArray) where {T<:Number} = kron(transpose(B), A)
51+
_sprepost(A::Adjoint{T1,<:MtlArray}, B::Tranpose{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
52+
_sprepost(A::Adjoint{T1,<:MtlArray}, B::Adjoint{T2,<:MtlArray}) where {T1<:Number,T2<:Number} = kron(transpose(B), A)
53+
54+
end

test/ext-test/metal_ext.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using Metal
2+
3+
QuantumToolbox.about()
4+
Metal.versioninfo()
5+
6+
@testset "Metal Extension" verbose = true begin
7+
ψdi = Qobj(Int64[1, 0])
8+
ψdf = Qobj(Float64[1, 0])
9+
ψdc = Qobj(ComplexF64[1, 0])
10+
ψsi = dense_to_sparse(ψdi)
11+
ψsf = dense_to_sparse(ψdf)
12+
ψsc = dense_to_sparse(ψdc)
13+
14+
Xdi = Qobj(Int64[0 1; 1 0])
15+
Xdf = Qobj(Float64[0 1; 1 0])
16+
Xdc = Qobj(ComplexF64[0 1; 1 0])
17+
Xsi = dense_to_sparse(Xdi)
18+
Xsf = dense_to_sparse(Xdf)
19+
Xsc = dense_to_sparse(Xdc)
20+
21+
# type conversion of dense arrays
22+
@test typeof(mtl(ψdi).data) == typeof(MtlArray{Int32}(ψdi).data) <: MtlVector{Int32}
23+
@test typeof(mtl(ψdf).data) ==
24+
typeof(MtlArray(ψdf).data) ==
25+
typeof(MtlArray{Float32}(ψdf).data) <:
26+
MtlVector{Float32}
27+
@test typeof(mtl(ψdc).data) ==
28+
typeof(MtlArray(ψdc).data) ==
29+
typeof(MtlArray{ComplexF32}(ψdc).data) <:
30+
MtlVector{ComplexF32}
31+
@test typeof(mtl(Xdi).data) == typeof(MtlArray{Int32}(Xdi).data) <: MtlMatrix{Int32}
32+
@test typeof(mtl(Xdf).data) ==
33+
typeof(MtlArray(Xdf).data) ==
34+
typeof(MtlArray{Float32}(Xdf).data) <:
35+
MtlMatrix{Float32}
36+
@test typeof(mtl(Xdc).data) ==
37+
typeof(MtlArray(Xdc).data) ==
38+
typeof(MtlArray{ComplexF32}(Xdc).data) <:
39+
MtlMatrix{ComplexF32}
40+
41+
# type conversion of sparse arrays
42+
@test typeof(mtl(ψsi).data) == typeof(MtlArray{Int32}(ψsi).data) <: MtlVector{Int32}
43+
@test typeof(mtl(ψsf).data) ==
44+
typeof(MtlArray(ψsf).data) ==
45+
typeof(MtlArray{Float32}(ψsf).data) <:
46+
MtlVector{Float32}
47+
@test typeof(mtl(ψsc).data) ==
48+
typeof(MtlArray(ψsc).data) ==
49+
typeof(MtlArray{ComplexF32}(ψsc).data) <:
50+
MtlVector{ComplexF32}
51+
@test typeof(mtl(Xsi).data) == typeof(MtlArray{Int32}(Xsi).data) <: MtlMatrix{Int32}
52+
@test typeof(mtl(Xsf).data) ==
53+
typeof(MtlArray(Xsf).data) ==
54+
typeof(MtlArray{Float32}(Xsf).data) <:
55+
MtlMatrix{Float32}
56+
@test typeof(mtl(Xsc).data) ==
57+
typeof(MtlArray(Xsc).data) ==
58+
typeof(MtlArray{ComplexF32}(Xsc).data) <:
59+
MtlMatrix{ComplexF32}
60+
61+
# brief example in README and documentation
62+
N = 5 # cannot be too large since Metal.jl does not support sparse matrix
63+
ω = 1.0f0 # Float32
64+
γ = 0.1f0 # Float32
65+
tlist = range(0, 10, 100)
66+
67+
## calculate by CPU
68+
a_cpu = destroy(N)
69+
ψ0_cpu = fock(N, 3)
70+
H_cpu = ω * a_cpu' * a_cpu
71+
sol_cpu = mesolve(H_cpu, ψ0_cpu, tlist, [sqrt(γ) * a_cpu], e_ops = [a_cpu' * a_cpu], progress_bar = Val(false))
72+
73+
## calculate by GPU
74+
a_gpu = mtl(destroy(N))
75+
ψ0_gpu = mtl(fock(N, 3))
76+
H_gpu = ω * a_gpu' * a_gpu
77+
sol_gpu = mesolve(H_gpu, ψ0_gpu, tlist, [sqrt(γ) * a_gpu], e_ops = [a_gpu' * a_gpu], progress_bar = Val(false))
78+
79+
@test all(isapprox.(sol_cpu.expect, sol_gpu.expect; atol = 1e-6))
80+
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,8 @@ if (GROUP == "CUDA_Ext")# || (GROUP == "All")
4343
Pkg.add("CUDA")
4444
include(joinpath(testdir, "ext-test", "cuda_ext.jl"))
4545
end
46+
47+
if (GROUP == "Metal_Ext")# || (GROUP == "All")
48+
Pkg.add("Metal")
49+
include(joinpath(testdir, "ext-test", "metal_ext.jl"))
50+
end

0 commit comments

Comments
 (0)