Skip to content

Commit 892c496

Browse files
Make clenshaw method as default for wigner (#279)
* Make clenshaw method as default for wigner * Minor changes
1 parent dea0bd7 commit 892c496

File tree

2 files changed

+94
-38
lines changed

2 files changed

+94
-38
lines changed

src/wigner.jl

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,114 @@ end
1212
WignerLaguerre(; parallel = false, tol = 1e-14) = WignerLaguerre(parallel, tol)
1313

1414
@doc raw"""
15-
wigner(state::QuantumObject, xvec::AbstractVector, yvec::AbstractVector; g::Real=√2,
16-
solver::WignerSolver=WignerLaguerre())
17-
18-
Generates the [Wigner quasipropability distribution](https://en.wikipedia.org/wiki/Wigner_quasiprobability_distribution)
19-
of `state` at points `xvec + 1im * yvec`. The `g` parameter is a scaling factor related to the value of ``\hbar`` in the
20-
commutation relation ``[x, y] = i \hbar`` via ``\hbar=2/g^2`` giving the default value ``\hbar=1``.
21-
22-
The `solver` parameter can be either `WignerLaguerre()` or `WignerClenshaw()`. The former uses the Laguerre polynomial
23-
expansion of the Wigner function, while the latter uses the Clenshaw algorithm. The Laguerre expansion is faster for
24-
sparse matrices, while the Clenshaw algorithm is faster for dense matrices. The `WignerLaguerre` solver has an optional
25-
`parallel` parameter which defaults to `true` and uses multithreading to speed up the calculation.
15+
wigner(
16+
state::QuantumObject{DT,OpType},
17+
xvec::AbstractVector,
18+
yvec::AbstractVector;
19+
g::Real = √2,
20+
method::WignerSolver = WignerClenshaw(),
21+
)
22+
23+
Generates the [Wigner quasipropability distribution](https://en.wikipedia.org/wiki/Wigner_quasiprobability_distribution) of `state` at points `xvec + 1im * yvec` in phase space. The `g` parameter is a scaling factor related to the value of ``\hbar`` in the commutation relation ``[x, y] = i \hbar`` via ``\hbar=2/g^2`` giving the default value ``\hbar=1``.
24+
25+
The `method` parameter can be either `WignerLaguerre()` or `WignerClenshaw()`. The former uses the Laguerre polynomial expansion of the Wigner function, while the latter uses the Clenshaw algorithm. The Laguerre expansion is faster for sparse matrices, while the Clenshaw algorithm is faster for dense matrices. The `WignerLaguerre` method has an optional `parallel` parameter which defaults to `true` and uses multithreading to speed up the calculation.
26+
27+
# Arguments
28+
- `state::QuantumObject`: The quantum state for which the Wigner function is calculated. It can be either a [`KetQuantumObject`](@ref), [`BraQuantumObject`](@ref), or [`OperatorQuantumObject`](@ref).
29+
- `xvec::AbstractVector`: The x-coordinates of the phase space grid.
30+
- `yvec::AbstractVector`: The y-coordinates of the phase space grid.
31+
- `g::Real`: The scaling factor related to the value of ``\hbar`` in the commutation relation ``[x, y] = i \hbar`` via ``\hbar=2/g^2``.
32+
- `method::WignerSolver`: The method used to calculate the Wigner function. It can be either `WignerLaguerre()` or `WignerClenshaw()`, with `WignerClenshaw()` as default. The `WignerLaguerre` method has the optional `parallel` and `tol` parameters, with default values `true` and `1e-14`, respectively.
33+
34+
# Returns
35+
- `W::Matrix`: The Wigner function of the state at the points `xvec + 1im * yvec` in phase space.
36+
37+
# Example
38+
```
39+
julia> ψ = fock(10, 0) + fock(10, 1) |> normalize
40+
Quantum Object: type=Ket dims=[10] size=(10,)
41+
10-element Vector{ComplexF64}:
42+
0.7071067811865475 + 0.0im
43+
0.7071067811865475 + 0.0im
44+
0.0 + 0.0im
45+
0.0 + 0.0im
46+
0.0 + 0.0im
47+
0.0 + 0.0im
48+
0.0 + 0.0im
49+
0.0 + 0.0im
50+
0.0 + 0.0im
51+
0.0 + 0.0im
52+
53+
julia> xvec = range(-5, 5, 200)
54+
-5.0:0.05025125628140704:5.0
55+
56+
julia> wig = wigner(ψ, xvec, xvec)
57+
200×200 Matrix{Float64}:
58+
2.63558e-21 4.30187e-21 6.98638e-21 1.12892e-20 1.81505e-20 … 1.50062e-20 9.28736e-21 5.71895e-21 3.50382e-21
59+
4.29467e-21 7.00905e-21 1.13816e-20 1.83891e-20 2.9562e-20 2.45173e-20 1.51752e-20 9.3454e-21 5.72614e-21
60+
6.96278e-21 1.13621e-20 1.8448e-20 2.98026e-20 4.79043e-20 3.98553e-20 2.46711e-20 1.51947e-20 9.31096e-21
61+
1.12314e-20 1.83256e-20 2.97505e-20 4.80558e-20 7.72344e-20 6.4463e-20 3.99074e-20 2.45808e-20 1.50639e-20
62+
1.80254e-20 2.94073e-20 4.77351e-20 7.70963e-20 1.23892e-19 1.0374e-19 6.42289e-20 3.95652e-20 2.42491e-20
63+
⋮ ⋱
64+
1.80254e-20 2.94073e-20 4.77351e-20 7.70963e-20 1.23892e-19 … 1.0374e-19 6.42289e-20 3.95652e-20 2.42491e-20
65+
1.12314e-20 1.83256e-20 2.97505e-20 4.80558e-20 7.72344e-20 6.4463e-20 3.99074e-20 2.45808e-20 1.50639e-20
66+
6.96278e-21 1.13621e-20 1.8448e-20 2.98026e-20 4.79043e-20 3.98553e-20 2.46711e-20 1.51947e-20 9.31096e-21
67+
4.29467e-21 7.00905e-21 1.13816e-20 1.83891e-20 2.9562e-20 2.45173e-20 1.51752e-20 9.3454e-21 5.72614e-21
68+
2.63558e-21 4.30187e-21 6.98638e-21 1.12892e-20 1.81505e-20 1.50062e-20 9.28736e-21 5.71895e-21 3.50382e-21
69+
```
70+
71+
or taking advantage of the parallel computation of the `WignerLaguerre` method
72+
73+
```
74+
julia> wig = wigner(ρ, xvec, xvec, method=WignerLaguerre(parallel=true))
75+
200×200 Matrix{Float64}:
76+
2.63558e-21 4.30187e-21 6.98638e-21 1.12892e-20 1.81505e-20 … 1.50062e-20 9.28736e-21 5.71895e-21 3.50382e-21
77+
4.29467e-21 7.00905e-21 1.13816e-20 1.83891e-20 2.9562e-20 2.45173e-20 1.51752e-20 9.3454e-21 5.72614e-21
78+
6.96278e-21 1.13621e-20 1.8448e-20 2.98026e-20 4.79043e-20 3.98553e-20 2.46711e-20 1.51947e-20 9.31096e-21
79+
1.12314e-20 1.83256e-20 2.97505e-20 4.80558e-20 7.72344e-20 6.4463e-20 3.99074e-20 2.45808e-20 1.50639e-20
80+
1.80254e-20 2.94073e-20 4.77351e-20 7.70963e-20 1.23892e-19 1.0374e-19 6.42289e-20 3.95652e-20 2.42491e-20
81+
⋮ ⋱
82+
1.80254e-20 2.94073e-20 4.77351e-20 7.70963e-20 1.23892e-19 … 1.0374e-19 6.42289e-20 3.95652e-20 2.42491e-20
83+
1.12314e-20 1.83256e-20 2.97505e-20 4.80558e-20 7.72344e-20 6.4463e-20 3.99074e-20 2.45808e-20 1.50639e-20
84+
6.96278e-21 1.13621e-20 1.8448e-20 2.98026e-20 4.79043e-20 3.98553e-20 2.46711e-20 1.51947e-20 9.31096e-21
85+
4.29467e-21 7.00905e-21 1.13816e-20 1.83891e-20 2.9562e-20 2.45173e-20 1.51752e-20 9.3454e-21 5.72614e-21
86+
2.63558e-21 4.30187e-21 6.98638e-21 1.12892e-20 1.81505e-20 1.50062e-20 9.28736e-21 5.71895e-21 3.50382e-21
87+
```
2688
"""
2789
function wigner(
28-
state::QuantumObject{<:AbstractArray{T},OpType},
90+
state::QuantumObject{DT,OpType},
2991
xvec::AbstractVector,
3092
yvec::AbstractVector;
3193
g::Real = 2,
32-
solver::MySolver = WignerLaguerre(),
33-
) where {T,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject},MySolver<:WignerSolver}
34-
if isket(state)
35-
ρ = (state * state').data
36-
elseif isbra(state)
37-
ρ = (state' * state).data
38-
else
39-
ρ = state.data
40-
end
94+
method::WignerSolver = WignerClenshaw(),
95+
) where {DT,OpType<:Union{BraQuantumObject,KetQuantumObject,OperatorQuantumObject}}
96+
ρ = ket2dm(state).data
4197

42-
return _wigner(ρ, xvec, yvec, g, solver)
98+
return _wigner(ρ, xvec, yvec, g, method)
4399
end
44100

45101
function _wigner(
46102
ρ::AbstractArray,
47103
xvec::AbstractVector{T},
48104
yvec::AbstractVector{T},
49105
g::Real,
50-
solver::WignerLaguerre,
106+
method::WignerLaguerre,
51107
) where {T<:BlasFloat}
52108
g = convert(T, g)
53109
X, Y = meshgrid(xvec, yvec)
54110
A = g / 2 * (X + 1im * Y)
55111
W = similar(A, T)
56112
W .= 0
57113

58-
return _wigner_laguerre(ρ, A, W, g, solver)
114+
return _wigner_laguerre(ρ, A, W, g, method)
59115
end
60116

61117
function _wigner(
62118
ρ::AbstractArray{T1},
63119
xvec::AbstractVector{T},
64120
yvec::AbstractVector{T},
65121
g::Real,
66-
solver::WignerClenshaw,
122+
method::WignerClenshaw,
67123
) where {T1<:BlasFloat,T<:BlasFloat}
68124
g = convert(T, g)
69125
M = size(ρ, 1)
@@ -90,11 +146,11 @@ function _wigner(
90146
return @. real(W) * exp(-B / 2) * g^2 / 2 / π
91147
end
92148

93-
function _wigner_laguerre::AbstractSparseArray, A::AbstractArray, W::AbstractArray, g::Real, solver::WignerLaguerre)
149+
function _wigner_laguerre::AbstractSparseArray, A::AbstractArray, W::AbstractArray, g::Real, method::WignerLaguerre)
94150
rows, cols, vals = findnz(ρ)
95151
B = @. 4 * abs2(A)
96152

