Skip to content

Commit c24ce91

Browse files
Add support for plotting 1D function (#2250)
* add support for plotting 1D function * allow more than one variable * add a test * format * format * simplify * add comment * function -> method * add support for 1D StructuredMesh and 1D DGMulti * specialize method for 1D meshes * format * Apply suggestions from code review Co-authored-by: Hendrik Ranocha <[email protected]> * always use solution_variables = cons2cons * don't pass variable_names to recipe (title can be changed by `title`) * fix plotting scalar function for DGMulti * fix for StructuredMesh * Update src/visualization/types.jl Co-authored-by: Hendrik Ranocha <[email protected]> * fix comment * clarify output of function --------- Co-authored-by: Hendrik Ranocha <[email protected]>
1 parent 06e8a20 commit c24ce91

File tree

4 files changed

+115
-15
lines changed

4 files changed

+115
-15
lines changed

src/visualization/recipes_plots.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ RecipesBase.@recipe function f(u, semi::AbstractSemidiscretization;
171171
end
172172
end
173173

174+
# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions.
175+
# We need this recipe in addition to the one above to avoid method ambiguities.
176+
RecipesBase.@recipe function f(func::Function, semi::AbstractSemidiscretization;
177+
solution_variables = nothing)
178+
n_variables = length(func(0.0, semi.equations))
179+
variable_names = SVector(["func[$i]" for i in 1:n_variables]...)
180+
if ndims(semi) == 1
181+
return PlotData1D(func, semi; solution_variables = cons2cons, variable_names)
182+
else
183+
throw(ArgumentError("Plotting of functions is only supported in 1D."))
184+
end
185+
end
186+
174187
# Recipe specifically for TreeMesh-type solutions
175188
# Note: If you change the defaults values here, you need to also change them in the PlotData1D or PlotData2D
176189
# constructor.
@@ -189,6 +202,22 @@ RecipesBase.@recipe function f(u, semi::SemidiscretizationHyperbolic{<:TreeMesh}
189202
end
190203
end
191204

205+
# Also allow plotting a function with signature `func(x, equations)`, e.g., for initial conditions.
206+
RecipesBase.@recipe function f(func::Function,
207+
semi::SemidiscretizationHyperbolic{<:TreeMesh};
208+
solution_variables = nothing,
209+
nvisnodes = nothing, slice = :xy,
210+
point = (0.0, 0.0, 0.0), curve = nothing)
211+
n_variables = length(func(0.0, semi.equations))
212+
variable_names = SVector(["func[$i]" for i in 1:n_variables]...)
213+
if ndims(semi) == 1
214+
return PlotData1D(func, semi; solution_variables = cons2cons, nvisnodes, slice,
215+
point, curve, variable_names)
216+
else
217+
throw(ArgumentError("Plotting of functions is only supported in 1D."))
218+
end
219+
end
220+
192221
# Series recipe for PlotData2DTriangulated
193222
RecipesBase.@recipe function f(pds::PlotDataSeries{<:PlotData2DTriangulated})
194223
pd = pds.plot_data

src/visualization/types.jl

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -519,10 +519,15 @@ end
519519
solution_variables=nothing, nvisnodes=nothing)
520520
521521
Create a new `PlotData1D` object that can be used for visualizing 1D DGSEM solution data array
522-
`u` with `Plots.jl`. All relevant geometrical information is extracted from the semidiscretization
523-
`semi`. By default, the primitive variables (if existent) or the conservative variables (otherwise)
524-
from the solution are used for plotting. This can be changed by passing an appropriate conversion
525-
function to `solution_variables`.
522+
`u` with `Plots.jl`. All relevant geometrical information is extracted from the
523+
semidiscretization `semi`. By default, the primitive variables (if existent)
524+
or the conservative variables (otherwise) from the solution are used for
525+
plotting. This can be changed by passing an appropriate conversion function to
526+
`solution_variables`, e.g., [`cons2cons`](@ref) or [`cons2prim`](@ref).
527+
528+
Alternatively, you can also pass a function `u` with signature `u(x, equations)`
529+
returning a vector. In this case, the `solution_variables` are ignored. This is useful,
530+
e.g., to visualize an analytical solution.
526531
527532
`nvisnodes` specifies the number of visualization nodes to be used. If it is `nothing`,
528533
twice the number of solution DG nodes are used for visualization, and if set to `0`,
@@ -547,11 +552,19 @@ function PlotData1D(u_ode, semi; kwargs...)
547552
kwargs...)
548553
end
549554

555+
function PlotData1D(func::Function, semi; kwargs...)
556+
PlotData1D(func,
557+
mesh_equations_solver_cache(semi)...;
558+
kwargs...)
559+
end
560+
550561
function PlotData1D(u, mesh::TreeMesh, equations, solver, cache;
551562
solution_variables = nothing, nvisnodes = nothing,
552-
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing)
563+
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing,
564+
variable_names = nothing)
553565
solution_variables_ = digest_solution_variables(equations, solution_variables)
554-
variable_names = SVector(varnames(solution_variables_, equations))
566+
variable_names_ = digest_variable_names(solution_variables_, equations,
567+
variable_names)
555568

