Skip to content

Commit 73060dd

Browse files
committed
add storage_type, real_type to semidiscretize
1 parent 5205a08 commit 73060dd

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

examples/p4est_2d_dgsem/elixir_advection_basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ semi = SemidiscretizationHyperbolic(mesh, equations, initial_condition_convergen
3131
# ODE solvers, callbacks etc.
3232

3333
# Create ODE problem with time span from 0.0 to 1.0
34-
ode = semidiscretize(semi, (0.0, 1.0))
34+
ode = semidiscretize(semi, (0.0, 1.0); real_type = nothing, storage_type = nothing)
3535

3636
# At the beginning of the main loop, the SummaryCallback prints a summary of the simulation setup
3737
# and resets the timers

src/semidiscretization/semidiscretization.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,35 @@ end
8282
8383
Wrap the semidiscretization `semi` as an ODE problem in the time interval `tspan`
8484
that can be passed to `solve` from the [SciML ecosystem](https://diffeq.sciml.ai/latest/).
85+
86+
The optional keyword arguments `storage_type` and `real_type` configure the underlying computational
87+
datastructures. `storage_type` changes the fundamental array type being used, allowing the
88+
experimental use of `CuArray` or other GPU array types. `real_type` changes the computational data type being used.
8589
"""
8690
function semidiscretize(semi::AbstractSemidiscretization, tspan;
87-
reset_threads = true)
91+
reset_threads = true,
92+
storage_type = nothing,
93+
real_type = nothing)
8894
# Optionally reset Polyester.jl threads. See
8995
# https://github.com/trixi-framework/Trixi.jl/issues/1583
9096
# https://github.com/JuliaSIMD/Polyester.jl/issues/30
9197
if reset_threads
9298
Polyester.reset_threads!()
9399
end
94100

101+
if !(storage_type === nothing && real_type === nothing)
102+
if storage_type === nothing
103+
storage_type = Array
104+
end
105+
if real_type === nothing
106+
real_type = Float64
107+
end
108+
semi = trixi_adapt(storage_type, real_type, semi)
109+
if eltype(tspan) !== real_type
110+
tspan = convert.(real_type, tspan)
111+
end
112+
end
113+
95114
u0_ode = compute_coefficients(first(tspan), semi)
96115
# TODO: MPI, do we want to synchronize loading and print debug statements, e.g. using
97116
# mpi_isparallel() && MPI.Barrier(mpi_comm())

test/test_p4est_2d.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ isdir(outdir) && rm(outdir, recursive = true)
3535
@test real(semi32.mesh) == Float64
3636
end
3737

38+
@trixi_testset "elixir_advection_basic.jl (Float32)" begin
39+
@test_trixi_include(joinpath(EXAMPLES_DIR, "elixir_advection_basic.jl"),
40+
# Expected errors are exactly the same as with TreeMesh!
41+
l2=[8.311947673061856e-6],
42+
linf=[6.627000273229378e-5],
43+
real_type=Float32)
44+
# Ensure that we do not have excessive memory allocations
45+
# (e.g., from type instabilities)
46+
let
47+
t = sol.t[end]
48+
u_ode = sol.u[end]
49+
du_ode = similar(u_ode)
50+
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 1000
51+
end
52+
@test real(ode.p.solver) == Float32
53+
@test real(ode.p.solver.basis) == Float32
54+
@test real(ode.p.solver.mortar) == Float32
55+
# TODO: remake ignores the mesh itself as well
56+
@test real(ode.p.mesh) == Float64
57+
end
58+
3859
@trixi_testset "elixir_advection_nonconforming_flag.jl" begin
3960
@test_trixi_include(joinpath(EXAMPLES_DIR,
4061
"elixir_advection_nonconforming_flag.jl"),

0 commit comments

Comments
 (0)