97-
if solver.parallel
153+
if method.parallel
98154
iter = filter(x -> x[2] >= x[1], collect(zip(rows, cols, vals)))
99155
Wtot = similar(B, size(B)..., length(iter))
100156
Threads.@threads for i in eachindex(iter)
@@ -122,12 +178,12 @@ function _wigner_laguerre(ρ::AbstractSparseArray, A::AbstractArray, W::Abstract
122178
return @. W * g^2 * exp(-B / 2) / 2 / π
123179
end
124180

125-
function _wigner_laguerre::AbstractArray, A::AbstractArray, W::AbstractArray, g::Real, solver::WignerLaguerre)
126-
tol = solver.tol
181+
function _wigner_laguerre::AbstractArray, A::AbstractArray, W::AbstractArray, g::Real, method::WignerLaguerre)
182+
tol = method.tol
127183
M = size(ρ, 1)
128184
B = @. 4 * abs2(A)
129185

130-
if solver.parallel
186+
if method.parallel
131187
throw(ArgumentError("Parallel version is not implemented for dense matrices"))
132188
else
133189
for m in 0:M-1

test/core-test/wigner.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
xvec = LinRange(-3, 3, 300)
66
yvec = LinRange(-3, 3, 300)
77

8-
wig = wigner(ψ, xvec, yvec, solver = WignerLaguerre(tol = 1e-6))
9-
wig2 = wigner(ρ, xvec, yvec, solver = WignerLaguerre(parallel = false))
10-
wig3 = wigner(ρ, xvec, yvec, solver = WignerLaguerre(parallel = true))
11-
wig4 = wigner(ψ, xvec, yvec, solver = WignerClenshaw())
8+
wig = wigner(ψ, xvec, yvec, method = WignerLaguerre(tol = 1e-6))
9+
wig2 = wigner(ρ, xvec, yvec, method = WignerLaguerre(parallel = false))
10+
wig3 = wigner(ρ, xvec, yvec, method = WignerLaguerre(parallel = true))
11+
wig4 = wigner(ψ, xvec, yvec, method = WignerClenshaw())
1212