556569
original_nodes = cache.elements.node_coordinates
557570
unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations,
@@ -610,15 +623,17 @@ function PlotData1D(u, mesh::TreeMesh, equations, solver, cache;
610623
end
611624
end
612625

613-
return PlotData1D(x, data, variable_names, mesh_vertices_x,
626+
return PlotData1D(x, data, variable_names_, mesh_vertices_x,
614627
orientation_x)
615628
end
616629

617630
function PlotData1D(u, mesh, equations, solver, cache;
618631
solution_variables = nothing, nvisnodes = nothing,
619-
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing)
632+
slice = :x, point = (0.0, 0.0, 0.0), curve = nothing,
633+
variable_names = nothing)
620634
solution_variables_ = digest_solution_variables(equations, solution_variables)
621-
variable_names = SVector(varnames(solution_variables_, equations))
635+
variable_names_ = digest_variable_names(solution_variables_, equations,
636+
variable_names)
622637

623638
original_nodes = cache.elements.node_coordinates
624639
unstructured_data = get_unstructured_data(u, solution_variables_, mesh, equations,
@@ -642,15 +657,25 @@ function PlotData1D(u, mesh, equations, solver, cache;
642657
slice, point, nvisnodes)
643658
end
644659

645-
return PlotData1D(x, data, variable_names, mesh_vertices_x,
660+
return PlotData1D(x, data, variable_names_, mesh_vertices_x,
646661
orientation_x)
647662
end
648663

664+
function PlotData1D(func::Function, mesh, equations, dg::DGMulti{1}, cache;
665+
solution_variables = nothing, variable_names = nothing)
666+
x = mesh.md.x
667+
u = func.(x, equations)
668+
669+
return PlotData1D(u, mesh, equations, dg, cache;
670+
solution_variables, variable_names)
671+
end
672+
649673
# Specializes the `PlotData1D` constructor for one-dimensional `DGMulti` solvers.
650674
function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache;
651-
solution_variables = nothing)
675+
solution_variables = nothing, variable_names = nothing)
652676
solution_variables_ = digest_solution_variables(equations, solution_variables)
653-
variable_names = SVector(varnames(solution_variables_, equations))
677+
variable_names_ = digest_variable_names(solution_variables_, equations,
678+
variable_names)
654679

655680
orientation_x = 0 # Set 'orientation' to zero on default.
656681

@@ -679,11 +704,16 @@ function PlotData1D(u, mesh, equations, dg::DGMulti{1}, cache;
679704
# Same as above - we create `data_plot` as array of size `num_plotting_points`
680705
# by "number of plotting variables".
681706
x_plot = vec(x)
682-
data_plot = permutedims(reinterpret(reshape, eltype(eltype(data)), vec(data)),
683-
(2, 1))
707+
data_ = reinterpret(reshape, eltype(eltype(data)), vec(data))
708+
# If there is only one solution variable, we need to add a singleton dimension
709+
if ndims(data_) == 1
710+
data_plot = reshape(data_, :, 1)
711+
else
712+
data_plot = permutedims(data_, (2, 1))
713+
end
684714
end
685715

686-
return PlotData1D(x_plot, data_plot, variable_names, mesh.md.VX, orientation_x)
716+
return PlotData1D(x_plot, data_plot, variable_names_, mesh.md.VX, orientation_x)
687717
end
688718

689719
"""

src/visualization/utilities.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ function digest_solution_variables(equations, solution_variables::Nothing)
256256
end
257257
end
258258

259+
digest_variable_names(solution_variables_, equations, variable_names) = variable_names
260+
digest_variable_names(solution_variables_, equations, ::Nothing) = SVector(varnames(solution_variables_,
261+
equations))
262+
259263
"""
260264
adapt_to_mesh_level!(u_ode, semi, level)
261265
adapt_to_mesh_level!(sol::Trixi.TrixiODESolution, level)
@@ -481,6 +485,18 @@ function get_unstructured_data(u, solution_variables, mesh, equations, solver, c
481485
return unstructured_data
482486
end
483487

488+
# This method is only for plotting 1D functions
489+
function get_unstructured_data(func::Function, solution_variables,
490+
mesh::AbstractMesh{1}, equations, solver, cache)
491+
original_nodes = cache.elements.node_coordinates
492+
# original_nodes has size (1, nnodes, nelements)
493+
# we want u to have size (nvars, nnodes, nelements)
494+
# func.(original_nodes, equations) has size (1, nnodes, nelements), where each component has length n_vars
495+
# Therefore, we drop the first (singleton) dimension and then stack the components
496+
u = stack(func.(SVector.(dropdims(original_nodes; dims = 1)), equations))
497+
return get_unstructured_data(u, solution_variables, mesh, equations, solver, cache)
498+
end
499+
484500
# Convert cell-centered values to node-centered values by averaging over all
485501
# four neighbors and making use of the periodicity of the solution
486502
#

test/test_visualization.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ end
188188
@test_nowarn_mod Plots.plot(pd)
189189
@test_nowarn_mod Plots.plot(pd["p"])
190190
@test_nowarn_mod Plots.plot(getmesh(pd))
191+
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan),
192+
equations)
193+
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
194+
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
191195
end
192196

193197
# Fake a PlotDataXD objects to test code for plotting multiple variables on at least two rows
@@ -220,13 +224,34 @@ end
220224
tspan = (0.0, 0.0),
221225
approximation_type = Polynomial())
222226
@test PlotData1D(sol) isa PlotData1D
227+
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations)
228+
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
229+
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
223230

