Skip to content

Commit 31860ec

Browse files
benegeeJoshuaLampertranochavchuravy
authored
Add mpi_isroot() (#46)
Co-authored-by: Joshua Lampert <51029046+JoshuaLampert@users.noreply.github.com> Co-authored-by: Benedict <135045760+benegee@users.noreply.github.com> Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com> Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent 8efaf07 commit 31860ec

File tree

7 files changed

+60
-12
lines changed

7 files changed

+60
-12
lines changed

ext/TrixiBaseMPIExt.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
11
# Package extension for adding MPI-based features to TrixiBase.jl
22
module TrixiBaseMPIExt
33

4-
# Load package extension code on Julia v1.9 and newer
5-
if isdefined(Base, :get_extension)
6-
using MPI: MPI
7-
end
4+
using MPI
85
import TrixiBase
96

10-
# This is a really working version - assuming the same
7+
function __init__()
8+
TrixiBase.__MPI__AVAILABLE__[] = true
9+
end
10+
11+
# These are really working functions - assuming the same
1112
# communication pattern etc. used in Trixi.jl.
12-
function TrixiBase.mpi_isparallel(::Val{:MPIExt})
13+
function TrixiBase.mpi_isparallel_internal()
1314
if MPI.Initialized()
1415
return MPI.Comm_size(MPI.COMM_WORLD) > 1
1516
else
1617
return false
1718
end
1819
end
1920

21+
function TrixiBase.mpi_isroot_internal()
22+
if MPI.Initialized()
23+
return MPI.Comm_rank(MPI.COMM_WORLD) == 0
24+
else
25+
return true
26+
end
27+
end
2028
end

src/TrixiBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module TrixiBase
33
using ChangePrecision: ChangePrecision
44
using TimerOutputs: TimerOutput, TimerOutputs
55

6+
include("mpi.jl")
67
include("trixi_include.jl")
78
include("trixi_timeit.jl")
89

src/mpi.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# These are just dummy functions. We only implement real
2+
# versions if MPI.jl is loaded to avoid letting TrixiBase.jl
3+
# depend explicitly on MPI.jl.
4+
5+
# This will be true if TrixiBaseMPIExt is loaded
6+
const __MPI__AVAILABLE__ = Ref{Bool}(false)
7+
8+
# These functions are defined in the TrixiBaseMPIExt extension
9+
function mpi_isparallel_internal end
10+
function mpi_isroot_internal end
11+
12+
function mpi_isparallel()
13+
if __MPI__AVAILABLE__[]
14+
return mpi_isparallel_internal()
15+
else
16+
return false
17+
end
18+
end
19+
20+
function mpi_isroot()
21+
if __MPI__AVAILABLE__[]
22+
return mpi_isroot_internal()
23+
else
24+
return true
25+
end
26+
end

src/trixi_include.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; k
4646
end
4747

4848
# Print information on potential wait time only in non-parallel case
49-
if !mpi_isparallel(Val{:MPIExt}())
49+
if !mpi_isparallel()
5050
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
5151
end
5252
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex); kwargs...)),
@@ -203,8 +203,3 @@ function find_assignment(expr, destination)
203203

204204
result
205205
end
206-
207-
# This is just a dummy function. We only implement a real
208-
# version if MPI.jl is loaded to avoid letting TrixiBase.jl
209-
# depend explicitly on MPI.jl.
210-
mpi_isparallel(x) = false

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
4+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
45
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
56
TrixiTest = "0a316866-cbd0-4425-8bcb-08103b2c1f26"
67

78
[compat]
89
Aqua = "0.7, 0.8"
910
ExplicitImports = "1.0.1"
11+
MPI = "0.20"
1012
Test = "1"
1113
TrixiTest = "0.1"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ include("test_util.jl")
22

33
@testset verbose=true "TrixiBase.jl Tests" begin
44
include("test_aqua.jl")
5+
include("test_mpi.jl")
6+
run(`$(mpiexec()) -n 2 $(Base.julia_cmd()) --threads=1 $(abspath("test_mpi.jl"))`)
57
include("trixi_include.jl")
68
include("test_timers.jl")
79
end;

test/test_mpi.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
include("test_util.jl")
3+
4+
@testset verbose=true "MPI helper functions" begin
5+
@test TrixiBase.mpi_isparallel() == false
6+
@test TrixiBase.mpi_isroot() == true
7+
8+
using MPI
9+
@test TrixiBase.mpi_isparallel() == false
10+
@test TrixiBase.mpi_isroot() == true
11+
MPI.Init()
12+
@test TrixiBase.mpi_isparallel() == (MPI.Comm_size(MPI.COMM_WORLD) > 1)
13+
@test TrixiBase.mpi_isroot() == (MPI.Comm_rank(MPI.COMM_WORLD) == 0)
14+
end

0 commit comments

Comments
 (0)