Skip to content

Commit a52ccf6

Browse files
Allow complex float conversion of general types and make _CType function name more Julia-like
1 parent 0bc7621 commit a52ccf6

File tree

10 files changed

+42
-40
lines changed

10 files changed

+42
-40
lines changed

src/qobj/quantum_object_base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,5 @@ _get_dims_length(::Space) = 1
245245
_get_dims_length(::EnrSpace{N}) where {N} = N
246246

247247
# functions for getting Float or Complex element type
248-
_FType(A::AbstractQuantumObject) = _FType(eltype(A))
249-
_CType(A::AbstractQuantumObject) = _CType(eltype(A))
248+
_float_type(A::AbstractQuantumObject) = _float_type(eltype(A))
249+
_complex_float_type(A::AbstractQuantumObject) = _complex_float_type(eltype(A))

src/spectrum.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,17 @@ function _spectrum(
127127
)
128128
check_dimensions(L, A, B)
129129

130-
ωList = convert(Vector{_FType(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
130+
ωList = convert(Vector{_float_type(L)}, ωlist) # Convert it to support GPUs and avoid type instabilities
131131
Length = length(ωList)
132-
spec = Vector{_FType(L)}(undef, Length)
132+
spec = Vector{_float_type(L)}(undef, Length)
133133

134134
# calculate vectorized steadystate, multiply by operator B on the left (spre)
135135
ρss = mat2vec(steadystate(L))
136136
b = (spre(B) * ρss).data
137137

138138
# multiply by operator A on the left (spre) and then perform trace operation
139139
D = prod(L.dimensions)
140-
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_CType(L), D)) # same as vec(system_identity_matrix)
140+
_tr = SparseVector(D^2, [1 + n * (D + 1) for n in 0:(D-1)], ones(_complex_float_type(L), D)) # same as vec(system_identity_matrix)
141141
_tr_A = transpose(_tr) * spre(A).data
142142

143143
Id = I(D^2)
@@ -169,8 +169,8 @@ function _spectrum(
169169
check_dimensions(L, A, B)
170170

171171
# Define type shortcuts
172-
fT = _FType(L)
173-
cT = _CType(L)
172+
fT = _float_type(L)
173+
cT = _complex_float_type(L)
174174

175175
# Calculate |v₁> = B|ρss>
176176
ρss =

src/steadystate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function _steadystate(L::QuantumObject{SuperOperator}, solver::SteadyStateODESol
206206
abstol = haskey(kwargs, :abstol) ? kwargs[:abstol] : DEFAULT_ODE_SOLVER_OPTIONS.abstol
207207
reltol = haskey(kwargs, :reltol) ? kwargs[:reltol] : DEFAULT_ODE_SOLVER_OPTIONS.reltol
208208

209-
ftype = _FType(ψ0)
209+
ftype = _float_type(ψ0)
210210
_terminate_func = SteadyStateODECondition(similar(mat2vec(ket2dm(ψ0)).data))
211211
cb = TerminateSteadyState(abstol, reltol, _terminate_func)
212212
sol = mesolve(

src/time_evolution/lr_mesolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ function lr_mesolveProblem(
412412
c_ops = get_data.(c_ops)
413413
e_ops = get_data.(e_ops)
414414

415-
t_l = _check_tlist(tlist, _FType(H))
415+
t_l = _check_tlist(tlist, _float_type(H))
416416

417417
# Initialization of Arrays
418418
expvals = Array{ComplexF64}(undef, length(e_ops), length(t_l))

src/time_evolution/mcsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function mcsolveProblem(
125125
c_ops isa Nothing &&
126126
throw(ArgumentError("The list of collapse operators must be provided. Use sesolveProblem instead."))
127127

128-
tlist = _check_tlist(tlist, _FType(ψ0))
128+
tlist = _check_tlist(tlist, _float_type(ψ0))
129129

130130
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, c_ops)
131131

src/time_evolution/mesolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,16 @@ function mesolveProblem(
7979
haskey(kwargs, :save_idxs) &&
8080
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
8181

82-
tlist = _check_tlist(tlist, _FType(ψ0))
82+
tlist = _check_tlist(tlist, _float_type(ψ0))
8383

8484
L_evo = _mesolve_make_L_QobjEvo(H, c_ops)
8585
check_dimensions(L_evo, ψ0)
8686

8787
T = Base.promote_eltype(L_evo, ψ0)
8888
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
89-
to_dense(_CType(T), copy(ψ0.data))
89+
to_dense(_complex_float_type(T), copy(ψ0.data))
9090
else
91-
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
91+
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
9292
end
9393
L = L_evo.data
9494

src/time_evolution/sesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ function sesolveProblem(
6161
haskey(kwargs, :save_idxs) &&
6262
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))
6363

64-
tlist = _check_tlist(tlist, _FType(ψ0))
64+
tlist = _check_tlist(tlist, _float_type(ψ0))
6565

6666
H_evo = _sesolve_make_U_QobjEvo(H) # Multiply by -i
6767
isoper(H_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
6868
check_dimensions(H_evo, ψ0)
6969

7070
T = Base.promote_eltype(H_evo, ψ0)
71-
ψ0 = to_dense(_CType(T), get_data(ψ0)) # Convert it to dense vector with complex element type
71+
ψ0 = to_dense(_complex_float_type(T), get_data(ψ0)) # Convert it to dense vector with complex element type
7272
U = H_evo.data
7373

7474
kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)

src/time_evolution/smesolve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ function smesolveProblem(
9494
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
9595
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject
9696

97-
tlist = _check_tlist(tlist, _FType(ψ0))
97+
tlist = _check_tlist(tlist, _float_type(ψ0))
9898

9999
L_evo = _mesolve_make_L_QobjEvo(H, c_ops) + _mesolve_make_L_QobjEvo(nothing, sc_ops_list)
100100
check_dimensions(L_evo, ψ0)
101101
dims = L_evo.dimensions
102102

103103
T = Base.promote_eltype(L_evo, ψ0)
104104
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
105-
to_dense(_CType(T), copy(ψ0.data))
105+
to_dense(_complex_float_type(T), copy(ψ0.data))
106106
else
107-
to_dense(_CType(T), mat2vec(ket2dm(ψ0).data))
107+
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
108108
end
109109

110110
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))

src/time_evolution/ssesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ function ssesolveProblem(
9494
sc_ops_list = _make_c_ops_list(sc_ops) # If it is an AbstractQuantumObject but we need to iterate
9595
sc_ops_isa_Qobj = sc_ops isa AbstractQuantumObject # We can avoid using non-diagonal noise if sc_ops is just an AbstractQuantumObject
9696

97-
tlist = _check_tlist(tlist, _FType(ψ0))
97+
tlist = _check_tlist(tlist, _float_type(ψ0))
9898

9999
H_eff_evo = _mcsolve_make_Heff_QobjEvo(H, sc_ops_list)
100100
isoper(H_eff_evo) || throw(ArgumentError("The Hamiltonian must be an Operator."))
101101
check_dimensions(H_eff_evo, ψ0)
102102
dims = H_eff_evo.dimensions
103103

104-
ψ0 = to_dense(_CType(ψ0), get_data(ψ0))
104+
ψ0 = to_dense(_complex_float_type(ψ0), get_data(ψ0))
105105

106106
progr = ProgressBar(length(tlist), enable = getVal(progress_bar))
107107

src/utilities.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ where ``\hbar`` is the reduced Planck constant, and ``k_B`` is the Boltzmann con
4646
function n_thermal::T1, ω_th::T2) where {T1<:Real,T2<:Real}
4747
x = exp/ ω_th)
4848
n = ((x != 1) && (ω_th > 0)) ? 1 / (x - 1) : 0
49-
return _FType(promote_type(T1, T2))(n)
49+
return _float_type(promote_type(T1, T2))(n)
5050
end
5151

5252
@doc raw"""
@@ -125,7 +125,7 @@ julia> round(convert_unit(1, :meV, :mK), digits=4)
125125
function convert_unit(value::T, unit1::Symbol, unit2::Symbol) where {T<:Real}
126126
!haskey(_energy_units, unit1) && throw(ArgumentError("Invalid unit :$(unit1)"))
127127
!haskey(_energy_units, unit2) && throw(ArgumentError("Invalid unit :$(unit2)"))
128-
return _FType(T)(value * (_energy_units[unit1] / _energy_units[unit2]))
128+
return _float_type(T)(value * (_energy_units[unit1] / _energy_units[unit2]))
129129
end
130130

131131
get_typename_wrapper(A) = Base.typename(typeof(A)).wrapper
@@ -174,24 +174,26 @@ for AType in (:AbstractArray, :AbstractSciMLOperator)
174174
end
175175

176176
# functions for getting Float or Complex element type
177-
_FType(::AbstractArray{T}) where {T<:Number} = _FType(T)
178-
_FType(::Type{Int32}) = Float32
179-
_FType(::Type{Int64}) = Float64
180-
_FType(::Type{Float32}) = Float32
181-
_FType(::Type{Float64}) = Float64
182-
_FType(::Type{Complex{Int32}}) = Float32
183-
_FType(::Type{Complex{Int64}}) = Float64
184-
_FType(::Type{Complex{Float32}}) = Float32
185-
_FType(::Type{Complex{Float64}}) = Float64
186-
_CType(::AbstractArray{T}) where {T<:Number} = _CType(T)
187-
_CType(::Type{Int32}) = ComplexF32
188-
_CType(::Type{Int64}) = ComplexF64
189-
_CType(::Type{Float32}) = ComplexF32
190-
_CType(::Type{Float64}) = ComplexF64
191-
_CType(::Type{Complex{Int32}}) = ComplexF32
192-
_CType(::Type{Complex{Int64}}) = ComplexF64
193-
_CType(::Type{Complex{Float32}}) = ComplexF32
194-
_CType(::Type{Complex{Float64}}) = ComplexF64
177+
_float_type(::AbstractArray{T}) where {T<:Number} = _float_type(T)
178+
_float_type(::Type{Int32}) = Float32
179+
_float_type(::Type{Int64}) = Float64
180+
_float_type(::Type{Float32}) = Float32
181+
_float_type(::Type{Float64}) = Float64
182+
_float_type(::Type{Complex{Int32}}) = Float32
183+
_float_type(::Type{Complex{Int64}}) = Float64
184+
_float_type(::Type{Complex{Float32}}) = Float32
185+
_float_type(::Type{Complex{Float64}}) = Float64
186+
_float_type(T::Type{<:Real}) = T # Allow other untracked Real types, like ForwardDiff.Dual
187+
_complex_float_type(::AbstractArray{T}) where {T<:Number} = _complex_float_type(T)
188+
_complex_float_type(::Type{Int32}) = ComplexF32
189+
_complex_float_type(::Type{Int64}) = ComplexF64
190+
_complex_float_type(::Type{Float32}) = ComplexF32
191+
_complex_float_type(::Type{Float64}) = ComplexF64
192+
_complex_float_type(::Type{Complex{Int32}}) = ComplexF32
193+
_complex_float_type(::Type{Complex{Int64}}) = ComplexF64
194+
_complex_float_type(::Type{Complex{Float32}}) = ComplexF32
195+
_complex_float_type(::Type{Complex{Float64}}) = ComplexF64
196+
_complex_float_type(T::Type{<:Complex}) = T # Allow other untracked Complex types, like ForwardDiff.Dual
195197

196198
_convert_eltype_wordsize(::Type{T}, ::Val{64}) where {T<:Int} = Int64
197199
_convert_eltype_wordsize(::Type{T}, ::Val{32}) where {T<:Int} = Int32

0 commit comments

Comments
 (0)