1313
@test sqrt(sum(abs.(wig2 .- wig)) / length(wig)) < 1e-3
1414
@test sqrt(sum(abs.(wig3 .- wig)) / length(wig)) < 1e-3
@@ -22,9 +22,9 @@
2222
@test sqrt(sum(abs.(wig2 .- wig)) / length(wig)) < 0.1
2323

2424
@testset "Type Inference (wigner)" begin
25-
@inferred wigner(ψ, xvec, yvec, solver = WignerLaguerre(tol = 1e-6))
26-
@inferred wigner(ρ, xvec, yvec, solver = WignerLaguerre(parallel = false))
27-
@inferred wigner(ρ, xvec, yvec, solver = WignerLaguerre(parallel = true))
28-
@inferred wigner(ψ, xvec, yvec, solver = WignerClenshaw())
25+
@inferred wigner(ψ, xvec, yvec, method = WignerLaguerre(tol = 1e-6))
26+
@inferred wigner(ρ, xvec, yvec, method = WignerLaguerre(parallel = false))
27+
@inferred wigner(ρ, xvec, yvec, method = WignerLaguerre(parallel = true))
28+
@inferred wigner(ψ, xvec, yvec, method = WignerClenshaw())
2929
end
3030
end

0 commit comments

Comments
 (0)