224231
@test_nowarn_mod trixi_include(@__MODULE__,
225232
joinpath(examples_dir(), "dgmulti_1d",
226233
"elixir_euler_flux_diff.jl"),
227234
tspan = (0.0, 0.0),
228235
approximation_type = SBP())
229236
@test PlotData1D(sol) isa PlotData1D
237+
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
238+
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
239+
end
240+
241+
@timed_testset "1D plot recipes (StructuredMesh)" begin
242+
@test_nowarn_mod trixi_include(@__MODULE__,
243+
joinpath(examples_dir(), "structured_1d_dgsem",
244+
"elixir_euler_source_terms.jl"),
245+
tspan = (0.0, 0.0))
246+
247+
pd = PlotData1D(sol)
248+
initial_condition_t_end(x, equations) = initial_condition(x, last(tspan), equations)
249+
@test_nowarn_mod Plots.plot(sol)
250+
@test_nowarn_mod Plots.plot(pd)
251+
@test_nowarn_mod Plots.plot(pd["p"])
252+
@test_nowarn_mod Plots.plot(sol.u[end], semi)
253+
@test_nowarn_mod Plots.plot(initial_condition_t_end, semi)
254+
@test_nowarn_mod Plots.plot((x, equations) -> x, semi)
230255
end
231256

232257
@timed_testset "plot time series" begin

0 commit comments

Comments
 (0)